Skip to content

Commit b9e9e15

Browse files
fix: add model author for model source (#2038)
Co-authored-by: sangjanai <[email protected]>
1 parent 9f1a50f commit b9e9e15

File tree

4 files changed

+98
-37
lines changed

4 files changed

+98
-37
lines changed

engine/services/model_source_service.cc

+63-30
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,13 @@
1414
namespace hu = huggingface_utils;
1515

1616
namespace {
17+
constexpr const int kModeSourceCacheSecs = 600;
18+
19+
std::string GenSourceId(const std::string& author_hub,
20+
const std::string& model_name) {
21+
return author_hub + "/" + model_name;
22+
}
23+
1724
std::vector<ModelInfo> ParseJsonString(const std::string& json_str) {
1825
std::vector<ModelInfo> models;
1926

@@ -79,19 +86,34 @@ cpp::result<bool, std::string> ModelSourceService::AddModelSource(
7986
}
8087

8188
if (auto is_org = r.pathParams.size() == 1; is_org) {
82-
auto& author = r.pathParams[0];
83-
if (author == "cortexso") {
84-
return AddCortexsoOrg(model_source);
85-
} else {
86-
return AddHfOrg(model_source, author);
87-
}
89+
return cpp::fail("Only support repository model source, url: " +
90+
model_source);
91+
// TODO(sang)
92+
// auto& hub_author = r.pathParams[0];
93+
// if (hub_author == "cortexso") {
94+
// return AddCortexsoOrg(model_source);
95+
// } else {
96+
// return AddHfOrg(model_source, hub_author);
97+
// }
8898
} else { // Repo
89-
auto const& author = r.pathParams[0];
99+
auto const& hub_author = r.pathParams[0];
90100
auto const& model_name = r.pathParams[1];
101+
// Return cache value
102+
if (auto key = GenSourceId(hub_author, model_name);
103+
src_cache_.find(key) != src_cache_.end()) {
104+
auto now = std::chrono::system_clock::now();
105+
if (std::chrono::duration_cast<std::chrono::seconds>(now -
106+
src_cache_.at(key))
107+
.count() < kModeSourceCacheSecs) {
108+
CTL_DBG("Return cache value for model source: " << model_source);
109+
return true;
110+
}
111+
}
112+
91113
if (r.pathParams[0] == "cortexso") {
92-
return AddCortexsoRepo(model_source, author, model_name);
114+
return AddCortexsoRepo(model_source, hub_author, model_name);
93115
} else {
94-
return AddHfRepo(model_source, author, model_name);
116+
return AddHfRepo(model_source, hub_author, model_name);
95117
}
96118
}
97119
}
@@ -190,9 +212,9 @@ cpp::result<ModelSource, std::string> ModelSourceService::GetModelSource(
190212
}
191213

192214
cpp::result<std::vector<std::string>, std::string>
193-
ModelSourceService::GetRepositoryList(std::string_view author,
215+
ModelSourceService::GetRepositoryList(std::string_view hub_author,
194216
std::string_view tag_filter) {
195-
std::string as(author);
217+
std::string as(hub_author);
196218
auto get_repo_list = [this, &as, &tag_filter] {
197219
std::vector<std::string> repo_list;
198220
auto const& mis = cortexso_repos_.at(as);
@@ -227,9 +249,9 @@ ModelSourceService::GetRepositoryList(std::string_view author,
227249
}
228250

229251
cpp::result<bool, std::string> ModelSourceService::AddHfOrg(
230-
const std::string& model_source, const std::string& author) {
252+
const std::string& model_source, const std::string& hub_author) {
231253
auto res = curl_utils::SimpleGet("https://huggingface.co/api/models?author=" +
232-
author);
254+
hub_author);
233255
if (res.has_value()) {
234256
auto models = ParseJsonString(res.value());
235257
// Add new models
@@ -238,9 +260,10 @@ cpp::result<bool, std::string> ModelSourceService::AddHfOrg(
238260

239261
auto author_model = string_utils::SplitBy(m.id, "/");
240262
if (author_model.size() == 2) {
241-
auto const& author = author_model[0];
263+
auto const& hub_author = author_model[0];
242264
auto const& model_name = author_model[1];
243-
auto r = AddHfRepo(model_source + "/" + model_name, author, model_name);
265+
auto r =
266+
AddHfRepo(model_source + "/" + model_name, hub_author, model_name);
244267
if (r.has_error()) {
245268
CTL_WRN(r.error());
246269
}
@@ -253,14 +276,14 @@ cpp::result<bool, std::string> ModelSourceService::AddHfOrg(
253276
}
254277

255278
cpp::result<bool, std::string> ModelSourceService::AddHfRepo(
256-
const std::string& model_source, const std::string& author,
279+
const std::string& model_source, const std::string& hub_author,
257280
const std::string& model_name) {
258281
// Get models from db
259282

260283
auto model_list_before = db_service_->GetModels(model_source)
261284
.value_or(std::vector<cortex::db::ModelEntry>{});
262285
std::unordered_set<std::string> updated_model_list;
263-
auto add_res = AddRepoSiblings(model_source, author, model_name);
286+
auto add_res = AddRepoSiblings(model_source, hub_author, model_name);
264287
if (add_res.has_error()) {
265288
return cpp::fail(add_res.error());
266289
} else {
@@ -274,15 +297,17 @@ cpp::result<bool, std::string> ModelSourceService::AddHfRepo(
274297
}
275298
}
276299
}
300+
src_cache_[GenSourceId(hub_author, model_name)] =
301+
std::chrono::system_clock::now();
277302
return true;
278303
}
279304

280305
cpp::result<std::unordered_set<std::string>, std::string>
281306
ModelSourceService::AddRepoSiblings(const std::string& model_source,
282-
const std::string& author,
307+
const std::string& hub_author,
283308
const std::string& model_name) {
284309
std::unordered_set<std::string> res;
285-
auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name);
310+
auto repo_info = hu::GetHuggingFaceModelRepoInfo(hub_author, model_name);
286311
if (repo_info.has_error()) {
287312
return cpp::fail(repo_info.error());
288313
}
@@ -293,14 +318,14 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source,
293318
"supported.");
294319
}
295320

296-
auto siblings_fs = hu::GetSiblingsFileSize(author, model_name);
321+
auto siblings_fs = hu::GetSiblingsFileSize(hub_author, model_name);
297322

298323
if (siblings_fs.has_error()) {
299-
return cpp::fail("Could not get siblings file size: " + author + "/" +
300-
model_name);
324+
return cpp::fail("Could not get siblings file size: " +
325+
GenSourceId(hub_author, model_name));
301326
}
302327

303-
auto readme = hu::GetReadMe(author, model_name);
328+
auto readme = hu::GetReadMe(hub_author, model_name);
304329
std::string desc;
305330
if (!readme.has_error()) {
306331
desc = readme.value();
@@ -326,10 +351,10 @@ ModelSourceService::AddRepoSiblings(const std::string& model_source,
326351
siblings_fs_v.file_sizes.at(sibling.rfilename).size_in_bytes;
327352
}
328353
std::string model_id =
329-
author + ":" + model_name + ":" + sibling.rfilename;
354+
hub_author + ":" + model_name + ":" + sibling.rfilename;
330355
cortex::db::ModelEntry e = {
331356
.model = model_id,
332-
.author_repo_id = author,
357+
.author_repo_id = hub_author,
333358
.branch_name = "main",
334359
.path_to_model_yaml = "",
335360
.model_alias = "",
@@ -369,9 +394,9 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoOrg(
369394
CTL_INF(m.id);
370395
auto author_model = string_utils::SplitBy(m.id, "/");
371396
if (author_model.size() == 2) {
372-
auto const& author = author_model[0];
397+
auto const& hub_author = author_model[0];
373398
auto const& model_name = author_model[1];
374-
auto r = AddCortexsoRepo(model_source + "/" + model_name, author,
399+
auto r = AddCortexsoRepo(model_source + "/" + model_name, hub_author,
375400
model_name);
376401
if (r.has_error()) {
377402
CTL_WRN(r.error());
@@ -386,7 +411,7 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoOrg(
386411
}
387412

388413
cpp::result<bool, std::string> ModelSourceService::AddCortexsoRepo(
389-
const std::string& model_source, const std::string& author,
414+
const std::string& model_source, const std::string& hub_author,
390415
const std::string& model_name) {
391416
auto begin = std::chrono::system_clock::now();
392417
auto branches =
@@ -395,17 +420,23 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoRepo(
395420
return cpp::fail(branches.error());
396421
}
397422

398-
auto repo_info = hu::GetHuggingFaceModelRepoInfo(author, model_name);
423+
auto repo_info = hu::GetHuggingFaceModelRepoInfo(hub_author, model_name);
399424
if (repo_info.has_error()) {
400425
return cpp::fail(repo_info.error());
401426
}
402427

403-
auto readme = hu::GetReadMe(author, model_name);
428+
auto readme = hu::GetReadMe(hub_author, model_name);
404429
std::string desc;
405430
if (!readme.has_error()) {
406431
desc = readme.value();
407432
}
408433

434+
auto author = hub_author;
435+
if (auto model_author = hu::GetModelAuthorCortexsoHub(model_name);
436+
model_author.has_value() && !model_author->empty()) {
437+
author = *model_author;
438+
}
439+
409440
// Get models from db
410441
auto model_list_before = db_service_->GetModels(model_source)
411442
.value_or(std::vector<cortex::db::ModelEntry>{});
@@ -442,6 +473,8 @@ cpp::result<bool, std::string> ModelSourceService::AddCortexsoRepo(
442473
"Duration ms: " << std::chrono::duration_cast<std::chrono::milliseconds>(
443474
end - begin)
444475
.count());
476+
src_cache_[GenSourceId(hub_author, model_name)] =
477+
std::chrono::system_clock::now();
445478
return true;
446479
}
447480

engine/services/model_source_service.h

+7-5
Original file line numberDiff line numberDiff line change
@@ -65,25 +65,25 @@ class ModelSourceService {
6565
cpp::result<ModelSource, std::string> GetModelSource(const std::string& src);
6666

6767
cpp::result<std::vector<std::string>, std::string> GetRepositoryList(
68-
std::string_view author, std::string_view tag_filter);
68+
std::string_view hub_author, std::string_view tag_filter);
6969

7070
private:
7171
cpp::result<bool, std::string> AddHfOrg(const std::string& model_source,
72-
const std::string& author);
72+
const std::string& hub_author);
7373

7474
cpp::result<bool, std::string> AddHfRepo(const std::string& model_source,
75-
const std::string& author,
75+
const std::string& hub_author,
7676
const std::string& model_name);
7777

7878
cpp::result<std::unordered_set<std::string>, std::string> AddRepoSiblings(
79-
const std::string& model_source, const std::string& author,
79+
const std::string& model_source, const std::string& hub_author,
8080
const std::string& model_name);
8181

8282
cpp::result<bool, std::string> AddCortexsoOrg(
8383
const std::string& model_source);
8484

8585
cpp::result<bool, std::string> AddCortexsoRepo(
86-
const std::string& model_source, const std::string& author,
86+
const std::string& model_source, const std::string& hub_author,
8787
const std::string& model_name);
8888

8989
cpp::result<std::string, std::string> AddCortexsoRepoBranch(
@@ -99,4 +99,6 @@ class ModelSourceService {
9999
std::atomic<bool> running_;
100100

101101
std::unordered_map<std::string, std::vector<ModelInfo>> cortexso_repos_;
102+
using TimePoint = std::chrono::time_point<std::chrono::system_clock>;
103+
std::unordered_map<std::string, TimePoint> src_cache_;
102104
};

engine/utils/huggingface_utils.h

+20
Original file line numberDiff line numberDiff line change
@@ -311,4 +311,24 @@ inline std::optional<std::string> GetDefaultBranch(
311311
return std::nullopt;
312312
}
313313
}
314+
315+
inline std::optional<std::string> GetModelAuthorCortexsoHub(
316+
const std::string& model_name) {
317+
try {
318+
auto remote_yml = curl_utils::ReadRemoteYaml(GetMetadataUrl(model_name));
319+
320+
if (remote_yml.has_error()) {
321+
return std::nullopt;
322+
}
323+
324+
auto metadata = remote_yml.value();
325+
auto author = metadata["author"];
326+
if (author.IsDefined()) {
327+
return author.as<std::string>();
328+
}
329+
return std::nullopt;
330+
} catch (const std::exception& e) {
331+
return std::nullopt;
332+
}
333+
}
314334
} // namespace huggingface_utils

engine/utils/url_parser.h

+8-2
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,11 @@ const std::regex url_regex(
6969
R"(^(([^:\/?#]+):)?(//([^\/?#]*))?([^?#]*)(\?([^#]*))?(#(.*))?)",
7070
std::regex::extended);
7171

72-
inline void SplitPathParams(const std::string& input,
72+
inline bool SplitPathParams(const std::string& input,
7373
std::vector<std::string>& pathList) {
74+
if (input.find("//") != std::string::npos) {
75+
return false;
76+
}
7477
// split the path by '/'
7578
std::string token;
7679
std::istringstream tokenStream(input);
@@ -80,6 +83,7 @@ inline void SplitPathParams(const std::string& input,
8083
}
8184
pathList.push_back(token);
8285
}
86+
return true;
8387
}
8488

8589
inline cpp::result<Url, std::string> FromUrlString(
@@ -105,7 +109,9 @@ inline cpp::result<Url, std::string> FromUrlString(
105109
} else if (counter == hostAndPortIndex) {
106110
url.host = res; // TODO: split the port for completeness
107111
} else if (counter == pathIndex) {
108-
SplitPathParams(res, url.pathParams);
112+
if (!SplitPathParams(res, url.pathParams)) {
113+
return cpp::fail("Malformed URL: " + urlString);
114+
}
109115
} else if (counter == queryIndex) {
110116
// TODO: implement
111117
}

0 commit comments

Comments
 (0)