Skip to content

Commit fd37ff3

Browse files
purplenicole730github-actions[bot]viambotdependabot[bot]
authored
RSDK-7200: add ml training wrapper (#272)
Signed-off-by: dependabot[bot] <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: viambot <[email protected]> Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
1 parent 3544470 commit fd37ff3

11 files changed

+432
-135
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ update-buf: $(node_modules)
5555
.PHONY: build-buf
5656
build-buf: $(node_modules) clean-buf
5757
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/googleapis/googleapis)
58-
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/api) --path common,component,robot,service,app,provisioning
58+
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/api) --path common,component,robot,service,app,provisioning,tagger
5959
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/erdaniels/gostream)
6060
$(buf) generate $$(./scripts/get-buf-lock-version.js buf.build/viamrobotics/goutils)
6161

src/app/ml-training-client.test.ts

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport';
2+
import { beforeEach, describe, expect, it, vi } from 'vitest';
3+
import {
4+
CancelTrainingJobRequest,
5+
DeleteCompletedTrainingJobRequest,
6+
GetTrainingJobRequest,
7+
GetTrainingJobResponse,
8+
ListTrainingJobsRequest,
9+
ListTrainingJobsResponse,
10+
ModelType,
11+
SubmitTrainingJobRequest,
12+
SubmitTrainingJobResponse,
13+
TrainingJobMetadata,
14+
TrainingStatus,
15+
} from '../gen/app/mltraining/v1/ml_training_pb';
16+
import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service';
17+
vi.mock('../gen/app/mltraining/v1/ml_training_pb_service');
18+
import { MlTrainingClient } from './ml-training-client';
19+
20+
const subject = () =>
21+
new MlTrainingClient('fakeServiceHoost', {
22+
transport: new FakeTransportBuilder().build(),
23+
});
24+
25+
describe('MlTrainingClient tests', () => {
26+
describe('submitTrainingJob tests', () => {
27+
const type = ModelType.MODEL_TYPE_UNSPECIFIED;
28+
beforeEach(() => {
29+
vi.spyOn(MLTrainingServiceClient.prototype, 'submitTrainingJob')
30+
// @ts-expect-error compiler is matching incorrect function signature
31+
.mockImplementationOnce((_req: SubmitTrainingJobRequest, _md, cb) => {
32+
const response = new SubmitTrainingJobResponse();
33+
response.setId('fakeId');
34+
cb(null, response);
35+
});
36+
});
37+
38+
it('submit job training job', async () => {
39+
const response = await subject().submitTrainingJob(
40+
'org_id',
41+
'dataset_id',
42+
'model_name',
43+
'model_version',
44+
type,
45+
['tag1']
46+
);
47+
expect(response).toEqual('fakeId');
48+
});
49+
});
50+
51+
describe('getTrainingJob tests', () => {
52+
const metadata: TrainingJobMetadata = new TrainingJobMetadata();
53+
metadata.setId('id');
54+
metadata.setDatasetId('dataset_id');
55+
metadata.setOrganizationId('org_id');
56+
metadata.setModelVersion('model_version');
57+
metadata.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED);
58+
metadata.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED);
59+
metadata.setSyncedModelId('synced_model_id');
60+
61+
beforeEach(() => {
62+
vi.spyOn(MLTrainingServiceClient.prototype, 'getTrainingJob')
63+
// @ts-expect-error compiler is matching incorrect function signature
64+
.mockImplementationOnce((_req: GetTrainingJobRequest, _md, cb) => {
65+
const response = new GetTrainingJobResponse();
66+
response.setMetadata(metadata);
67+
cb(null, response);
68+
});
69+
});
70+
71+
it('get training job', async () => {
72+
const response = await subject().getTrainingJob('id');
73+
expect(response).toEqual(metadata);
74+
});
75+
});
76+
77+
describe('listTrainingJobs', () => {
78+
const status = TrainingStatus.TRAINING_STATUS_UNSPECIFIED;
79+
const md1 = new TrainingJobMetadata();
80+
md1.setId('id1');
81+
md1.setDatasetId('dataset_id1');
82+
md1.setOrganizationId('org_id1');
83+
md1.setModelVersion('model_version1');
84+
md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED);
85+
md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED);
86+
md1.setSyncedModelId('synced_model_id1');
87+
const md2 = new TrainingJobMetadata();
88+
md1.setId('id2');
89+
md1.setDatasetId('dataset_id2');
90+
md1.setOrganizationId('org_id2');
91+
md1.setModelVersion('model_version2');
92+
md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED);
93+
md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED);
94+
md1.setSyncedModelId('synced_model_id2');
95+
const jobs = [md1, md2];
96+
97+
beforeEach(() => {
98+
vi.spyOn(MLTrainingServiceClient.prototype, 'listTrainingJobs')
99+
// @ts-expect-error compiler is matching incorrect function signature
100+
.mockImplementationOnce((_req: ListTrainingJobsRequest, _md, cb) => {
101+
const response = new ListTrainingJobsResponse();
102+
response.setJobsList(jobs);
103+
cb(null, response);
104+
});
105+
});
106+
107+
it('list training jobs', async () => {
108+
const response = await subject().listTrainingJobs('org_id', status);
109+
expect(response).toEqual([md1.toObject(), md2.toObject()]);
110+
});
111+
});
112+
113+
describe('cancelTrainingJob tests', () => {
114+
const id = 'id';
115+
beforeEach(() => {
116+
vi.spyOn(MLTrainingServiceClient.prototype, 'cancelTrainingJob')
117+
// @ts-expect-error compiler is matching incorrect function signature
118+
.mockImplementationOnce((req: CancelTrainingJobRequest, _md, cb) => {
119+
expect(req.getId()).toStrictEqual(id);
120+
cb(null, {});
121+
});
122+
});
123+
it('cancel training job', async () => {
124+
expect(await subject().cancelTrainingJob(id)).toStrictEqual(null);
125+
});
126+
});
127+
128+
describe('deleteCompletedTrainingJob tests', () => {
129+
const id = 'id';
130+
beforeEach(() => {
131+
vi.spyOn(
132+
MLTrainingServiceClient.prototype,
133+
'deleteCompletedTrainingJob'
134+
).mockImplementationOnce(
135+
// @ts-expect-error compiler is matching incorrect function signature
136+
(req: DeleteCompletedTrainingJobRequest, _md, cb) => {
137+
expect(req.getId()).toStrictEqual(id);
138+
cb(null, {});
139+
}
140+
);
141+
});
142+
it('delete completed training job', async () => {
143+
expect(await subject().deleteCompletedTrainingJob(id)).toEqual(null);
144+
});
145+
});
146+
});

