Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sweep: Connection to openai via OPENAI_API_BASE doesn't seem to work (βœ“ Sandbox Passed) #109

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
66 changes: 40 additions & 26 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,32 +26,46 @@ impl Manager for OpenAIClientManager {
type Error = MotorheadError;

async fn create(&self) -> Result<AnyOpenAIClient, MotorheadError> {
let openai_client = match (
env::var("AZURE_API_KEY"),
env::var("AZURE_DEPLOYMENT_ID"),
env::var("AZURE_DEPLOYMENT_ID_ADA"),
env::var("AZURE_API_BASE"),
) {
(
Ok(azure_api_key),
Ok(azure_deployment_id),
Ok(azure_deployment_id_ada),
Ok(azure_api_base),
) => {
let config = AzureConfig::new()
.with_api_base(&azure_api_base)
.with_api_key(&azure_api_key)
.with_deployment_id(azure_deployment_id)
.with_api_version("2023-05-15");

let config_ada = AzureConfig::new()
.with_api_base(&azure_api_base)
.with_api_key(&azure_api_key)
.with_deployment_id(azure_deployment_id_ada)
.with_api_version("2023-05-15");

AnyOpenAIClient::Azure {
embedding_client: Client::with_config(config_ada),
let openai_api_base = env::var("OPENAI_API_BASE").ok();
let azure_api_key = env::var("AZURE_API_KEY").ok();
let azure_deployment_id = env::var("AZURE_DEPLOYMENT_ID").ok();
let azure_deployment_id_ada = env::var("AZURE_DEPLOYMENT_ID_ADA").ok();
let azure_api_base = env::var("AZURE_API_BASE").ok();

let openai_client = if let Some(api_base) = openai_api_base {
let embedding_config = OpenAIConfig::default().with_api_base(&api_base);
let completion_config = OpenAIConfig::default().with_api_base(&api_base);

AnyOpenAIClient::OpenAI {
embedding_client: Client::with_config(embedding_config),
completion_client: Client::with_config(completion_config),
}
} else if azure_api_key.is_some() && azure_deployment_id.is_some() && azure_deployment_id_ada.is_some() && azure_api_base.is_some() {
let config = AzureConfig::new()
.with_api_base(azure_api_base.as_ref().unwrap())
.with_api_key(azure_api_key.as_ref().unwrap())
.with_deployment_id(azure_deployment_id.unwrap())
.with_api_version("2023-05-15");

let config_ada = AzureConfig::new()
.with_api_base(azure_api_base.as_ref().unwrap())
.with_api_key(azure_api_key.as_ref().unwrap())
.with_deployment_id(azure_deployment_id_ada.unwrap())
.with_api_version("2023-05-15");

AnyOpenAIClient::Azure {
embedding_client: Client::with_config(config_ada),
completion_client: Client::with_config(config),
}
} else {
AnyOpenAIClient::OpenAI {
embedding_client: Client::new(),
completion_client: Client::new(),
}
};

Ok(openai_client)
}
completion_client: Client::with_config(config),
}
}
Expand Down
Loading