From abf0e546bda733eaa3948c35fd99dd7dae847174 Mon Sep 17 00:00:00 2001 From: "sweep-ai[bot]" <128439645+sweep-ai[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 07:26:42 +0000 Subject: [PATCH] feat: Updated src/models.rs --- src/models.rs | 66 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 40 insertions(+), 26 deletions(-) diff --git a/src/models.rs b/src/models.rs index 38502b5..fd51c65 100644 --- a/src/models.rs +++ b/src/models.rs @@ -26,32 +26,46 @@ impl Manager for OpenAIClientManager { type Error = MotorheadError; async fn create(&self) -> Result { - let openai_client = match ( - env::var("AZURE_API_KEY"), - env::var("AZURE_DEPLOYMENT_ID"), - env::var("AZURE_DEPLOYMENT_ID_ADA"), - env::var("AZURE_API_BASE"), - ) { - ( - Ok(azure_api_key), - Ok(azure_deployment_id), - Ok(azure_deployment_id_ada), - Ok(azure_api_base), - ) => { - let config = AzureConfig::new() - .with_api_base(&azure_api_base) - .with_api_key(&azure_api_key) - .with_deployment_id(azure_deployment_id) - .with_api_version("2023-05-15"); - - let config_ada = AzureConfig::new() - .with_api_base(&azure_api_base) - .with_api_key(&azure_api_key) - .with_deployment_id(azure_deployment_id_ada) - .with_api_version("2023-05-15"); - - AnyOpenAIClient::Azure { - embedding_client: Client::with_config(config_ada), + let openai_api_base = env::var("OPENAI_API_BASE").ok(); + let azure_api_key = env::var("AZURE_API_KEY").ok(); + let azure_deployment_id = env::var("AZURE_DEPLOYMENT_ID").ok(); + let azure_deployment_id_ada = env::var("AZURE_DEPLOYMENT_ID_ADA").ok(); + let azure_api_base = env::var("AZURE_API_BASE").ok(); + + let openai_client = if let Some(api_base) = openai_api_base { + let embedding_config = OpenAIConfig::default().with_api_base(&api_base); + let completion_config = OpenAIConfig::default().with_api_base(&api_base); + + AnyOpenAIClient::OpenAI { + embedding_client: Client::with_config(embedding_config), + completion_client: Client::with_config(completion_config), + } + } else if azure_api_key.is_some() && azure_deployment_id.is_some() && azure_deployment_id_ada.is_some() && azure_api_base.is_some() { + let config = AzureConfig::new() + .with_api_base(azure_api_base.as_ref().unwrap()) + .with_api_key(azure_api_key.as_ref().unwrap()) + .with_deployment_id(azure_deployment_id.unwrap()) + .with_api_version("2023-05-15"); + + let config_ada = AzureConfig::new() + .with_api_base(azure_api_base.as_ref().unwrap()) + .with_api_key(azure_api_key.as_ref().unwrap()) + .with_deployment_id(azure_deployment_id_ada.unwrap()) + .with_api_version("2023-05-15"); + + AnyOpenAIClient::Azure { + embedding_client: Client::with_config(config_ada), + completion_client: Client::with_config(config), + } + } else { + AnyOpenAIClient::OpenAI { + embedding_client: Client::new(), + completion_client: Client::new(), + } + }; + + Ok(openai_client) + } completion_client: Client::with_config(config), } }