Skip to content

Commit

Permalink
simplify openai flow
Browse files Browse the repository at this point in the history
  • Loading branch information
= Enea_Gore committed Jul 23, 2024
1 parent dfd1878 commit 7a936df
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 129 deletions.
12 changes: 6 additions & 6 deletions module_text_llm/module_text_llm/generate_suggestions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Sequence
from pydantic import BaseModel, Field

# from pydantic import BaseModel, Field
from langchain_core.pydantic_v1 import BaseModel,Field, ValidationError
from athena import emit_meta
from athena.text import Exercise, Submission, Feedback
from athena.logger import logger
Expand All @@ -24,17 +24,17 @@ class FeedbackModel(BaseModel):
description="ID of the grading instruction that was used to generate this feedback, or empty if no grading instruction was used"
)

class Config:
title = "Feedback"
# class Config:
# title = "Feedback"


class AssessmentModel(BaseModel):
"""Collection of feedbacks making up an assessment"""

feedbacks: Sequence[FeedbackModel] = Field(description="Assessment feedbacks")

class Config:
title = "Assessment"
# class Config:
# title = "Assessment"


async def generate_suggestions(exercise: Exercise, submission: Submission, config: BasicApproachConfig, debug: bool) -> List[Feedback]:
Expand Down
79 changes: 22 additions & 57 deletions module_text_llm/module_text_llm/helpers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_openai.chat_models.base import BaseChatOpenAI
from pydantic import BaseModel, ValidationError
# from pydantic import BaseModel, ValidationError
from langchain_core.pydantic_v1 import BaseModel, ValidationError
import tiktoken
from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI

Expand Down Expand Up @@ -88,8 +89,8 @@ def supports_function_calling(model: BaseLanguageModel):
"""
return isinstance(model, ChatOpenAI)

def is_azure_call(model:BaseLanguageModel):
return isinstance(model, AzureChatOpenAI)
def is_openai(model:BaseLanguageModel):
return isinstance(model, AzureChatOpenAI) or isinstance(model, ChatOpenAI)

def get_chat_prompt_with_formatting_instructions(
model: BaseLanguageModel,
Expand All @@ -110,14 +111,16 @@ def get_chat_prompt_with_formatting_instructions(
Returns:
ChatPromptTemplate: ChatPromptTemplate with formatting instructions (if necessary)
"""
if supports_function_calling(model):
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
# if supports_function_calling(model):
# system_message_prompt = SystemMessagePromptTemplate.from_template(system_message)
# human_message_prompt = HumanMessagePromptTemplate.from_template(human_message)
# return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])

output_parser = PydanticOutputParser(pydantic_object=pydantic_object)
system_message_prompt = SystemMessagePromptTemplate.from_template(system_message + "\n{format_instructions}")
system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()}#type: ignore
# system_message_prompt.prompt.partial_variables = {"format_instructions": output_parser.get_format_instructions()}#type: ignore
system_message_prompt.prompt.partial_variables = {"format_instructions": ""}#type: ignore

