Skip to content

Commit 8833d45

Browse files
authored
Implemented logit bias (#385)
* implemented logit bias * fixed comment * moved logit bias generation to `openai.py` * removed unused code from config * formatting * formatting * Revert "formatting" This reverts commit 5418326.
1 parent b49ff19 commit 8833d45

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/autolabel/configs/config.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class AutolabelConfig(BaseConfig):
2929
MODEL_NAME_KEY = "name"
3030
MODEL_PARAMS_KEY = "params"
3131
COMPUTE_CONFIDENCE_KEY = "compute_confidence"
32+
LOGIT_BIAS_KEY = "logit_bias"
3233

3334
# Embedding config keys (config["embedding"][<key>])
3435
EMBEDDING_PROVIDER_KEY = "provider"
@@ -124,6 +125,10 @@ def confidence(self) -> bool:
124125
"""Returns true if the model is able to return a confidence score along with its predictions"""
125126
return self._model_config.get(self.COMPUTE_CONFIDENCE_KEY, False)
126127

128+
def logit_bias(self) -> bool:
129+
"""Returns true if the model is configured to use a logit bias"""
130+
return self._model_config.get(self.LOGIT_BIAS_KEY, False)
131+
127132
# Embedding config
128133
def embedding_provider(self) -> str:
129134
"""Returns the name of the entity that provides the model used for computing embeddings"""

src/autolabel/models/openai.py

+38
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import cached_property
22
from typing import List, Optional
3+
import logging
34

45
from langchain.chat_models import ChatOpenAI
56
from langchain.llms import OpenAI
@@ -11,6 +12,9 @@
1112
from autolabel.cache import BaseCache
1213

1314

15+
logger = logging.getLogger(__name__)
16+
17+
1418
class OpenAILLM(BaseModel):
1519
CHAT_ENGINE_MODELS = [
1620
"gpt-3.5-turbo",
@@ -76,6 +80,14 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
7680

7781
# populate model params and initialize the LLM
7882
model_params = config.model_params()
83+
if config.logit_bias():
84+
logit_bias = self._generate_logit_bias(config)
85+
# if logit_bias or max_tokens is specified already, we don't want to overwrite it
86+
model_params = {
87+
**logit_bias,
88+
**model_params,
89+
}
90+
7991
if self._engine == "chat":
8092
self.model_params = {**self.DEFAULT_PARAMS_CHAT_ENGINE, **model_params}
8193
self.llm = ChatOpenAI(model_name=self.model_name, **self.model_params)
@@ -86,6 +98,32 @@ def __init__(self, config: AutolabelConfig, cache: BaseCache = None) -> None:
8698
}
8799
self.llm = OpenAI(model_name=self.model_name, **self.model_params)
88100

101+
def _generate_logit_bias(self, config: AutolabelConfig) -> None:
102+
"""Generates logit bias for the labels specified in the config
103+
104+
Args:
105+
config (AutolabelConfig): AutolabelConfig object
106+
107+
Returns:
108+
Dict: logit bias and max tokens
109+
"""
110+
if len(config.labels_list()) == 0:
111+
logger.warning(
112+
"No labels specified in the config. Skipping logit bias generation."
113+
)
114+
return {}
115+
encoding = tiktoken.encoding_for_model(self.model_name)
116+
logit_bias = {}
117+
max_tokens = 0
118+
for label in config.labels_list():
119+
if label not in logit_bias:
120+
tokens = encoding.encode(label)
121+
for token in tokens:
122+
logit_bias[token] = 100
123+
max_tokens = max(max_tokens, len(tokens))
124+
125+
return {"logit_bias": logit_bias, "max_tokens": max_tokens}
126+
89127
def _label(self, prompts: List[str]) -> LLMResult:
90128
if self._engine == "chat":
91129
# Need to convert list[prompts] -> list[messages]

0 commit comments

Comments
 (0)