Skip to content

Commit 5dc9ed4

Browse files
fix: get models for anthropic and cohere (#2115)
* fix: get models for anthropic * fix: model list for cohere --------- Co-authored-by: sangjanai <[email protected]>
1 parent c65a632 commit 5dc9ed4

File tree

2 files changed

+22
-27
lines changed

2 files changed

+22
-27
lines changed

engine/extensions/remote-engine/remote_engine.cc

+22-23
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,14 @@ CurlResponse RemoteEngine::MakeGetModelsRequest(
252252
return response;
253253
}
254254

255-
std::string api_key_header =
256-
ReplaceApiKeyPlaceholder(header_template, api_key);
255+
std::unordered_map<std::string, std::string> replacements = {
256+
{"api_key", api_key}};
257+
auto hs = ReplaceHeaderPlaceholders(header_template, replacements);
257258

258259
struct curl_slist* headers = nullptr;
259-
headers = curl_slist_append(headers, api_key_header.c_str());
260+
for (auto const& h : hs) {
261+
headers = curl_slist_append(headers, h.c_str());
262+
}
260263
headers = curl_slist_append(headers, "Content-Type: application/json");
261264

262265
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
@@ -699,25 +702,7 @@ Json::Value RemoteEngine::GetRemoteModels(const std::string& url,
699702
const std::string& api_key,
700703
const std::string& header_template) {
701704
if (url.empty()) {
702-
if (engine_name_ == kAnthropicEngine) {
703-
Json::Value json_resp;
704-
Json::Value model_array(Json::arrayValue);
705-
for (const auto& m : kAnthropicModels) {
706-
Json::Value val;
707-
val["id"] = std::string(m);
708-
val["engine"] = "anthropic";
709-
val["created"] = "_";
710-
val["object"] = "model";
711-
model_array.append(val);
712-
}
713-
714-
json_resp["object"] = "list";
715-
json_resp["data"] = model_array;
716-
CTL_INF("Remote models responded");
717-
return json_resp;
718-
} else {
719-
return Json::Value();
720-
}
705+
return Json::Value();
721706
} else {
722707
auto response = MakeGetModelsRequest(url, api_key, header_template);
723708
if (response.error) {
@@ -728,9 +713,23 @@ Json::Value RemoteEngine::GetRemoteModels(const std::string& url,
728713
}
729714
CTL_DBG(response.body);
730715
auto body_json = json_helper::ParseJsonString(response.body);
731-
if (body_json.isMember("error")) {
716+
if (body_json.isMember("error") && !body_json["error"].isNull()) {
732717
return body_json["error"];
733718
}
719+
720+
// hardcode for cohere
721+
if (url.find("api.cohere.ai") != std::string::npos) {
722+
if (body_json.isMember("models")) {
723+
for (auto& model : body_json["models"]) {
724+
if (model.isMember("name")) {
725+
model["id"] = model["name"];
726+
model.removeMember("name");
727+
}
728+
}
729+
body_json["data"] = body_json["models"];
730+
body_json.removeMember("models");
731+
}
732+
}
734733
return body_json;
735734
}
736735
}

engine/utils/engine_constants.h

-4
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
constexpr const auto kLlamaEngine = "llama-cpp";
44
constexpr const auto kPythonEngine = "python-engine";
55

6-
constexpr const auto kOpenAiEngine = "openai";
7-
constexpr const auto kAnthropicEngine = "anthropic";
8-
9-
106
constexpr const auto kRemote = "remote";
117
constexpr const auto kLocal = "local";
128

0 commit comments

Comments
 (0)