Skip to content

Commit

Permalink
tmp: overriding weight update yields same as default
Browse files Browse the repository at this point in the history
  • Loading branch information
JulienVig committed Feb 25, 2025
1 parent c20ad82 commit abe7996
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 9 deletions.
17 changes: 9 additions & 8 deletions discojs/src/default_tasks/lus_covid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ export const lusCovid: TaskProvider<'image'> = {

// Model architecture from tensorflow.js docs:
// https://codelabs.developers.google.com/codelabs/tfjs-training-classfication/index.html#4
async getModel (): Promise<Model<'image'>> {
async getModel(): Promise<Model<'image'>> {
const seed = 42
const imageHeight = 100
const imageWidth = 100
const imageChannels = 3
Expand All @@ -55,7 +56,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 8,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
kernelInitializer: tf.initializers.heNormal({ seed })
}))

// The MaxPooling layer acts as a sort of downsampling using max values
Expand All @@ -69,7 +70,7 @@ export const lusCovid: TaskProvider<'image'> = {
filters: 16,
strides: 1,
activation: 'relu',
kernelInitializer: 'varianceScaling'
kernelInitializer: tf.initializers.heNormal({ seed })
}))
model.add(tf.layers.maxPooling2d({ poolSize: [2, 2], strides: [2, 2] }))

Expand All @@ -82,16 +83,16 @@ export const lusCovid: TaskProvider<'image'> = {
// output class.
model.add(tf.layers.dense({
units: numOutputClasses,
kernelInitializer: 'varianceScaling',
activation: 'softmax'
activation: 'softmax',
kernelInitializer: tf.initializers.heNormal({ seed })
}))

model.compile({
optimizer: 'sgd',
optimizer: tf.train.sgd(0.001),
loss: 'binaryCrossentropy',
metrics: ['accuracy']
})

return Promise.resolve(new models.TFJS('image', model))
}
}
}
37 changes: 36 additions & 1 deletion discojs/src/models/tfjs.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import createDebug from "debug";
import { List, Map, Range } from "immutable";
import * as tf from '@tensorflow/tfjs'

Expand All @@ -13,6 +14,8 @@ import { BatchLogs } from './index.js'
import { Model } from './index.js'
import { EpochLogs } from './logs.js'

const debug = createDebug("discojs:models:tfjs");

type Serialized<D extends DataType> = [D, tf.io.ModelArtifacts];

/** TensorFlow JavaScript model with standard training */
Expand Down Expand Up @@ -64,11 +67,43 @@ export class TFJS<D extends "image" | "tabular"> extends Model<D> {
batch: Batched<DataFormat.ModelEncoded[D]>,
): Promise<BatchLogs> {
const { xs, ys } = this.#batchToTF(batch);
const logs = await this.model.trainOnBatch(xs, ys);
// Toggling two next lines should yield the same training loss
const logs = await this.trainFedProx(xs, ys);
// const logs = await this.model.trainOnBatch(xs, ys);
tf.dispose([xs, ys])
return this.getBatchLogs(logs)
}

// First iteration: replace trainOnBatch with custom loss computation
async trainFedProx(
xs: tf.Tensor, ys: tf.Tensor): Promise<[number, number]> {

debug(this.model.loss, this.model.losses, this.model.lossFunctions)
const lossFunction: () => tf.Scalar = () => {
this.model.apply(xs)
const logits = this.model.apply(xs)
if (Array.isArray(logits))
throw new Error('model outputs too many tensor')
if (logits instanceof tf.SymbolicTensor)
throw new Error('model outputs symbolic tensor')

// binaryCrossEntropyLoss as implemented by tensorflow.js
// https://github.com/tensorflow/tfjs/blob/2644bd0d6cea677f80e44ed4a44bea5e04aabeb3/tfjs-layers/src/losses.ts#L193
let y: tf.Tensor;
y = tf.clipByValue(logits, 0.00001, 1 - 0.00001);
y = tf.log(tf.div(y, tf.sub(1, y)));
return tf.losses.sigmoidCrossEntropy(ys, y);
}
const lossTensor = this.model.optimizer.minimize(lossFunction, true)
if (lossTensor === null) throw new Error("loss should not be null")

const loss = await lossTensor.array()
tf.dispose([xs, ys, lossTensor])

// dummy accuracy for now
return [loss, 0]
}

async #evaluate(
dataset: Dataset<Batched<DataFormat.ModelEncoded[D]>>,
): Promise<Record<"accuracy" | "loss", number>> {
Expand Down

0 comments on commit abe7996

Please sign in to comment.