src/app/ml-training-client.ts

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import { type RpcOptions } from '@improbable-eng/grpc-web/dist/typings/client.d';
2+
import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service';
3+
import pb from '../gen/app/mltraining/v1/ml_training_pb';
4+
import { promisify } from '../utils';
5+
6+
type ValueOf<T> = T[keyof T];
7+
export const { ModelType } = pb;
8+
export type ModelType = ValueOf<typeof pb.ModelType>;
9+
export const { TrainingStatus } = pb;
10+
export type TrainingStatus = ValueOf<typeof pb.TrainingStatus>;
11+
12+
export class MlTrainingClient {
13+
private service: MLTrainingServiceClient;
14+
15+
constructor(serviceHost: string, grpcOptions: RpcOptions) {
16+
this.service = new MLTrainingServiceClient(serviceHost, grpcOptions);
17+
}
18+
19+
async submitTrainingJob(
20+
orgId: string,
21+
datasetId: string,
22+
modelName: string,
23+
modelVersion: string,
24+
modelType: ModelType,
25+
tagsList: string[]
26+
) {
27+
const { service } = this;
28+
29+
const req = new pb.SubmitTrainingJobRequest();
30+
req.setOrganizationId(orgId);
31+
req.setDatasetId(datasetId);
32+
req.setModelName(modelName);
33+
req.setModelVersion(modelVersion);
34+
req.setModelType(modelType);
35+
req.setTagsList(tagsList);
36+
37+
const response = await promisify<
38+
pb.SubmitTrainingJobRequest,
39+
pb.SubmitTrainingJobResponse
40+
>(service.submitTrainingJob.bind(service), req);
41+
return response.getId();
42+
}
43+
44+
async getTrainingJob(id: string) {
45+
const { service } = this;
46+
47+
const req = new pb.GetTrainingJobRequest();
48+
req.setId(id);
49+
50+
const response = await promisify<
51+
pb.GetTrainingJobRequest,
52+
pb.GetTrainingJobResponse
53+
>(service.getTrainingJob.bind(service), req);
54+
return response.getMetadata();
55+
}
56+
57+
async listTrainingJobs(orgId: string, status: TrainingStatus) {
58+
const { service } = this;
59+
60+
const req = new pb.ListTrainingJobsRequest();
61+
req.setOrganizationId(orgId);
62+
req.setStatus(status);
63+
64+
const response = await promisify<
65+
pb.ListTrainingJobsRequest,
66+
pb.ListTrainingJobsResponse
67+
>(service.listTrainingJobs.bind(service), req);
68+
return response.toObject().jobsList;
69+
}
70+
71+
async cancelTrainingJob(id: string) {
72+
const { service } = this;
73+
74+
const req = new pb.CancelTrainingJobRequest();
75+
req.setId(id);
76+
77+
await promisify<pb.CancelTrainingJobRequest, pb.CancelTrainingJobResponse>(
78+
service.cancelTrainingJob.bind(service),
79+
req
80+
);
81+
return null;
82+
}
83+
84+
async deleteCompletedTrainingJob(id: string) {
85+
const { service } = this;
86+
87+
const req = new pb.DeleteCompletedTrainingJobRequest();
88+
req.setId(id);
89+
90+
await promisify<
91+
pb.DeleteCompletedTrainingJobRequest,
92+
pb.DeleteCompletedTrainingJobResponse
93+
>(service.deleteCompletedTrainingJob.bind(service), req);
94+
return null;
95+
}
96+
}

