|
1 | 1 | import { exit } from 'node:process';
|
2 | 2 | import { ModelsUsecases } from '@/usecases/models/models.usecases';
|
3 | 3 | import { Model } from '@/domain/models/model.interface';
|
4 |
| -import { CreateModelDto } from '@/infrastructure/dtos/models/create-model.dto'; |
5 |
| -import { HuggingFaceRepoData } from '@/domain/models/huggingface.interface'; |
6 | 4 | import { InquirerService } from 'nest-commander';
|
7 | 5 | import { Inject, Injectable } from '@nestjs/common';
|
8 | 6 | import { Presets, SingleBar } from 'cli-progress';
|
9 |
| -import { LLAMA_2 } from '@/infrastructure/constants/prompt-constants'; |
10 | 7 |
|
11 | 8 | import { HttpService } from '@nestjs/axios';
|
12 | 9 | import { StartModelSuccessDto } from '@/infrastructure/dtos/models/start-model-success.dto';
|
13 | 10 | import { UpdateModelDto } from '@/infrastructure/dtos/models/update-model.dto';
|
14 | 11 | import { FileManagerService } from '@/infrastructure/services/file-manager/file-manager.service';
|
15 |
| -import { join, basename } from 'path'; |
| 12 | +import { join } from 'path'; |
16 | 13 | import { load } from 'js-yaml';
|
17 | 14 | import { existsSync, readdirSync, readFileSync } from 'fs';
|
18 |
| -import { isLocalModel, normalizeModelId } from '@/utils/normalize-model-id'; |
19 |
| -import { fetchJanRepoData, getHFModelMetadata } from '@/utils/huggingface'; |
20 |
| -import { createWriteStream, mkdirSync, promises } from 'node:fs'; |
21 |
| -import { firstValueFrom } from 'rxjs'; |
22 |
| -import { Engines } from '../types/engine.interface'; |
| 15 | +import { isLocalModel } from '@/utils/normalize-model-id'; |
| 16 | +import { HuggingFaceRepoSibling } from '@/domain/models/huggingface.interface'; |
23 | 17 |
|
24 | 18 | @Injectable()
|
25 | 19 | export class ModelsCliUsecases {
|
@@ -120,170 +114,34 @@ export class ModelsCliUsecases {
|
120 | 114 | console.error('Model already exists');
|
121 | 115 | process.exit(1);
|
122 | 116 | }
|
123 |
| - |
124 |
| - if (modelId.includes('onnx') || modelId.includes('tensorrt')) { |
125 |
| - await this.pullEngineModelFiles(modelId); |
126 |
| - } else { |
127 |
| - await this.pullGGUFModel(modelId); |
128 |
| - const bar = new SingleBar({}, Presets.shades_classic); |
129 |
| - bar.start(100, 0); |
130 |
| - const callback = (progress: number) => { |
131 |
| - bar.update(progress); |
132 |
| - }; |
133 |
| - |
134 |
| - try { |
135 |
| - await this.modelsUsecases.downloadModel(modelId, callback); |
136 |
| - |
137 |
| - const model = await this.modelsUsecases.findOne(modelId); |
138 |
| - const fileUrl = join( |
139 |
| - await this.fileService.getModelsPath(), |
140 |
| - normalizeModelId(modelId), |
141 |
| - basename((model?.files as string[])[0]), |
142 |
| - ); |
143 |
| - await this.modelsUsecases.update(modelId, { |
144 |
| - files: [fileUrl], |
145 |
| - name: modelId.replace(':default', ''), |
146 |
| - }); |
147 |
| - } catch (err) { |
148 |
| - bar.stop(); |
149 |
| - throw err; |
150 |
| - } |
151 |
| - } |
152 |
| - } |
153 |
| - |
154 |
| - /** |
155 |
| - * It's to pull engine model files from HuggingFace repository |
156 |
| - * @param modelId |
157 |
| - */ |
158 |
| - private async pullEngineModelFiles(modelId: string) { |
159 |
| - const modelsContainerDir = await this.fileService.getModelsPath(); |
160 |
| - |
161 |
| - if (!existsSync(modelsContainerDir)) { |
162 |
| - mkdirSync(modelsContainerDir, { recursive: true }); |
163 |
| - } |
164 |
| - |
165 |
| - const modelFolder = join(modelsContainerDir, normalizeModelId(modelId)); |
166 |
| - await promises.mkdir(modelFolder, { recursive: true }).catch(() => {}); |
167 |
| - |
168 |
| - const files = (await fetchJanRepoData(modelId)).siblings; |
169 |
| - for (const file of files) { |
170 |
| - console.log(`Downloading ${file.rfilename}`); |
171 |
| - const bar = new SingleBar({}, Presets.shades_classic); |
172 |
| - bar.start(100, 0); |
173 |
| - const response = await firstValueFrom( |
174 |
| - this.httpService.get(file.downloadUrl ?? '', { |
175 |
| - responseType: 'stream', |
176 |
| - }), |
177 |
| - ); |
178 |
| - if (!response) { |
179 |
| - throw new Error('Failed to download model'); |
180 |
| - } |
181 |
| - |
182 |
| - await new Promise((resolve, reject) => { |
183 |
| - const writer = createWriteStream(join(modelFolder, file.rfilename)); |
184 |
| - let receivedBytes = 0; |
185 |
| - const totalBytes = response.headers['content-length']; |
186 |
| - |
187 |
| - writer.on('finish', () => { |
188 |
| - resolve(true); |
189 |
| - }); |
190 |
| - |
191 |
| - writer.on('error', (error) => { |
192 |
| - reject(error); |
193 |
| - }); |
194 |
| - |
195 |
| - response.data.on('data', (chunk: any) => { |
196 |
| - receivedBytes += chunk.length; |
197 |
| - bar.update(Math.floor((receivedBytes / totalBytes) * 100)); |
198 |
| - }); |
199 |
| - |
200 |
| - response.data.pipe(writer); |
| 117 | + await this.modelsUsecases.pullModel(modelId, true, (files) => { |
| 118 | + return new Promise<HuggingFaceRepoSibling>(async (resolve) => { |
| 119 | + const listChoices = files |
| 120 | + .filter((e) => e.quantization != null) |
| 121 | + .map((e) => { |
| 122 | + return { |
| 123 | + name: e.quantization, |
| 124 | + value: e.quantization, |
| 125 | + }; |
| 126 | + }); |
| 127 | + |
| 128 | + if (listChoices.length > 1) { |
| 129 | + const { quantization } = await this.inquirerService.inquirer.prompt({ |
| 130 | + type: 'list', |
| 131 | + name: 'quantization', |
| 132 | + message: 'Select quantization', |
| 133 | + choices: listChoices, |
| 134 | + }); |
| 135 | + resolve( |
| 136 | + files |
| 137 | + .filter((e) => !!e.quantization) |
| 138 | + .find((e: any) => e.quantization === quantization) ?? files[0], |
| 139 | + ); |
| 140 | + } else { |
| 141 | + resolve(files.find((e) => e.rfilename.includes('.gguf')) ?? files[0]); |
| 142 | + } |
201 | 143 | });
|
202 |
| - bar.stop(); |
203 |
| - } |
204 |
| - |
205 |
| - const model: CreateModelDto = load( |
206 |
| - readFileSync(join(modelFolder, 'model.yml'), 'utf-8'), |
207 |
| - ) as CreateModelDto; |
208 |
| - model.files = [join(modelFolder)]; |
209 |
| - model.model = modelId; |
210 |
| - |
211 |
| - if (!(await this.modelsUsecases.findOne(modelId))) |
212 |
| - await this.modelsUsecases.create(model); |
213 |
| - |
214 |
| - if (model.engine === Engines.tensorrtLLM) { |
215 |
| - if (process.platform === 'win32') |
216 |
| - console.log( |
217 |
| - 'Please ensure that you install MPI and its SDK to use the TensorRT engine, as it also requires the Cuda Toolkit 12.3 to work. Refs:\n- https://github.com/microsoft/Microsoft-MPI/releases/download/v10.1.1/msmpisetup.exe\n- https://github.com/microsoft/Microsoft-MPI/releases/download/v10.1.1/msmpisdk.msi', |
218 |
| - ); |
219 |
| - else if (process.platform === 'linux') |
220 |
| - console.log( |
221 |
| - 'Please ensure that you install OpenMPI and its SDK to use the TensorRT engine, as it also requires the Cuda Toolkit 12.3 to work.\nYou can install OpenMPI by running "sudo apt update && sudo apt install openmpi-bin libopenmpi-dev"', |
222 |
| - ); |
223 |
| - } |
224 |
| - } |
225 |
| - /** |
226 |
| - * It's to pull model from HuggingFace repository |
227 |
| - * It could be a model from Jan's repo or other authors |
228 |
| - * @param modelId HuggingFace model id. e.g. "janhq/llama-3 or llama3:7b" |
229 |
| - */ |
230 |
| - private async pullGGUFModel(modelId: string) { |
231 |
| - const data: HuggingFaceRepoData = |
232 |
| - await this.modelsUsecases.fetchModelMetadata(modelId); |
233 |
| - |
234 |
| - let modelVersion; |
235 |
| - |
236 |
| - const listChoices = data.siblings |
237 |
| - .filter((e) => e.quantization != null) |
238 |
| - .map((e) => { |
239 |
| - return { |
240 |
| - name: e.quantization, |
241 |
| - value: e.quantization, |
242 |
| - }; |
243 |
| - }); |
244 |
| - |
245 |
| - if (listChoices.length > 1) { |
246 |
| - const { quantization } = await this.inquirerService.inquirer.prompt({ |
247 |
| - type: 'list', |
248 |
| - name: 'quantization', |
249 |
| - message: 'Select quantization', |
250 |
| - choices: listChoices, |
251 |
| - }); |
252 |
| - modelVersion = data.siblings |
253 |
| - .filter((e) => !!e.quantization) |
254 |
| - .find((e: any) => e.quantization === quantization); |
255 |
| - } else { |
256 |
| - modelVersion = data.siblings.find((e) => e.rfilename.includes('.gguf')); |
257 |
| - } |
258 |
| - |
259 |
| - if (!modelVersion) throw 'No expected quantization found'; |
260 |
| - const metadata = await getHFModelMetadata(modelVersion.downloadUrl!); |
261 |
| - |
262 |
| - const promptTemplate = metadata?.promptTemplate ?? LLAMA_2; |
263 |
| - const stopWords: string[] = [metadata?.stopWord ?? '']; |
264 |
| - |
265 |
| - const model: CreateModelDto = { |
266 |
| - files: [modelVersion.downloadUrl ?? ''], |
267 |
| - model: modelId, |
268 |
| - name: modelId, |
269 |
| - prompt_template: promptTemplate, |
270 |
| - stop: stopWords, |
271 |
| - |
272 |
| - // Default Inference Params |
273 |
| - stream: true, |
274 |
| - max_tokens: 4098, |
275 |
| - frequency_penalty: 0.7, |
276 |
| - presence_penalty: 0.7, |
277 |
| - temperature: 0.7, |
278 |
| - top_p: 0.7, |
279 |
| - |
280 |
| - // Default Model Settings |
281 |
| - ctx_len: 4096, |
282 |
| - ngl: 100, |
283 |
| - engine: Engines.llamaCPP, |
284 |
| - }; |
285 |
| - if (!(await this.modelsUsecases.findOne(modelId))) |
286 |
| - await this.modelsUsecases.create(model); |
| 144 | + }); |
287 | 145 | }
|
288 | 146 |
|
289 | 147 | /**
|
|
0 commit comments