Skip to content

Commit

Permalink
Added wikitext task, added text-dataset for core, node and web, renam…
Browse files Browse the repository at this point in the history
…ed task.taskId to task.id
  • Loading branch information
peacefulotter committed Feb 3, 2024
1 parent a25a9af commit a0be7af
Show file tree
Hide file tree
Showing 62 changed files with 5,803 additions and 3,310 deletions.
103 changes: 53 additions & 50 deletions discojs/discojs-core/src/async_informant.ts
Original file line number Diff line number Diff line change
@@ -1,64 +1,67 @@
import { AggregatorBase } from './aggregator'

export class AsyncInformant<T> {
private _round = 0
private _currentNumberOfParticipants = 0
private _totalNumberOfParticipants = 0
private _averageNumberOfParticipants = 0
private _round = 0

Check failure on line 4 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4
private _currentNumberOfParticipants = 0

Check failure on line 5 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4
private _totalNumberOfParticipants = 0

Check failure on line 6 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4
private _averageNumberOfParticipants = 0

Check failure on line 7 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4

constructor (
private readonly aggregator: AggregatorBase<T>
) {}
constructor(private readonly aggregator: AggregatorBase<T>) {}

Check failure on line 9 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4

Check failure on line 9 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Missing space before function parentheses

update (): void {
console.debug('before:')
this.printAllInfos()
if (this.round === 0 || this.round < this.aggregator.round) {
this._round = this.aggregator.round
this._currentNumberOfParticipants = this.aggregator.size
this._averageNumberOfParticipants = this.totalNumberOfParticipants / this.round
this._totalNumberOfParticipants += this.currentNumberOfParticipants
} else {
this._round = this.aggregator.round
update(): void {

Check failure on line 11 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 2 spaces but found 4

Check failure on line 11 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Missing space before function parentheses
console.debug('before:')

Check failure on line 12 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 4 spaces but found 8
this.printAllInfos()

Check failure on line 13 in discojs/discojs-core/src/async_informant.ts

View workflow job for this annotation

GitHub Actions / lint-lib-core

Expected indentation of 4 spaces but found 8
if (this.round === 0 || this.round < this.aggregator.round) {
this._round = this.aggregator.round
this._currentNumberOfParticipants = this.aggregator.size
this._averageNumberOfParticipants =
this.totalNumberOfParticipants / this.round
this._totalNumberOfParticipants += this.currentNumberOfParticipants
} else {
this._round = this.aggregator.round
}
console.debug('after:')
this.printAllInfos()
}
console.debug('after:')
this.printAllInfos()
}

// Getter functions
get round (): number {
return this._round
}
// Getter functions
get round(): number {
return this._round
}

get currentNumberOfParticipants (): number {
return this._currentNumberOfParticipants
}
get currentNumberOfParticipants(): number {
return this._currentNumberOfParticipants
}

get totalNumberOfParticipants (): number {
return this._totalNumberOfParticipants
}
get totalNumberOfParticipants(): number {
return this._totalNumberOfParticipants
}

get averageNumberOfParticipants (): number {
return this._averageNumberOfParticipants
}
get averageNumberOfParticipants(): number {
return this._averageNumberOfParticipants
}

getAllStatistics (): Record<
'round' | 'currentNumberOfParticipants' | 'totalNumberOfParticipants' | 'averageNumberOfParticipants', number
> {
return {
round: this.round,
currentNumberOfParticipants: this.currentNumberOfParticipants,
totalNumberOfParticipants: this.totalNumberOfParticipants,
averageNumberOfParticipants: this.averageNumberOfParticipants
getAllStatistics(): Record<
| 'round'
| 'currentNumberOfParticipants'
| 'totalNumberOfParticipants'
| 'averageNumberOfParticipants',
number
> {
return {
round: this.round,
currentNumberOfParticipants: this.currentNumberOfParticipants,
totalNumberOfParticipants: this.totalNumberOfParticipants,
averageNumberOfParticipants: this.averageNumberOfParticipants,
}
}
}

// Debug
public printAllInfos (): void {
console.debug('task:', this.aggregator.task.taskID)
console.debug('round:', this.round)
console.debug('participants:', this.currentNumberOfParticipants)
console.debug('total:', this.totalNumberOfParticipants)
console.debug('average:', this.averageNumberOfParticipants)
}
// Debug
public printAllInfos(): void {
console.debug('task:', this.aggregator.task.id)
console.debug('round:', this.round)
console.debug('participants:', this.currentNumberOfParticipants)
console.debug('total:', this.totalNumberOfParticipants)
console.debug('average:', this.averageNumberOfParticipants)
}
}
198 changes: 102 additions & 96 deletions discojs/discojs-core/src/client/base.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import { Set } from 'immutable'
import axios from 'axios'

import { tf, Task, TrainingInformant, serialization, WeightsContainer } from '..'
import {
tf,
Task,
TrainingInformant,
serialization,
WeightsContainer,
} from '..'
import { NodeID } from './types'
import { EventConnection } from './event_connection'
import { Aggregator } from '../aggregator'
Expand All @@ -11,119 +17,119 @@ import { Aggregator } from '../aggregator'
* communication with other nodes, be it peers or a server.
*/
export abstract class Base {
/**
* Own ID provided by the network's server.
*/
protected _ownId?: NodeID
/**
* The network's server.
*/
protected _server?: EventConnection
/**
* The aggregator's result produced after aggregation.
*/
protected aggregationResult?: Promise<WeightsContainer>

constructor (
/**
* The network server's URL to connect to.
* Own ID provided by the network's server.
*/
public readonly url: URL,
protected _ownId?: NodeID
/**
* The client's corresponding task.
* The network's server.
*/
public readonly task: Task,
protected _server?: EventConnection
/**
* The client's aggregator.
* The aggregator's result produced after aggregation.
*/
public readonly aggregator: Aggregator
) {}
protected aggregationResult?: Promise<WeightsContainer>

/**
* Handles the connection process from the client to any sort of network server.
*/
async connect (): Promise<void> {}
constructor(
/**
* The network server's URL to connect to.
*/
public readonly url: URL,
/**
* The client's corresponding task.
*/
public readonly task: Task,
/**
* The client's aggregator.
*/
public readonly aggregator: Aggregator
) {}

/**
* Handles the disconnection process of the client from any sort of network server.
*/
async disconnect (): Promise<void> {}
/**
* Handles the connection process from the client to any sort of network server.
*/
async connect(): Promise<void> {}

/**
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel (): Promise<tf.LayersModel> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
}
url.pathname += `tasks/${this.task.taskID}/model.json`
/**
* Handles the disconnection process of the client from any sort of network server.
*/
async disconnect(): Promise<void> {}

const response = await axios.get(url.href)
/**
* Fetches the latest model available on the network's server, for the adequate task.
* @returns The latest model
*/
async getLatestModel(): Promise<tf.LayersModel> {
const url = new URL('', this.url.href)
if (!url.pathname.endsWith('/')) {
url.pathname += '/'
}
url.pathname += `tasks/${this.task.id}/model.json`

return await serialization.model.decode(response.data)
}
const response = await axios.get(url.href)

/**
* Communication callback called once at the beginning of the training instance.
* @param weights The initial model weights
* @param trainingInformant The training informant
*/
async onTrainBeginCommunication (
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}
return await serialization.model.decode(response.data)
}

/**
* Communication callback called once at the end of the training instance.
* @param weights The final model weights
* @param trainingInformant The training informant
*/
async onTrainEndCommunication (
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}
/**
* Communication callback called once at the beginning of the training instance.
* @param weights The initial model weights
* @param trainingInformant The training informant
*/
async onTrainBeginCommunication(
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called at the beginning of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundBeginCommunication (
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}
/**
* Communication callback called once at the end of the training instance.
* @param weights The final model weights
* @param trainingInformant The training informant
*/
async onTrainEndCommunication(
weights: WeightsContainer,
trainingInformant: TrainingInformant
): Promise<void> {}

/**
* Communication callback called the end of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundEndCommunication (
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}
/**
* Communication callback called at the beginning of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundBeginCommunication(
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}

get nodes (): Set<NodeID> {
return this.aggregator.nodes
}
/**
* Communication callback called the end of every training round.
* @param weights The most recent local weight updates
* @param round The current training round
* @param trainingInformant The training informant
*/
async onRoundEndCommunication(
weights: WeightsContainer,
round: number,
trainingInformant: TrainingInformant
): Promise<void> {}

get nodes(): Set<NodeID> {
return this.aggregator.nodes
}

get ownId (): NodeID {
if (this._ownId === undefined) {
throw new Error('the node is not connected')
get ownId(): NodeID {
if (this._ownId === undefined) {
throw new Error('the node is not connected')
}
return this._ownId
}
return this._ownId
}

get server (): EventConnection {
if (this._server === undefined) {
throw new Error('server undefined, not connected')
get server(): EventConnection {
if (this._server === undefined) {
throw new Error('server undefined, not connected')
}
return this._server
}
return this._server
}
}
Loading

0 comments on commit a0be7af

Please sign in to comment.