Skip to content

Commit 4328f57

Browse files
author
= Enea_Gore
committed
Upgrade all dependencies to the latest version and adjust openai config
1 parent 0dcae95 commit 4328f57

File tree

3 files changed

+237
-325
lines changed

3 files changed

+237
-325
lines changed

module_text_llm/module_text_llm/helpers/models/openai.py

Lines changed: 49 additions & 204 deletions
Original file line numberDiff line numberDiff line change
@@ -1,254 +1,98 @@
11
import os
2-
from contextlib import contextmanager
3-
from typing import Any, Callable, Dict, List
2+
from typing import Any , Dict
43
from pydantic import Field, validator, PositiveInt
54
from enum import Enum
6-
75
import openai
8-
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
9-
from langchain.llms import AzureOpenAI, OpenAI
10-
from langchain.llms.openai import BaseOpenAI
116
from langchain.base_language import BaseLanguageModel
12-
7+
from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI
138
from athena.logger import logger
149
from .model_config import ModelConfig
1510

1611

1712
OPENAI_PREFIX = "openai_"
1813
AZURE_OPENAI_PREFIX = "azure_openai_"
14+
############################################################################################
15+
# START Set Enviorment variables that are automatically used by ChatOpenAI/ChatAzureOpenAI#
16+
############################################################################################
17+
# Might be worth renaming them
18+
openai_available = bool(os.environ.get("LLM_OPENAI_API_KEY"))
19+
if openai_available:
20+
os.environ["OPENAI_API_KEY"] = os.environ["LLM_OPENAI_API_KEY"]
1921