system_message_prompt.prompt.input_variables.remove("format_instructions") #type:ignore
human_message_prompt = HumanMessagePromptTemplate.from_template(human_message + "\n\nJSON response following the provided schema:")
return ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
Expand All @@ -129,7 +132,7 @@ async def predict_and_parse(
pydantic_object: Type[T],
tags: Optional[List[str]]
) -> Optional[T]:
print("type of model is ", type(model))
print("The model type you are using:", type(model))
"""Predicts an LLM completion using the model and parses the output using the provided Pydantic model
Args:
Expand All @@ -153,51 +156,13 @@ async def predict_and_parse(
tags.append(f"run-{experiment.run_id}")

chat_prompt.tags = tags

if supports_function_calling(model):
#chain = create_structured_output_chain(pydantic_object, llm=model, prompt=chat_prompt, tags=tags)
openai_functions = [convert_to_openai_function(pydantic_object)]

runnable = chat_prompt | model.bind(functions=openai_functions).with_retry(
retry_if_exception_type=(ValueError, OutputParserException),
wait_exponential_jitter=True,
stop_after_attempt=3,
) | JsonOutputFunctionsParser()
try:
output_dict = await runnable.ainvoke(prompt_input)
print(output_dict)
return pydantic_object.model_validate(output_dict)
except (OutputParserException, ValidationError):
# In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt)
return None
elif is_azure_call(model) :
output_parser = PydanticOutputParser(pydantic_object=pydantic_object)
try:
runnable = chat_prompt | model.with_retry(
retry_if_exception_type=(ValueError, OutputParserException),
wait_exponential_jitter=True,
stop_after_attempt=3,
) | output_parser
output_dict = await runnable.ainvoke(prompt_input)
print(output_dict)
return pydantic_object.model_validate(output_dict)
except (OutputParserException, ValidationError):
# In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt)
return None
else:
output_parser = PydanticOutputParser(pydantic_object=pydantic_object)
#chain = LLMChain(llm=model, prompt=chat_prompt, output_parser=output_parser, tags=tags)
runnable = chat_prompt | model.with_retry(
retry_if_exception_type=(ValueError, OutputParserException),
wait_exponential_jitter=True,
stop_after_attempt=3,
) | output_parser

try:
output_dict = await runnable.ainvoke(prompt_input)
return pydantic_object.model_validate(output_dict)

except (OutputParserException, ValidationError):
# In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt)
# The future is now, or maybe soon
return None
if is_openai(model):
runnable = chat_prompt | model.with_structured_output(pydantic_object) # type: ignore

try:
output_dict = await runnable.ainvoke(prompt_input)
return pydantic_object.validate(output_dict)
except (OutputParserException, ValidationError):
# In the future, we should probably have some recovery mechanism here (i.e. fix the output with another prompt)
# The future is now, or maybe soon
return None
12 changes: 6 additions & 6 deletions module_text_llm/module_text_llm/helpers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
except AttributeError:
pass

# try:
# import module_text_llm.helpers.models.llama as ollama_config #type: ignore
# types.append(ollama_config.OllamaModelConfig)
# # DefaultModelConfig = ollama_config.OllamaModelConfig
# except AttributeError:
# pass
try:
import module_text_llm.helpers.models.llama as ollama_config #type: ignore
types.append(ollama_config.OllamaModelConfig)
# DefaultModelConfig = ollama_config.OllamaModelConfig
except AttributeError:
pass


if not types:
Expand Down
6 changes: 3 additions & 3 deletions module_text_llm/module_text_llm/helpers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from langchain_community.llms import Ollama # type: ignore
from module_text_llm.helpers.models.model_config import ModelConfig # type: ignore
from pydantic import validator, Field, PositiveInt,field_validator
from pydantic import validator, Field, PositiveInt
from langchain.base_language import BaseLanguageModel
import os
from langchain_community.chat_models import ChatOllama # type: ignore
Expand Down Expand Up @@ -35,7 +35,7 @@
model = name,
base_url = os.environ["OLLAMA_ENDPOINT"],
headers = auth_header,

format = "json"
) for name in ollama_models
}

Expand All @@ -62,7 +62,7 @@ class OllamaModelConfig(ModelConfig):
frequency_penalty: float = Field(default=0, ge=-2, le=2, description="")

base_url : str = Field(default="https://gpu-artemis.ase.cit.tum.de/ollama", description=" Base Url where ollama is hosted")
@field_validator('max_tokens')
@validator('max_tokens')
def max_tokens_must_be_positive(cls, v):
"""
Validate that max_tokens is a positive integer.
Expand Down
108 changes: 51 additions & 57 deletions module_text_llm/module_text_llm/helpers/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,53 @@
from enum import Enum
import openai
from langchain.base_language import BaseLanguageModel
from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI, OpenAI

from langchain_openai import AzureChatOpenAI, AzureOpenAI, ChatOpenAI
from athena.logger import logger
from .model_config import ModelConfig


OPENAI_PREFIX = "openai_"
AZURE_OPENAI_PREFIX = "azure_openai_"

openai_available = bool(os.environ.get("LLM_OPENAI_API_KEY"))
if openai_available:
############################################################################################
# START Set Enviorment variables that are automatically used by ChatOpenAI/ChatAzureOpenAI#
############################################################################################
# Might be worth renaming them
openai_available = bool(os.environ.get("LLM_OPENAI_API_KEY"))
if openai_available:
os.environ["OPENAI_API_KEY"] = os.environ["LLM_OPENAI_API_KEY"]
models_ai_api = openai.OpenAI().models.list()#type:ignore


azure_openai_available = bool(os.environ.get("LLM_AZURE_OPENAI_API_KEY"))
if azure_openai_available:
os.environ["AZURE_OPENAI_ENDPOINT"]=os.environ["LLM_AZURE_OPENAI_API_BASE"]
os.environ["AZURE_OPENAI_API_KEY"]=os.environ["LLM_AZURE_OPENAI_API_KEY"]
os.environ["OPENAI_API_VERSION"]=os.environ["LLM_AZURE_OPENAI_API_VERSION"]

