Skip to content

Commit

Permalink
tmp: sketch of fedprox implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Feb 25, 2025
1 parent abe7996 commit 76977e7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 20 deletions.
4 changes: 4 additions & 0 deletions discojs/src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js";
**/
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model<D extends DataType> implements Disposable {
protected prevRoundWeights: WeightsContainer | undefined;
// TODO don't allow external access but upgrade train to return weights on every epoch
/** Return training state */
abstract get weights(): WeightsContainer;
/** Set training state */
abstract set weights(ws: WeightsContainer);

set previousRoundWeights(ws: WeightsContainer | undefined) {
this.prevRoundWeights = ws
}
/**
* Improve predictor
*
Expand Down
67 changes: 48 additions & 19 deletions discojs/src/models/tfjs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -76,32 +76,61 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {

// First iteration: replace trainOnBatch with custom loss computation
async trainFedProx(
xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {

debug(this.model.loss, this.model.losses, this.model.lossFunctions)
xs: tf.Tensor, ys: tf.Tensor,
): Promise<[number, number]> {
let logitsTensor: tf.Tensor<tf.Rank>;
const lossFunction: () => tf.Scalar = () => {
// Proximal term
let proximalTerm = tf.tensor(0)
if (this.prevRoundWeights !== undefined) {
// squared norm
const norm = new WeightsContainer(this.model.getWeights())
.sub(this.prevRoundWeights)
.map(t => t.square().sum())
.reduce((t, acc) => tf.add(t, acc)).asScalar()
const mu = 1
proximalTerm = tf.mul(mu / 2, norm)
}

this.model.apply(xs)
const logits = this.model.apply(xs)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')

// binaryCrossEntropyLoss as implemented by tensorflow.js
// https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
let y: tf.Tensor;
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
y = tf.log(tf.div(y, tf.sub(1, y)));
return tf.losses.sigmoidCrossEntropy(ys, y);
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')
logitsTensor = tf.keep(logits)
// binaryCrossentropy as implemented by tensorflow.js
// https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
let y: tf.Tensor;
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
y = tf.log(tf.div(y, tf.sub(1, y)));
const loss = tf.losses.sigmoidCrossEntropy(ys, y);
console.log(loss.dataSync(), proximalTerm.dataSync())
return tf.add(loss, proximalTerm)
}
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
if (lossTensor === null) throw new Error("loss should not be null")

const loss = await lossTensor.array()
tf.dispose([xs, ys, lossTensor])

// @ts-expect-error Variable 'logitsTensor' is used before being assigned
const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor)
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSumTensor = accTensor.sum()
const accSum = await accSumTensor.array()
if (typeof accSum !== 'number')
throw new Error('got multiple accuracy sum')
// @ts-expect-error Variable 'logitsTensor' is used before being assigned
tf.dispose([accTensor, accSumTensor, logitsTensor])

const loss = await lossTensor.array()
tf.dispose([xs, ys, lossTensor])

// dummy accuracy for now
return [loss, 0]
const memory = tf.memory().numBytes / 1024 / 1024 / 1024
debug("training metrics: %O", {
loss,
memory,
allocated: tf.memory().numTensors,
});
return [loss, accSum / accSize]
}

async #evaluate(
Expand Down
3 changes: 2 additions & 1 deletion discojs/src/training/trainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ export class Trainer<D extends DataType> {
let previousRoundWeights: WeightsContainer | undefined;
for (let round = 0; round < totalRound; round++) {
await this.#client.onRoundBeginCommunication();


this.model.previousRoundWeights = previousRoundWeights
yield this.#runRound(dataset, validationDataset);

let localWeights = this.model.weights;
Expand Down

0 comments on commit 76977e7

Please sign in to comment.