20-
21-
#########################################################################
22-
# Monkey patching openai/langchain api #
23-
# ===================================================================== #
24-
# This allows us to have multiple api keys i.e. mixing #
25-
# openai and azure openai api keys so we can use not only deployed #
26-
# models but also models from the non-azure openai api. #
27-
# This is mostly for testing purposes, in production we can just deploy #
28-
# the models to azure that we want to use. #
29-
#########################################################################
30-
31-
# Prevent LangChain error, we will set the key later
32-
os.environ["OPENAI_API_KEY"] = ""
33-
34-
def _wrap(old: Any, new: Any) -> Callable:
35-
def repl(*args: Any, **kwargs: Any) -> Any:
36-
new(args[0]) # args[0] is self
37-
return old(*args, **kwargs)
38-
return repl
39-
40-
41-
def _async_wrap(old: Any, new: Any):
42-
async def repl(*args, **kwargs):
43-
new(args[0]) # args[0] is self
44-
return await old(*args, **kwargs)
45-
return repl
46-
47-
48-
def _set_credentials(self):
49-
openai.api_key = self.openai_api_key
50-
51-
api_type = "open_ai"
52-
api_base = "https://api.openai.com/v1"
53-
api_version = None
54-
if hasattr(self, "openai_api_type"):
55-
api_type = self.openai_api_type
56-
57-
if api_type == "azure":
58-
if hasattr(self, "openai_api_base"):
59-
api_base = self.openai_api_base
60-
if hasattr(self, "openai_api_version"):
61-
api_version = self.openai_api_version
62-
63-
openai.api_type = api_type
64-
openai.api_base = api_base
65-
openai.api_version = api_version
66-
67-
68-
# Monkey patching langchain
69-
# pylint: disable=protected-access
70-
ChatOpenAI._generate = _wrap(ChatOpenAI._generate, _set_credentials) # type: ignore
71-
ChatOpenAI._agenerate = _async_wrap(ChatOpenAI._agenerate, _set_credentials) # type: ignore
72-
BaseOpenAI._generate = _wrap(BaseOpenAI._generate, _set_credentials) # type: ignore
73-
BaseOpenAI._agenerate = _async_wrap(BaseOpenAI._agenerate, _set_credentials) # type: ignore
74-
# pylint: enable=protected-access
75-
76-
#########################################################################
77-
# Monkey patching end #
78-
#########################################################################
79-
80-
81-
def _use_azure_credentials():
82-
openai.api_type = "azure"
83-
openai.api_key = os.environ.get("LLM_AZURE_OPENAI_API_KEY")
84-
openai.api_base = os.environ.get("LLM_AZURE_OPENAI_API_BASE")
85-
# os.environ.get("LLM_AZURE_OPENAI_API_VERSION")
86-
openai.api_version = "2023-03-15-preview"
87-
88-
89-
def _use_openai_credentials():
90-
openai.api_type = "open_ai"
91-
openai.api_key = os.environ.get("LLM_OPENAI_API_KEY")
92-
openai.api_base = "https://api.openai.com/v1"
93-
openai.api_version = None
94-
95-
96-
openai_available = bool(os.environ.get("LLM_OPENAI_API_KEY"))
9722
azure_openai_available = bool(os.environ.get("LLM_AZURE_OPENAI_API_KEY"))
98-
99-
100-
# This is a hack to make sure that the openai api is set correctly
101-
# Right now it is overkill, but it will be useful when the api gets fixed and we no longer
102-
# hardcode the model names (i.e. OpenAI fixes their api)
103-
@contextmanager
104-
def _openai_client(use_azure_api: bool, is_preference: bool):
105-
"""Set the openai client to use the correct api type, if available
106-
107-
Args:
108-
use_azure_api (bool): If true, use the azure api, else use the openai api
109-
is_preference (bool): If true, it can fall back to the other api if the preferred one is not available
110-
"""
111-
if use_azure_api:
112-
if azure_openai_available:
113-
_use_azure_credentials()
114-
elif is_preference and openai_available:
115-
_use_openai_credentials()
116-
elif is_preference:
117-
raise EnvironmentError(
118-
"No OpenAI api available, please set LLM_AZURE_OPENAI_API_KEY, LLM_AZURE_OPENAI_API_BASE and "
119-
"LLM_AZURE_OPENAI_API_VERSION environment variables or LLM_OPENAI_API_KEY environment variable"
120-
)
121-
else:
122-
raise EnvironmentError(
123-
"Azure OpenAI api not available, please set LLM_AZURE_OPENAI_API_KEY, LLM_AZURE_OPENAI_API_BASE and "
124-
"LLM_AZURE_OPENAI_API_VERSION environment variables"
125-
)
126-
else:
127-
if openai_available:
128-
_use_openai_credentials()
129-
elif is_preference and azure_openai_available:
130-
_use_azure_credentials()
131-
elif is_preference:
132-
raise EnvironmentError(
133-
"No OpenAI api available, please set LLM_OPENAI_API_KEY environment variable or LLM_AZURE_OPENAI_API_KEY, "
134-
"LLM_AZURE_OPENAI_API_BASE and LLM_AZURE_OPENAI_API_VERSION environment variables"
135-
)
136-
else:
137-
raise EnvironmentError(
138-
"OpenAI api not available, please set LLM_OPENAI_API_KEY environment variable"
139-
)
140-
141-
# API client is setup correctly
142-
yield
143-
144-
145-
def _get_available_deployments(openai_models: Dict[str, List[str]], model_aliases: Dict[str, str]):
23+
if azure_openai_available:
24+
os.environ["AZURE_OPENAI_ENDPOINT"]=os.environ["LLM_AZURE_OPENAI_API_BASE"]
25+
os.environ["AZURE_OPENAI_API_KEY"]=os.environ["LLM_AZURE_OPENAI_API_KEY"]
26+
os.environ["OPENAI_API_VERSION"]=os.environ["LLM_AZURE_OPENAI_API_VERSION"]
27+
#########################################################################################
28+
# END Set Enviorment variables that are automatically used by ChatOpenAI/ChatAzureOpenAI#
29+
#########################################################################################
30+
31+
actually_deployed_azure= [
32+
"gpt-35-turbo",
33+
"gpt-4-turbo",
34+
"gpt-4-vision"
35+
]
36+
37+
new_openai_models = []
38+
39+
def _get_available_deployments():
14640
available_deployments: Dict[str, Dict[str, Any]] = {
14741
"chat_completion": {},
14842
"completion": {},
149-
"fine_tuneing": {},
43+
"fine_tune": {},
44+
"embeddings": {},
45+
"inference": {}
15046
}
15147

15248
if azure_openai_available:
153-
with _openai_client(use_azure_api=True, is_preference=False):
154-
deployments = openai.Deployment.list().get("data") or [] # type: ignore
155-
for deployment in deployments:
156-
model_name = deployment.model
157-
if model_name in model_aliases:
158-
model_name = model_aliases[model_name]
159-
if model_name in openai_models["chat_completion"]:
160-
available_deployments["chat_completion"][deployment.id] = deployment
161-
elif model_name in openai_models["completion"]:
162-
available_deployments["completion"][deployment.id] = deployment
163-
elif model_name in openai_models["fine_tuneing"]:
164-
available_deployments["fine_tuneing"][deployment.id] = deployment
49+
for deployment in actually_deployed_azure:
50+
available_deployments["chat_completion"][deployment] = deployment
51+
52+
53+
if openai_available:
54+
# This will return only the usable models
55+
openai.api_type= "openai"
56+
for model in openai.models.list():
57+
if model.owned_by == "openai":
58+
new_openai_models.append(model.id)
16559

