diff --git a/elleelleaime/generate/strategies/models/openai/openai.py b/elleelleaime/generate/strategies/models/openai/openai.py index ab406f48..4b185ee8 100644 --- a/elleelleaime/generate/strategies/models/openai/openai.py +++ b/elleelleaime/generate/strategies/models/openai/openai.py @@ -14,10 +14,12 @@ def __init__(self, model_name: str, **kwargs) -> None: self.temperature = kwargs.get("temperature", 0.0) self.n_samples = kwargs.get("n_samples", 1) self.reasoning_effort = kwargs.get("reasoning_effort", "high") + self.base_url = kwargs.get("base_url", None) + self.batching = kwargs.get("batching", True) load_dotenv() openai.api_key = os.getenv("OPENAI_API_KEY") - self.client = openai.OpenAI(api_key=openai.api_key) + self.client = openai.OpenAI(api_key=openai.api_key, base_url=self.base_url) @backoff.on_exception(backoff.expo, Exception) def _completions_with_backoff(self, **kwargs): @@ -27,8 +29,7 @@ def _generate_impl(self, chunk: List[str]) -> Any: result = [] for prompt in chunk: - # TODO: Temporary fix to handle beta version of "oX" family of models - if self.model_name.startswith("o"): + if not self.batching: result_sample = [] for _ in range(self.n_samples): completion = self._completions_with_backoff(