|
| 1 | +import { FakeTransportBuilder } from '@improbable-eng/grpc-web-fake-transport'; |
| 2 | +import { beforeEach, describe, expect, it, vi } from 'vitest'; |
| 3 | +import { |
| 4 | + CancelTrainingJobRequest, |
| 5 | + DeleteCompletedTrainingJobRequest, |
| 6 | + GetTrainingJobRequest, |
| 7 | + GetTrainingJobResponse, |
| 8 | + ListTrainingJobsRequest, |
| 9 | + ListTrainingJobsResponse, |
| 10 | + ModelType, |
| 11 | + SubmitTrainingJobRequest, |
| 12 | + SubmitTrainingJobResponse, |
| 13 | + TrainingJobMetadata, |
| 14 | + TrainingStatus, |
| 15 | +} from '../gen/app/mltraining/v1/ml_training_pb'; |
| 16 | +import { MLTrainingServiceClient } from '../gen/app/mltraining/v1/ml_training_pb_service'; |
| 17 | +vi.mock('../gen/app/mltraining/v1/ml_training_pb_service'); |
| 18 | +import { MlTrainingClient } from './ml-training-client'; |
| 19 | + |
| 20 | +const subject = () => |
| 21 | + new MlTrainingClient('fakeServiceHoost', { |
| 22 | + transport: new FakeTransportBuilder().build(), |
| 23 | + }); |
| 24 | + |
| 25 | +describe('MlTrainingClient tests', () => { |
| 26 | + describe('submitTrainingJob tests', () => { |
| 27 | + const type = ModelType.MODEL_TYPE_UNSPECIFIED; |
| 28 | + beforeEach(() => { |
| 29 | + vi.spyOn(MLTrainingServiceClient.prototype, 'submitTrainingJob') |
| 30 | + // @ts-expect-error compiler is matching incorrect function signature |
| 31 | + .mockImplementationOnce((_req: SubmitTrainingJobRequest, _md, cb) => { |
| 32 | + const response = new SubmitTrainingJobResponse(); |
| 33 | + response.setId('fakeId'); |
| 34 | + cb(null, response); |
| 35 | + }); |
| 36 | + }); |
| 37 | + |
| 38 | + it('submit job training job', async () => { |
| 39 | + const response = await subject().submitTrainingJob( |
| 40 | + 'org_id', |
| 41 | + 'dataset_id', |
| 42 | + 'model_name', |
| 43 | + 'model_version', |
| 44 | + type, |
| 45 | + ['tag1'] |
| 46 | + ); |
| 47 | + expect(response).toEqual('fakeId'); |
| 48 | + }); |
| 49 | + }); |
| 50 | + |
| 51 | + describe('getTrainingJob tests', () => { |
| 52 | + const metadata: TrainingJobMetadata = new TrainingJobMetadata(); |
| 53 | + metadata.setId('id'); |
| 54 | + metadata.setDatasetId('dataset_id'); |
| 55 | + metadata.setOrganizationId('org_id'); |
| 56 | + metadata.setModelVersion('model_version'); |
| 57 | + metadata.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); |
| 58 | + metadata.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); |
| 59 | + metadata.setSyncedModelId('synced_model_id'); |
| 60 | + |
| 61 | + beforeEach(() => { |
| 62 | + vi.spyOn(MLTrainingServiceClient.prototype, 'getTrainingJob') |
| 63 | + // @ts-expect-error compiler is matching incorrect function signature |
| 64 | + .mockImplementationOnce((_req: GetTrainingJobRequest, _md, cb) => { |
| 65 | + const response = new GetTrainingJobResponse(); |
| 66 | + response.setMetadata(metadata); |
| 67 | + cb(null, response); |
| 68 | + }); |
| 69 | + }); |
| 70 | + |
| 71 | + it('get training job', async () => { |
| 72 | + const response = await subject().getTrainingJob('id'); |
| 73 | + expect(response).toEqual(metadata); |
| 74 | + }); |
| 75 | + }); |
| 76 | + |
| 77 | + describe('listTrainingJobs', () => { |
| 78 | + const status = TrainingStatus.TRAINING_STATUS_UNSPECIFIED; |
| 79 | + const md1 = new TrainingJobMetadata(); |
| 80 | + md1.setId('id1'); |
| 81 | + md1.setDatasetId('dataset_id1'); |
| 82 | + md1.setOrganizationId('org_id1'); |
| 83 | + md1.setModelVersion('model_version1'); |
| 84 | + md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); |
| 85 | + md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); |
| 86 | + md1.setSyncedModelId('synced_model_id1'); |
| 87 | + const md2 = new TrainingJobMetadata(); |
| 88 | + md1.setId('id2'); |
| 89 | + md1.setDatasetId('dataset_id2'); |
| 90 | + md1.setOrganizationId('org_id2'); |
| 91 | + md1.setModelVersion('model_version2'); |
| 92 | + md1.setModelType(ModelType.MODEL_TYPE_UNSPECIFIED); |
| 93 | + md1.setStatus(TrainingStatus.TRAINING_STATUS_UNSPECIFIED); |
| 94 | + md1.setSyncedModelId('synced_model_id2'); |
| 95 | + const jobs = [md1, md2]; |
| 96 | + |
| 97 | + beforeEach(() => { |
| 98 | + vi.spyOn(MLTrainingServiceClient.prototype, 'listTrainingJobs') |
| 99 | + // @ts-expect-error compiler is matching incorrect function signature |
| 100 | + .mockImplementationOnce((_req: ListTrainingJobsRequest, _md, cb) => { |
| 101 | + const response = new ListTrainingJobsResponse(); |
| 102 | + response.setJobsList(jobs); |
| 103 | + cb(null, response); |
| 104 | + }); |
| 105 | + }); |
| 106 | + |
| 107 | + it('list training jobs', async () => { |
| 108 | + const response = await subject().listTrainingJobs('org_id', status); |
| 109 | + expect(response).toEqual([md1.toObject(), md2.toObject()]); |
| 110 | + }); |
| 111 | + }); |
| 112 | + |
| 113 | + describe('cancelTrainingJob tests', () => { |
| 114 | + const id = 'id'; |
| 115 | + beforeEach(() => { |
| 116 | + vi.spyOn(MLTrainingServiceClient.prototype, 'cancelTrainingJob') |
| 117 | + // @ts-expect-error compiler is matching incorrect function signature |
| 118 | + .mockImplementationOnce((req: CancelTrainingJobRequest, _md, cb) => { |
| 119 | + expect(req.getId()).toStrictEqual(id); |
| 120 | + cb(null, {}); |
| 121 | + }); |
| 122 | + }); |
| 123 | + it('cancel training job', async () => { |
| 124 | + expect(await subject().cancelTrainingJob(id)).toStrictEqual(null); |
| 125 | + }); |
| 126 | + }); |
| 127 | + |
| 128 | + describe('deleteCompletedTrainingJob tests', () => { |
| 129 | + const id = 'id'; |
| 130 | + beforeEach(() => { |
| 131 | + vi.spyOn( |
| 132 | + MLTrainingServiceClient.prototype, |
| 133 | + 'deleteCompletedTrainingJob' |
| 134 | + ).mockImplementationOnce( |
| 135 | + // @ts-expect-error compiler is matching incorrect function signature |
| 136 | + (req: DeleteCompletedTrainingJobRequest, _md, cb) => { |
| 137 | + expect(req.getId()).toStrictEqual(id); |
| 138 | + cb(null, {}); |
| 139 | + } |
| 140 | + ); |
| 141 | + }); |
| 142 | + it('delete completed training job', async () => { |
| 143 | + expect(await subject().deleteCompletedTrainingJob(id)).toEqual(null); |
| 144 | + }); |
| 145 | + }); |
| 146 | +}); |
0 commit comments