@@ -143,6 +143,21 @@ cpp::result<DownloadTask, std::string> GetDownloadTask(
143
143
}
144
144
} // namespace
145
145
146
+ ModelService::ModelService (std::shared_ptr<DatabaseService> db_service,
147
+ std::shared_ptr<HardwareService> hw_service,
148
+ std::shared_ptr<DownloadService> download_service,
149
+ std::shared_ptr<InferenceService> inference_service,
150
+ std::shared_ptr<EngineServiceI> engine_svc,
151
+ cortex::TaskQueue& task_queue)
152
+ : db_service_(db_service),
153
+ hw_service_(hw_service),
154
+ download_service_{download_service},
155
+ inference_svc_ (inference_service),
156
+ engine_svc_(engine_svc),
157
+ task_queue_(task_queue) {
158
+ ProcessBgrTasks ();
159
+ };
160
+
146
161
void ModelService::ForceIndexingModelList () {
147
162
CTL_INF (" Force indexing model list" );
148
163
@@ -331,8 +346,17 @@ cpp::result<DownloadTask, std::string> ModelService::HandleDownloadUrlAsync(
331
346
return download_service_->AddTask (downloadTask, on_finished);
332
347
}
333
348
349
+ std::optional<hardware::Estimation> ModelService::GetEstimation (
350
+ const std::string& model_handle) {
351
+ std::lock_guard l (es_mtx_);
352
+ if (auto it = es_.find (model_handle); it != es_.end ()) {
353
+ return it->second ;
354
+ }
355
+ return std::nullopt;
356
+ }
357
+
334
358
cpp::result<std::optional<hardware::Estimation>, std::string>
335
- ModelService::GetEstimation (const std::string& model_handle,
359
+ ModelService::EstimateModel (const std::string& model_handle,
336
360
const std::string& kv_cache, int n_batch,
337
361
int n_ubatch) {
338
362
namespace fs = std::filesystem;
@@ -548,7 +572,7 @@ ModelService::DownloadModelFromCortexsoAsync(
548
572
// Close the file
549
573
pyvenv_cfg.close ();
550
574
// Add executable permission to python
551
- set_permission_utils::SetExecutePermissionsRecursive (venv_path);
575
+ ( void ) set_permission_utils::SetExecutePermissionsRecursive (venv_path);
552
576
} else {
553
577
CTL_ERR (" Failed to extract venv.zip" );
554
578
};
@@ -828,7 +852,7 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
828
852
CTL_WRN (" Error: " + res.error ());
829
853
for (auto & depend : depends) {
830
854
if (depend != model_handle) {
831
- StopModel (depend);
855
+ auto sr = StopModel (depend);
832
856
}
833
857
}
834
858
return cpp::fail (" Model failed to start dependency '" + depend +
@@ -945,6 +969,11 @@ cpp::result<StartModelResult, std::string> ModelService::StartModel(
945
969
946
970
json_helper::MergeJson (json_data, params_override);
947
971
972
+ // Set default cpu_threads if it is not configured
973
+ if (!json_data.isMember (" cpu_threads" )) {
974
+ json_data[" cpu_threads" ] = GetCpuThreads ();
975
+ }
976
+
948
977
// Set the latest ctx_len
949
978
if (ctx_len) {
950
979
json_data[" ctx_len" ] =
@@ -1329,6 +1358,10 @@ ModelService::MayFallbackToCpu(const std::string& model_path, int ngl,
1329
1358
return warning;
1330
1359
}
1331
1360
1361
+ int ModelService::GetCpuThreads () const {
1362
+ return std::max (std::thread::hardware_concurrency () / 2 , 1u );
1363
+ }
1364
+
1332
1365
cpp::result<std::shared_ptr<ModelMetadata>, std::string>
1333
1366
ModelService::GetModelMetadata (const std::string& model_id) const {
1334
1367
if (model_id.empty ()) {
@@ -1381,4 +1414,28 @@ std::string ModelService::GetEngineByModelId(
1381
1414
auto mc = yaml_handler.GetModelConfig ();
1382
1415
CTL_DBG (mc.engine );
1383
1416
return mc.engine ;
1417
+ }
1418
+
1419
+ void ModelService::ProcessBgrTasks () {
1420
+ CTL_INF (" Start processing background tasks" )
1421
+ auto cb = [this ] {
1422
+ CTL_DBG (" Estimate model resource usage" );
1423
+ auto list_entry = db_service_->LoadModelList ();
1424
+ if (list_entry) {
1425
+ for (const auto & model_entry : list_entry.value ()) {
1426
+ // Only process local models
1427
+ if (model_entry.status == cortex::db::ModelStatus::Downloaded) {
1428
+ auto es = EstimateModel (model_entry.model );
1429
+ if (es.has_value ()) {
1430
+ std::lock_guard l (es_mtx_);
1431
+ es_[model_entry.model ] = es.value ();
1432
+ }
1433
+ }
1434
+ }
1435
+ }
1436
+ };
1437
+
1438
+ auto clone = cb;
1439
+ task_queue_.RunInQueue (std::move (cb));
1440
+ task_queue_.RunEvery (std::chrono::seconds (10 ), std::move (clone));
1384
1441
}
0 commit comments