1
1
from functools import cached_property
2
2
from typing import List , Optional
3
+ import logging
3
4
4
5
from langchain .chat_models import ChatOpenAI
5
6
from langchain .llms import OpenAI
11
12
from autolabel .cache import BaseCache
12
13
13
14
15
+ logger = logging .getLogger (__name__ )
16
+
17
+
14
18
class OpenAILLM (BaseModel ):
15
19
CHAT_ENGINE_MODELS = [
16
20
"gpt-3.5-turbo" ,
@@ -76,6 +80,14 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
76
80
77
81
# populate model params and initialize the LLM
78
82
model_params = config .model_params ()
83
+ if config .logit_bias ():
84
+ logit_bias = self ._generate_logit_bias (config )
85
+ # if logit_bias or max_tokens is specified already, we don't want to overwrite it
86
+ model_params = {
87
+ ** logit_bias ,
88
+ ** model_params ,
89
+ }
90
+
79
91
if self ._engine == "chat" :
80
92
self .model_params = {** self .DEFAULT_PARAMS_CHAT_ENGINE , ** model_params }
81
93
self .llm = ChatOpenAI (model_name = self .model_name , ** self .model_params )
@@ -86,6 +98,32 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
86
98
}
87
99
self .llm = OpenAI (model_name = self .model_name , ** self .model_params )
88
100
101
+ def _generate_logit_bias (self , config : AutolabelConfig ) -> None :
102
+ """Generates logit bias for the labels specified in the config
103
+
104
+ Args:
105
+ config (AutolabelConfig): AutolabelConfig object
106
+
107
+ Returns:
108
+ Dict: logit bias and max tokens
109
+ """
110
+ if len (config .labels_list ()) == 0 :
111
+ logger .warning (
112
+ "No labels specified in the config. Skipping logit bias generation."
113
+ )
114
+ return {}
115
+ encoding = tiktoken .encoding_for_model (self .model_name )
116
+ logit_bias = {}
117
+ max_tokens = 0
118
+ for label in config .labels_list ():
119
+ if label not in logit_bias :
120
+ tokens = encoding .encode (label )
121
+ for token in tokens :
122
+ logit_bias [token ] = 100
123
+ max_tokens = max (max_tokens , len (tokens ))
124
+
125
+ return {"logit_bias" : logit_bias , "max_tokens" : max_tokens }
126
+
89
127
def _label (self , prompts : List [str ]) -> LLMResult :
90
128
if self ._engine == "chat" :
91
129
# Need to convert list[prompts] -> list[messages]
0 commit comments