Skip to content

Commit c44ab9a

Browse files
committed
feat(llm): enhance LiteLLMCaller to support custom API key handling and environment variable expansion
1 parent 252d963 commit c44ab9a

File tree

2 files changed

+146
-6
lines changed

2 files changed

+146
-6
lines changed

backend/modules/llm/litellm_caller.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,13 @@ def _get_model_kwargs(self, model_name: str, temperature: Optional[float] = None
8585
except ValueError as e:
8686
logger.error(f"Failed to resolve API key for model {model_name}: {e}")
8787
raise
88-
88+
8989
if api_key:
90+
# Always pass api_key to LiteLLM for all providers
91+
kwargs["api_key"] = api_key
92+
93+
# Additionally set provider-specific env vars for LiteLLM's internal logic
9094
if "openrouter" in model_config.model_url:
91-
kwargs["api_key"] = api_key
92-
# LiteLLM will automatically set the correct env var
9395
os.environ["OPENROUTER_API_KEY"] = api_key
9496
elif "openai" in model_config.model_url:
9597
os.environ["OPENAI_API_KEY"] = api_key
@@ -99,6 +101,10 @@ def _get_model_kwargs(self, model_name: str, temperature: Optional[float] = None
99101
os.environ["GOOGLE_API_KEY"] = api_key
100102
elif "cerebras" in model_config.model_url:
101103
os.environ["CEREBRAS_API_KEY"] = api_key
104+
else:
105+
# Custom endpoint - set OPENAI_API_KEY as fallback
106+
# (most custom endpoints are OpenAI-compatible)
107+
os.environ["OPENAI_API_KEY"] = api_key
102108

103109
# Set custom API base for non-standard endpoints
104110
if hasattr(model_config, 'model_url') and model_config.model_url:

backend/tests/test_llm_env_expansion.py

Lines changed: 137 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,14 +146,148 @@ def test_litellm_caller_handles_literal_extra_headers(self):
146146
)
147147
}
148148
)
149-
149+
150150
# Create LiteLLMCaller
151151
caller = LiteLLMCaller(llm_config, debug_mode=True)
152-
152+
153153
# Get model kwargs - this should work without errors
154154
model_kwargs = caller._get_model_kwargs("test-model")
155-
155+
156156
# Verify that extra_headers were passed through
157157
assert "extra_headers" in model_kwargs
158158
assert model_kwargs["extra_headers"]["HTTP-Referer"] == "https://literal-app.com"
159159
assert model_kwargs["extra_headers"]["X-Title"] == "LiteralApp"
160+
161+
def test_custom_endpoint_with_env_var_api_key(self, monkeypatch):
162+
"""Custom endpoint should pass api_key in kwargs when using env var."""
163+
monkeypatch.setenv("CUSTOM_LLM_KEY", "sk-custom-12345")
164+
165+
# Create LLM config for custom endpoint with env var in api_key
166+
llm_config = LLMConfig(
167+
models={
168+
"custom-model": ModelConfig(
169+
model_name="custom-model-name",
170+
model_url="https://custom-llm.example.com/v1",
171+
api_key="${CUSTOM_LLM_KEY}"
172+
)
173+
}
174+
)
175+
176+
# Create LiteLLMCaller
177+
caller = LiteLLMCaller(llm_config, debug_mode=True)
178+
179+
# Get model kwargs
180+
model_kwargs = caller._get_model_kwargs("custom-model")
181+
182+
# Verify that api_key is in kwargs (critical for custom endpoints)
183+
assert "api_key" in model_kwargs
184+
assert model_kwargs["api_key"] == "sk-custom-12345"
185+
186+
# Verify that api_base is set for custom endpoint
187+
assert "api_base" in model_kwargs
188+
assert model_kwargs["api_base"] == "https://custom-llm.example.com/v1"
189+
190+
# Verify fallback env var is set for OpenAI-compatible endpoints
191+
import os
192+
assert os.environ.get("OPENAI_API_KEY") == "sk-custom-12345"
193+
194+
def test_custom_endpoint_with_literal_api_key(self):
195+
"""Custom endpoint should pass api_key in kwargs when using literal value."""
196+
# Create LLM config for custom endpoint with literal api_key
197+
llm_config = LLMConfig(
198+
models={
199+
"custom-model": ModelConfig(
200+
model_name="custom-model-name",
201+
model_url="https://custom-llm.example.com/v1",
202+
api_key="sk-literal-custom-key"
203+
)
204+
}
205+
)
206+
207+
# Create LiteLLMCaller
208+
caller = LiteLLMCaller(llm_config, debug_mode=True)
209+
210+
# Get model kwargs
211+
model_kwargs = caller._get_model_kwargs("custom-model")
212+
213+
# Verify that api_key is in kwargs (critical for custom endpoints)
214+
assert "api_key" in model_kwargs
215+
assert model_kwargs["api_key"] == "sk-literal-custom-key"
216+
217+
# Verify that api_base is set for custom endpoint
218+
assert "api_base" in model_kwargs
219+
assert model_kwargs["api_base"] == "https://custom-llm.example.com/v1"
220+
221+
def test_custom_endpoint_with_extra_headers(self, monkeypatch):
222+
"""Custom endpoint should handle extra_headers correctly."""
223+
monkeypatch.setenv("CUSTOM_API_KEY", "sk-custom-auth")
224+
monkeypatch.setenv("CUSTOM_TENANT", "tenant-123")
225+
226+
# Create LLM config for custom endpoint with extra headers
227+
llm_config = LLMConfig(
228+
models={
229+
"custom-model": ModelConfig(
230+
model_name="custom-model-name",
231+
model_url="https://custom-llm.example.com/v1",
232+
api_key="${CUSTOM_API_KEY}",
233+
extra_headers={
234+
"X-Tenant-ID": "${CUSTOM_TENANT}",
235+
"X-Custom-Header": "custom-value"
236+
}
237+
)
238+
}
239+
)
240+
241+
# Create LiteLLMCaller
242+
caller = LiteLLMCaller(llm_config, debug_mode=True)
243+
244+
# Get model kwargs
245+
model_kwargs = caller._get_model_kwargs("custom-model")
246+
247+
# Verify api_key is passed
248+
assert "api_key" in model_kwargs
249+
assert model_kwargs["api_key"] == "sk-custom-auth"
250+
251+
# Verify extra_headers are resolved and passed
252+
assert "extra_headers" in model_kwargs
253+
assert model_kwargs["extra_headers"]["X-Tenant-ID"] == "tenant-123"
254+
assert model_kwargs["extra_headers"]["X-Custom-Header"] == "custom-value"
255+
256+
# Verify api_base is set
257+
assert "api_base" in model_kwargs
258+
259+
def test_known_providers_still_get_api_key_in_kwargs(self):
260+
"""Verify that known providers also get api_key in kwargs (backward compatibility)."""
261+
# Test OpenAI
262+
llm_config = LLMConfig(
263+
models={
264+
"openai-model": ModelConfig(
265+
model_name="gpt-4",
266+
model_url="https://api.openai.com/v1",
267+
api_key="sk-openai-test"
268+
)
269+
}
270+
)
271+
caller = LiteLLMCaller(llm_config, debug_mode=True)
272+
model_kwargs = caller._get_model_kwargs("openai-model")
273+
274+
# OpenAI should get api_key in kwargs
275+
assert "api_key" in model_kwargs
276+
assert model_kwargs["api_key"] == "sk-openai-test"
277+
278+
# Test OpenRouter
279+
llm_config = LLMConfig(
280+
models={
281+
"openrouter-model": ModelConfig(
282+
model_name="meta-llama/llama-3-70b",
283+
model_url="https://openrouter.ai/api/v1",
284+
api_key="sk-or-test"
285+
)
286+
}
287+
)
288+
caller = LiteLLMCaller(llm_config, debug_mode=True)
289+
model_kwargs = caller._get_model_kwargs("openrouter-model")
290+
291+
# OpenRouter should get api_key in kwargs
292+
assert "api_key" in model_kwargs
293+
assert model_kwargs["api_key"] == "sk-or-test"

0 commit comments

Comments
 (0)