Skip to content

Commit aaa9911

Browse files
refactor: model pull api and post processing (#794)
Signed-off-by: James <[email protected]> Co-authored-by: James <[email protected]>
1 parent 0c36428 commit aaa9911

19 files changed

+381
-435
lines changed

cortex-js/src/domain/models/huggingface.interface.ts

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@ export interface HuggingFaceModelVersion {
44
fileSize?: number;
55
quantization?: Quantization;
66
}
7+
8+
export interface HuggingFaceRepoSibling {
9+
rfilename: string;
10+
downloadUrl?: string;
11+
fileSize?: number;
12+
quantization?: Quantization;
13+
}
714
export interface HuggingFaceRepoData {
815
id: string;
916
modelId: string;
@@ -18,12 +25,7 @@ export interface HuggingFaceRepoData {
1825
pipeline_tag: 'text-generation';
1926
tags: Array<'transformers' | 'pytorch' | 'safetensors' | string>;
2027
cardData: Record<CardDataKeys | string, unknown>;
21-
siblings: {
22-
rfilename: string;
23-
downloadUrl?: string;
24-
fileSize?: number;
25-
quantization?: Quantization;
26-
}[];
28+
siblings: HuggingFaceRepoSibling[];
2729
createdAt: string;
2830
}
2931

cortex-js/src/domain/models/model.event.ts

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ const ModelLoadingEvents = [
77
'stopped',
88
'starting-failed',
99
'stopping-failed',
10+
'model-downloaded',
11+
'model-deleted',
1012
] as const;
1113
export type ModelLoadingEvent = (typeof ModelLoadingEvents)[number];
1214

cortex-js/src/domain/models/model.interface.ts

+1
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,5 @@ export interface ModelRuntimeParams {
168168
export interface ModelArtifact {
169169
mmproj?: string;
170170
llama_model_path?: string;
171+
model_path?: string;
171172
}

cortex-js/src/infrastructure/commanders/models/model-pull.command.ts

+3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { existsSync } from 'fs';
88
import { join } from 'node:path';
99
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
1010
import { InitCliUsecases } from '../usecases/init.cli.usecases';
11+
import { checkModelCompatibility } from '@/utils/model-check';
1112

1213
@SubCommand({
1314
name: 'pull',
@@ -35,6 +36,8 @@ export class ModelPullCommand extends CommandRunner {
3536
}
3637
const modelId = passedParams[0];
3738

39+
checkModelCompatibility(modelId);
40+
3841
await this.modelsCliUsecases.pullModel(modelId).catch((e: Error) => {
3942
if (e instanceof ModelNotFoundException)
4043
console.error('Model does not exist.');

cortex-js/src/infrastructure/commanders/models/model-start.command.ts

+8-5
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { existsSync } from 'node:fs';
1414
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
1515
import { join } from 'node:path';
1616
import { Engines } from '../types/engine.interface';
17+
import { checkModelCompatibility } from '@/utils/model-check';
1718

1819
type ModelStartOptions = {
1920
attach: boolean;
@@ -58,9 +59,14 @@ export class ModelStartCommand extends CommandRunner {
5859
!Array.isArray(existingModel.files) ||
5960
/^(http|https):\/\/[^/]+\/.*/.test(existingModel.files[0])
6061
) {
61-
console.error('Model is not available. Please pull the model first.');
62+
console.error(
63+
`${modelId} not found on filesystem. Please try 'cortex pull ${modelId}' first.`,
64+
);
6265
process.exit(1);
6366
}
67+
68+
checkModelCompatibility(modelId);
69+
6470
const engine = existingModel.engine || 'cortex.llamacpp';
6571
// Pull engine if not exist
6672
if (
@@ -72,10 +78,7 @@ export class ModelStartCommand extends CommandRunner {
7278
engine,
7379
);
7480
}
75-
if (engine === Engines.onnx && process.platform !== 'win32') {
76-
console.error('The ONNX engine does not support this OS yet.');
77-
process.exit(1);
78-
}
81+
7982
await this.cortexUsecases
8083
.startCortex(options.attach)
8184
.then(() => this.modelsCliUsecases.startModel(modelId, options.preset))

cortex-js/src/infrastructure/commanders/serve.command.ts

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ export class ServeCommand extends CommandRunner {
4242
console.log(
4343
chalk.blue(`API Playground available at http://${host}:${port}/api`),
4444
);
45-
} catch (err) {
46-
console.error(err.message ?? err);
45+
} catch {
46+
console.error(`Failed to start server. Is port ${port} in use?`);
4747
}
4848
}
4949

cortex-js/src/infrastructure/commanders/shortcuts/run.command.ts

+7-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import { join } from 'path';
1414
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
1515
import { InitCliUsecases } from '../usecases/init.cli.usecases';
1616
import { Engines } from '../types/engine.interface';
17+
import { checkModelCompatibility } from '@/utils/model-check';
1718

1819
type RunOptions = {
1920
threadId?: string;
@@ -55,7 +56,9 @@ export class RunCommand extends CommandRunner {
5556
// If not exist
5657
// Try Pull
5758
if (!(await this.modelsCliUsecases.getModel(modelId))) {
58-
console.log(`Model ${modelId} not found. Try pulling model...`);
59+
console.log(
60+
`${modelId} not found on filesystem. Downloading from remote: https://huggingface.co/cortexhub if possible.`,
61+
);
5962
await this.modelsCliUsecases.pullModel(modelId).catch((e: Error) => {
6063
if (e instanceof ModelNotFoundException)
6164
console.error('Model does not exist.');
@@ -71,10 +74,12 @@ export class RunCommand extends CommandRunner {
7174
!Array.isArray(existingModel.files) ||
7275
/^(http|https):\/\/[^/]+\/.*/.test(existingModel.files[0])
7376
) {
74-
console.error('Model is not available. Please pull the model first.');
77+
console.error('Model is not available.');
7578
process.exit(1);
7679
}
7780

81+
checkModelCompatibility(modelId);
82+
7883
const engine = existingModel.engine || 'cortex.llamacpp';
7984
// Pull engine if not exist
8085
if (
@@ -86,10 +91,6 @@ export class RunCommand extends CommandRunner {
8691
engine,
8792
);
8893
}
89-
if (engine === Engines.onnx && process.platform !== 'win32') {
90-
console.error('The ONNX engine does not support this OS yet.');
91-
process.exit(1);
92-
}
9394

9495
return this.cortexUsecases
9596
.startCortex(false)

cortex-js/src/infrastructure/commanders/usecases/init.cli.usecases.ts

+1-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
} from '@/infrastructure/constants/cortex';
1919
import { checkNvidiaGPUExist, cudaVersion } from '@/utils/cuda';
2020
import { Engines } from '../types/engine.interface';
21+
import { checkModelCompatibility } from '@/utils/model-check';
2122

2223
@Injectable()
2324
export class InitCliUsecases {
@@ -71,11 +72,6 @@ export class InitCliUsecases {
7172
)
7273
await this.installLlamaCppEngine(options, version);
7374

74-
if (engine === Engines.onnx && process.platform !== 'win32') {
75-
console.error('The ONNX engine does not support this OS yet.');
76-
process.exit(1);
77-
}
78-
7975
if (engine !== 'cortex.llamacpp')
8076
await this.installAcceleratedEngine('latest', engine);
8177

cortex-js/src/infrastructure/commanders/usecases/models.cli.usecases.ts

+30-172
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
11
import { exit } from 'node:process';
22
import { ModelsUsecases } from '@/usecases/models/models.usecases';
33
import { Model } from '@/domain/models/model.interface';
4-
import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto';
5-
import { HuggingFaceRepoData } from '@/domain/models/huggingface.interface';
64
import { InquirerService } from 'nest-commander';
75
import { Inject, Injectable } from '@nestjs/common';
86
import { Presets, SingleBar } from 'cli-progress';
9-
import { LLAMA_2 } from '@/infrastructure/constants/prompt-constants';
107

118
import { HttpService } from '@nestjs/axios';
129
import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-success.dto';
1310
import { UpdateModelDto } from '@/infrastructure/dtos/models/update-model.dto';
1411
import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
15-
import { join, basename } from 'path';
12+
import { join } from 'path';
1613
import { load } from 'js-yaml';
1714
import { existsSync, readdirSync, readFileSync } from 'fs';
18-
import { isLocalModel, normalizeModelId } from '@/utils/normalize-model-id';
19-
import { fetchJanRepoData, getHFModelMetadata } from '@/utils/huggingface';
20-
import { createWriteStream, mkdirSync, promises } from 'node:fs';
21-
import { firstValueFrom } from 'rxjs';
22-
import { Engines } from '../types/engine.interface';
15+
import { isLocalModel } from '@/utils/normalize-model-id';
16+
import { HuggingFaceRepoSibling } from '@/domain/models/huggingface.interface';
2317

2418
@Injectable()
2519
export class ModelsCliUsecases {
@@ -120,170 +114,34 @@ export class ModelsCliUsecases {
120114
console.error('Model already exists');
121115
process.exit(1);
122116
}
123-
124-
if (modelId.includes('onnx') || modelId.includes('tensorrt')) {
125-
await this.pullEngineModelFiles(modelId);
126-
} else {
127-
await this.pullGGUFModel(modelId);
128-
const bar = new SingleBar({}, Presets.shades_classic);
129-
bar.start(100, 0);
130-
const callback = (progress: number) => {
131-
bar.update(progress);
132-
};
133-
134-
try {
135-
await this.modelsUsecases.downloadModel(modelId, callback);
136-
137-
const model = await this.modelsUsecases.findOne(modelId);
138-
const fileUrl = join(
139-
await this.fileService.getModelsPath(),
140-
normalizeModelId(modelId),
141-
basename((model?.files as string[])[0]),
142-
);
143-
await this.modelsUsecases.update(modelId, {
144-
files: [fileUrl],
145-
name: modelId.replace(':default', ''),
146-
});
147-
} catch (err) {
148-
bar.stop();
149-
throw err;
150-
}
151-
}
152-
}
153-
154-
/**
155-
* It's to pull engine model files from HuggingFace repository
156-
* @param modelId
157-
*/
158-
private async pullEngineModelFiles(modelId: string) {
159-
const modelsContainerDir = await this.fileService.getModelsPath();
160-
161-
if (!existsSync(modelsContainerDir)) {
162-
mkdirSync(modelsContainerDir, { recursive: true });
163-
}
164-
165-
const modelFolder = join(modelsContainerDir, normalizeModelId(modelId));
166-
await promises.mkdir(modelFolder, { recursive: true }).catch(() => {});
167-
168-
const files = (await fetchJanRepoData(modelId)).siblings;
169-
for (const file of files) {
170-
console.log(`Downloading ${file.rfilename}`);
171-
const bar = new SingleBar({}, Presets.shades_classic);
172-
bar.start(100, 0);
173-
const response = await firstValueFrom(
174-
this.httpService.get(file.downloadUrl ?? '', {
175-
responseType: 'stream',
176-
}),
177-
);
178-
if (!response) {
179-
throw new Error('Failed to download model');
180-
}
181-
182-
await new Promise((resolve, reject) => {
183-
const writer = createWriteStream(join(modelFolder, file.rfilename));
184-
let receivedBytes = 0;
185-
const totalBytes = response.headers['content-length'];
186-
187-
writer.on('finish', () => {
188-
resolve(true);
189-
});
190-
191-
writer.on('error', (error) => {
192-
reject(error);
193-
});
194-
195-
response.data.on('data', (chunk: any) => {
196-
receivedBytes += chunk.length;
197-
bar.update(Math.floor((receivedBytes / totalBytes) * 100));
198-
});
199-
200-
response.data.pipe(writer);
117+
await this.modelsUsecases.pullModel(modelId, true, (files) => {
118+
return new Promise<HuggingFaceRepoSibling>(async (resolve) => {
119+
const listChoices = files
120+
.filter((e) => e.quantization != null)
121+
.map((e) => {
122+
return {
123+
name: e.quantization,
124+
value: e.quantization,
125+
};
126+
});
127+
128+
if (listChoices.length > 1) {
129+
const { quantization } = await this.inquirerService.inquirer.prompt({
130+
type: 'list',
131+
name: 'quantization',
132+
message: 'Select quantization',
133+
choices: listChoices,
134+
});
135+
resolve(
136+
files
137+
.filter((e) => !!e.quantization)
138+
.find((e: any) => e.quantization === quantization) ?? files[0],
139+
);
140+
} else {
141+
resolve(files.find((e) => e.rfilename.includes('.gguf')) ?? files[0]);
142+
}
201143
});
202-
bar.stop();
203-
}
204-
205-
const model: CreateModelDto = load(
206-
readFileSync(join(modelFolder, 'model.yml'), 'utf-8'),
207-
) as CreateModelDto;
208-
model.files = [join(modelFolder)];
209-
model.model = modelId;
210-
211-
if (!(await this.modelsUsecases.findOne(modelId)))
212-
await this.modelsUsecases.create(model);
213-
214-
if (model.engine === Engines.tensorrtLLM) {
215-
if (process.platform === 'win32')
216-
console.log(
217-
'Please ensure that you install MPI and its SDK to use the TensorRT engine, as it also requires the Cuda Toolkit 12.3 to work. Refs:\n- https://github.com/microsoft/Microsoft-MPI/releases/download/v10.1.1/msmpisetup.exe\n- https://github.com/microsoft/Microsoft-MPI/releases/download/v10.1.1/msmpisdk.msi',
218-
);
219-
else if (process.platform === 'linux')
220-
console.log(
221-
'Please ensure that you install OpenMPI and its SDK to use the TensorRT engine, as it also requires the Cuda Toolkit 12.3 to work.\nYou can install OpenMPI by running "sudo apt update && sudo apt install openmpi-bin libopenmpi-dev"',
222-
);
223-
}
224-
}
225-
/**
226-
* It's to pull model from HuggingFace repository
227-
* It could be a model from Jan's repo or other authors
228-
* @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b"
229-
*/
230-
private async pullGGUFModel(modelId: string) {
231-
const data: HuggingFaceRepoData =
232-
await this.modelsUsecases.fetchModelMetadata(modelId);
233-
234-
let modelVersion;
235-
236-
const listChoices = data.siblings
237-
.filter((e) => e.quantization != null)
238-
.map((e) => {
239-
return {
240-
name: e.quantization,
241-
value: e.quantization,
242-
};
243-
});
244-
245-
if (listChoices.length > 1) {
246-
const { quantization } = await this.inquirerService.inquirer.prompt({
247-
type: 'list',
248-
name: 'quantization',
249-
message: 'Select quantization',
250-
choices: listChoices,
251-
});
252-
modelVersion = data.siblings
253-
.filter((e) => !!e.quantization)
254-
.find((e: any) => e.quantization === quantization);
255-
} else {
256-
modelVersion = data.siblings.find((e) => e.rfilename.includes('.gguf'));
257-
}
258-
259-
if (!modelVersion) throw 'No expected quantization found';
260-
const metadata = await getHFModelMetadata(modelVersion.downloadUrl!);
261-
262-
const promptTemplate = metadata?.promptTemplate ?? LLAMA_2;
263-
const stopWords: string[] = [metadata?.stopWord ?? ''];
264-
265-
const model: CreateModelDto = {
266-
files: [modelVersion.downloadUrl ?? ''],
267-
model: modelId,
268-
name: modelId,
269-
prompt_template: promptTemplate,
270-
stop: stopWords,
271-
272-
// Default Inference Params
273-
stream: true,
274-
max_tokens: 4098,
275-
frequency_penalty: 0.7,
276-
presence_penalty: 0.7,
277-
temperature: 0.7,
278-
top_p: 0.7,
279-
280-
// Default Model Settings
281-
ctx_len: 4096,
282-
ngl: 100,
283-
engine: Engines.llamaCPP,
284-
};
285-
if (!(await this.modelsUsecases.findOne(modelId)))
286-
await this.modelsUsecases.create(model);
144+
});
287145
}
288146

289147
/**

0 commit comments

Comments
 (0)