src/app/provisioning-client.test.ts

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// @vitest-environment happy-dom
2+
3+
import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport';
4+
import { beforeEach, expect, it, vi } from 'vitest';
5+
import {
6+
CloudConfig,
7+
SetNetworkCredentialsRequest,
8+
SetSmartMachineCredentialsRequest,
9+
} from '../gen/provisioning/v1/provisioning_pb';
10+
import { ProvisioningServiceClient } from '../gen/provisioning/v1/provisioning_pb_service';
11+
import { ProvisioningClient } from './provisioning-client';
12+
13+
const subject = () =>
14+
new ProvisioningClient('fakeServiceHost', {
15+
transport: new FakeTransportBuilder().build(),
16+
});
17+
18+
const testProvisioningInfo = {
19+
fragmentId: 'id',
20+
model: 'model',
21+
manufacturer: 'manufacturer',
22+
};
23+
const testNetworkInfo = {
24+
type: 'type',
25+
ssid: 'ssid',
26+
security: 'security',
27+
signal: 999,
28+
connected: 'true',
29+
lastError: 'last error',
30+
};
31+
const testSmartMachineStatus = {
32+
provisioningInfo: testProvisioningInfo,
33+
hasSmartMachineCredentials: true,
34+
isOnline: true,
35+
latestConnectionAttempt: testNetworkInfo,
36+
errorsList: ['error', 'err'],
37+
};
38+
const type = 'type';
39+
const ssid = 'ssid';
40+
const psk = 'psk';
41+
const cloud = new CloudConfig();
42+
cloud.setId('id');
43+
cloud.setSecret('secret');
44+
cloud.setAppAddress('app_address');
45+
46+
beforeEach(() => {
47+
ProvisioningServiceClient.prototype.getSmartMachineStatus = vi
48+
.fn()
49+
.mockImplementation((_req, _md, cb) => {
50+
cb(null, {
51+
toObject: () => testSmartMachineStatus,
52+
});
53+
});
54+
55+
ProvisioningServiceClient.prototype.getNetworkList = vi
56+
.fn()
57+
.mockImplementation((_req, _md, cb) => {
58+
cb(null, {
59+
toObject: () => ({ networksList: [testNetworkInfo] }),
60+
});
61+
});
62+
63+
ProvisioningServiceClient.prototype.setNetworkCredentials = vi
64+
.fn()
65+
.mockImplementation((req: SetNetworkCredentialsRequest, _md, cb) => {
66+
expect(req.getType()).toStrictEqual(type);
67+
expect(req.getSsid()).toStrictEqual(ssid);
68+
expect(req.getPsk()).toStrictEqual(psk);
69+
cb(null, {});
70+
});
71+
72+
ProvisioningServiceClient.prototype.setSmartMachineCredentials = vi
73+
.fn()
74+
.mockImplementation((req: SetSmartMachineCredentialsRequest, _md, cb) => {
75+
expect(req.getCloud()).toStrictEqual(cloud);
76+
cb(null, {});
77+
});
78+
});
79+
80+
it('getSmartMachineStatus', async () => {
81+
await expect(subject().getSmartMachineStatus()).resolves.toStrictEqual(
82+
testSmartMachineStatus
83+
);
84+
});
85+
86+
it('getNetworkList', async () => {
87+
await expect(subject().getNetworkList()).resolves.toStrictEqual([
88+
testNetworkInfo,
89+
]);
90+
});
91+
92+
it('setNetworkCredentials', async () => {
93+
await expect(
94+
subject().setNetworkCredentials(type, ssid, psk)
95+
).resolves.toStrictEqual(undefined);
96+
});
97+
98+
it('setSmartMachineCredentials', async () => {
99+
await expect(
100+
subject().setSmartMachineCredentials(cloud.toObject())
101+
).resolves.toStrictEqual(undefined);
102+
});

0 commit comments

Comments
 (0)