#########################################################################################
# END Set Enviorment variables that are automatically used by ChatOpenAI/ChatAzureOpenAI#
#########################################################################################

actually_deployed_azure= [
"gpt-35-turbo"
]

new_openai_models = []

# """
# Calling AzureOpenAI.models returns a lot of models which are not available.
# A very hacky way to check if the deployments are actually deployed and available in azure.
# Might incure some minor costs, using the results at actually_deployed_azure for now.
# """
# def _check_deployment_availability():
# deployments = openai.AzureOpenAI.models.list()
# responding_deployments = []
# try:
# for deployment in deployments:
# deployment_model = AzureChatOpenAI(model=deployment.id)
# deployment_model.invoke("")
# responding_deployments.append(deployment)
# except:
# pass
# return responding_deployments

def _get_available_deployments():
available_deployments: Dict[str, Dict[str, Any]] = {
"chat_completion": {},
Expand All @@ -35,71 +61,39 @@ def _get_available_deployments():
}

if azure_openai_available:
deployments = openai.AzureOpenAI().models.list() or []#type:ignore
# This returns a lot of unusable models
deployments = openai.AzureOpenAI().models.list() or []#type:ignore
for deployment in deployments:
if deployment.capabilities["chat_completion"]:#type:ignore
available_deployments["chat_completion"][deployment.id] = deployment

# if openai_available:
# models_ai_api = openai.OpenAI().models.list()#type:ignore
# for model in models_ai_api:
# # print(model)
# pass
if(deployment.id in actually_deployed_azure):
if deployment.capabilities["chat_completion"]:#type:ignore
available_deployments["chat_completion"][deployment.id] = deployment
elif deployment.capabilities["completion"]:#type:ignore
available_deployments["completion"][deployment.id] = deployment

if openai_available:
# This will return only the usable models
openai.api_type= "openai"
for model in openai.models.list():
if model.owned_by == "openai":
new_openai_models.append(model.id)

return available_deployments

openai_models = {
"chat_completion": [
"gpt-4",
"gpt-4o-mini",
# "gpt-35",
# "gpt-4-32k", # Not publicly available
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k"
],
"completion": [
"text-davinci-003",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
],
"fine_tuneing": [
"davinci",
"curie",
"babbage",
"ada",
]
}

def _get_available_models(available_deployments: Dict[str, Dict[str, Any]]):
available_models: Dict[str, BaseLanguageModel] = {}

if openai_available:
openai_api_key = os.environ["LLM_OPENAI_API_KEY"]
for model in openai_models["chat_completion"]:
for model in new_openai_models:
available_models[OPENAI_PREFIX + model] = ChatOpenAI(#type:ignore
model=model,
)
# print( )
for model in openai_models["completion"]:
available_models[OPENAI_PREFIX + model] = OpenAI(#type:ignore
model= model
)


if azure_openai_available:
azure_openai_api_key = os.environ["LLM_AZURE_OPENAI_API_KEY"]
azure_openai_api_base = os.environ["LLM_AZURE_OPENAI_API_BASE"]
azure_openai_api_version = os.environ["LLM_AZURE_OPENAI_API_VERSION"]

for model_type, Model in [("chat_completion", AzureChatOpenAI), ("completion", AzureOpenAI)]:
for deployment_name, deployment in available_deployments[model_type].items():
available_models[AZURE_OPENAI_PREFIX + deployment_name] = Model(
deployment_name=deployment_name,
azure_endpoint=azure_openai_api_base,
openai_api_version=azure_openai_api_version,
openai_api_key=azure_openai_api_key,
client="",
temperature=0
)
return available_models
Expand All @@ -109,10 +103,10 @@ def _get_available_models(available_deployments: Dict[str, Dict[str, Any]]):
available_models = _get_available_models(available_deployments)

if available_models:
logger.info("Available openai models: %s", ", ".join(available_models.keys()))
# logger.info("Available openai models: %s", ", ".join(available_models.keys()))

OpenAIModel = Enum('OpenAIModel', {name: name for name in available_models}) # type: ignore
default_model_name = "gpt-3.5-turbo"
default_model_name = "gpt-35-turbo"
if "LLM_DEFAULT_MODEL" in os.environ and os.environ["LLM_DEFAULT_MODEL"] in available_models:
default_model_name = os.environ["LLM_DEFAULT_MODEL"]
if default_model_name not in available_models:
Expand Down

0 comments on commit 7a936df

Please sign in to comment.