diff --git a/discojs/src/models/model.ts b/discojs/src/models/model.ts index 6bbc4e70c..b7e51eab7 100644 --- a/discojs/src/models/model.ts +++ b/discojs/src/models/model.ts @@ -17,12 +17,16 @@ import type { BatchLogs, EpochLogs } from "./logs.js"; **/ // TODO make it typesafe: same shape of data/input/weights export abstract class Model implements Disposable { + protected prevRoundWeights: WeightsContainer | undefined; // TODO don't allow external access but upgrade train to return weights on every epoch /** Return training state */ abstract get weights(): WeightsContainer; /** Set training state */ abstract set weights(ws: WeightsContainer); + set previousRoundWeights(ws: WeightsContainer | undefined) { + this.prevRoundWeights = ws + } /** * Improve predictor * diff --git a/discojs/src/models/tfjs.ts b/discojs/src/models/tfjs.ts index 4e29498b9..a906be741 100644 --- a/discojs/src/models/tfjs.ts +++ b/discojs/src/models/tfjs.ts @@ -76,32 +76,61 @@ export class TFJS extends Model { // First iteration: replace trainOnBatch with custom loss computation async trainFedProx( - xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> { - - debug(this.model.loss, this.model.losses, this.model.lossFunctions) + xs: tf.Tensor, ys: tf.Tensor, + ): Promise<[number, number]> { + let logitsTensor: tf.Tensor; const lossFunction: () => tf.Scalar = () => { + // Proximal term + let proximalTerm = tf.tensor(0) + if (this.prevRoundWeights !== undefined) { + // squared norm + const norm = new WeightsContainer(this.model.getWeights()) + .sub(this.prevRoundWeights) + .map(t => t.square().sum()) + .reduce((t, acc) => tf.add(t, acc)).asScalar() + const mu = 1 + proximalTerm = tf.mul(mu / 2, norm) + } + this.model.apply(xs) const logits = this.model.apply(xs) - if (Array.isArray(logits)) - throw new Error('model outputs too many tensor') - if (logits instanceof tf.SymbolicTensor) - throw new Error('model outputs symbolic tensor') - - // binaryCrossEntropyLoss as implemented by tensorflow.js - // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193 - let y: tf.Tensor; - y = tf.clipByValue(logits, 0.00001, 1 - 0.00001); - y = tf.log(tf.div(y, tf.sub(1, y))); - return tf.losses.sigmoidCrossEntropy(ys, y); + if (Array.isArray(logits)) + throw new Error('model outputs too many tensor') + if (logits instanceof tf.SymbolicTensor) + throw new Error('model outputs symbolic tensor') + logitsTensor = tf.keep(logits) + // binaryCrossentropy as implemented by tensorflow.js + // https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193 + let y: tf.Tensor; + y = tf.clipByValue(logits, 0.00001, 1 - 0.00001); + y = tf.log(tf.div(y, tf.sub(1, y))); + const loss = tf.losses.sigmoidCrossEntropy(ys, y); + console.log(loss.dataSync(), proximalTerm.dataSync()) + return tf.add(loss, proximalTerm) } const lossTensor = this.model.optimizer.minimize(lossFunction, true) if (lossTensor === null) throw new Error("loss should not be null") - - const loss = await lossTensor.array() - tf.dispose([xs, ys, lossTensor]) + + // @ts-expect-error Variable 'logitsTensor' is used before being assigned + const accTensor = tf.metrics.categoricalAccuracy(ys, logitsTensor) + const accSize = accTensor.shape.reduce((l, r) => l * r, 1) + const accSumTensor = accTensor.sum() + const accSum = await accSumTensor.array() + if (typeof accSum !== 'number') + throw new Error('got multiple accuracy sum') + // @ts-expect-error Variable 'logitsTensor' is used before being assigned + tf.dispose([accTensor, accSumTensor, logitsTensor]) + + const loss = await lossTensor.array() + tf.dispose([xs, ys, lossTensor]) - // dummy accuracy for now - return [loss, 0] + const memory = tf.memory().numBytes / 1024 / 1024 / 1024 + debug("training metrics: %O", { + loss, + memory, + allocated: tf.memory().numTensors, + }); + return [loss, accSum / accSize] } async #evaluate( diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 1124137be..db6877323 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -90,7 +90,8 @@ export class Trainer { let previousRoundWeights: WeightsContainer | undefined; for (let round = 0; round < totalRound; round++) { await this.#client.onRoundBeginCommunication(); - + + this.model.previousRoundWeights = previousRoundWeights yield this.#runRound(dataset, validationDataset); let localWeights = this.model.weights;