16660
return available_deployments
16761

16862

169-
def _get_available_models(openai_models: Dict[str, List[str]],
170-
available_deployments: Dict[str, Dict[str, Any]]):
63+
def _get_available_models(available_deployments: Dict[str, Dict[str, Any]]):
17164
available_models: Dict[str, BaseLanguageModel] = {}
17265

17366
if openai_available:
174-
openai_api_key = os.environ["LLM_OPENAI_API_KEY"]
175-
for model_name in openai_models["chat_completion"]:
176-
available_models[OPENAI_PREFIX + model_name] = ChatOpenAI(
177-
model=model_name,
178-
openai_api_key=openai_api_key,
179-
client="",
180-
temperature=0
181-
)
182-
for model_name in openai_models["completion"]:
183-
available_models[OPENAI_PREFIX + model_name] = OpenAI(
184-
model=model_name,
185-
openai_api_key=openai_api_key,
186-
client="",
187-
temperature=0
67+
for model in new_openai_models:
68+
available_models[OPENAI_PREFIX + model] = ChatOpenAI(#type:ignore
69+
model=model,
18870
)
18971

19072
if azure_openai_available:
191-
azure_openai_api_key = os.environ["LLM_AZURE_OPENAI_API_KEY"]
192-
azure_openai_api_base = os.environ["LLM_AZURE_OPENAI_API_BASE"]
193-
azure_openai_api_version = os.environ["LLM_AZURE_OPENAI_API_VERSION"]
194-
19573
for model_type, Model in [("chat_completion", AzureChatOpenAI), ("completion", AzureOpenAI)]:
19674
for deployment_name, deployment in available_deployments[model_type].items():
19775
available_models[AZURE_OPENAI_PREFIX + deployment_name] = Model(
198-
model=deployment.model,
19976
deployment_name=deployment_name,
200-
openai_api_base=azure_openai_api_base,
201-
openai_api_version=azure_openai_api_version,
202-
openai_api_key=azure_openai_api_key,
203-
client="",
20477
temperature=0
20578
)
206-
20779
return available_models
20880

20981

210-
_model_aliases = {
211-
"gpt-35-turbo": "gpt-3.5-turbo",
212-
}
213-
214-
# Hardcoded because openai can't provide a trustworthly api to get the list of models and capabilities...
215-
openai_models = {
216-
"chat_completion": [
217-
"gpt-4",
218-
# "gpt-4-32k", # Not publicly available
219-
"gpt-3.5-turbo",
220-
"gpt-3.5-turbo-16k"
221-
],
222-
"completion": [
223-
"text-davinci-003",
224-
"text-curie-001",
225-
"text-babbage-001",
226-
"text-ada-001",
227-
],
228-
"fine_tuneing": [
229-
"davinci",
230-
"curie",
231-
"babbage",
232-
"ada",
233-
]
234-
}
235-
available_deployments = _get_available_deployments(openai_models, _model_aliases)
236-
available_models = _get_available_models(openai_models, available_deployments)
82+
available_deployments = _get_available_deployments()
83+
available_models = _get_available_models(available_deployments)
23784

23885
if available_models:
23986
logger.info("Available openai models: %s", ", ".join(available_models.keys()))
24087

24188
OpenAIModel = Enum('OpenAIModel', {name: name for name in available_models}) # type: ignore
242-
243-
244-
default_model_name = "gpt-3.5-turbo"
89+
default_model_name = "gpt-35-turbo"
24590
if "LLM_DEFAULT_MODEL" in os.environ and os.environ["LLM_DEFAULT_MODEL"] in available_models:
24691
default_model_name = os.environ["LLM_DEFAULT_MODEL"]
24792
if default_model_name not in available_models:
24893
default_model_name = list(available_models.keys())[0]
24994

250-
default_openai_model = OpenAIModel[default_model_name]
251-
95+
default_openai_model = OpenAIModel[default_model_name]#type:ignore
25296

25397
# Long descriptions will be displayed in the playground UI and are copied from the OpenAI docs
25498
class OpenAIModelConfig(ModelConfig):
@@ -307,7 +151,8 @@ def get_model(self) -> BaseLanguageModel:
307151
BaseLanguageModel: The model.
308152
"""
309153
model = available_models[self.model_name.value]
310-
kwargs = model._lc_kwargs
154+
kwargs = model.__dict__ #BaseLanguageModel type
155+
# kw = model._lc_kwargs
311156
secrets = {secret: getattr(model, secret) for secret in model.lc_secrets.keys()}
312157
kwargs.update(secrets)
313158

0 commit comments

Comments
 (0)