Skip to content

Commit 1be0bfe

Browse files
jhrozekrdimitrov
andauthored
Add an OpenRouter provider (#921)
* Move _get_base_url to the base provider In order to properly support "muxing providers" like openrouter, we'll have to tell litellm (or in future a native implementation), what server do we want to proxy to. We were already doing that with Vllm, but since are about to do the same for OpenRouter, let's move the `_get_base_url` method to the base provider. * Add an openrouter provider OpenRouter is a "muxing provider" which itself provides access to multiple models and providers. It speaks a dialect of the OpenAI protocol, but for our purposes, we can say it's OpenAI. There are some differences in handling the requests, though: 1) we need to know where to forward the request to, by default this is `https://openrouter.ai/api/v1`, this is done by setting the base_url parameter 2) we need to prefix the model with `openrouter/`. This is a lite-LLM-ism (see https://docs.litellm.ai/docs/providers/openrouter) which we'll be able to remove once we ditch litellm Initially I was considering just exposing the OpenAI provider on an additional route and handling the prefix based on the route, but I think having an explicit provider class is better as it allows us to handle any differences in OpenRouter dialect easily in the future. Related: #878 * Add a special ProviderType for openrouter We can later alias it to openai if we decide to merge them. * Add tests for the openrouter provider * ProviderType was reversed, thanks Alejandro --------- Co-authored-by: Radoslav Dimitrov <[email protected]>
1 parent 37168b5 commit 1be0bfe

File tree

11 files changed

+170
-5
lines changed

11 files changed

+170
-5
lines changed

src/codegate/config.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
# Default provider URLs
1818
DEFAULT_PROVIDER_URLS = {
1919
"openai": "https://api.openai.com/v1",
20+
"openrouter": "https://openrouter.ai/api/v1",
2021
"anthropic": "https://api.anthropic.com/v1",
2122
"vllm": "http://localhost:8000", # Base URL without /v1 path
2223
"ollama": "http://localhost:11434", # Default Ollama server URL

src/codegate/db/models.py

+1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class ProviderType(str, Enum):
128128
ollama = "ollama"
129129
lm_studio = "lm_studio"
130130
llamacpp = "llamacpp"
131+
openrouter = "openai"
131132

132133

133134
class GetPromptWithOutputsRow(BaseModel):

src/codegate/muxing/adapter.py

+2
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def __init__(self):
106106
db_models.ProviderType.anthropic: self._format_antropic,
107107
# Our Lllamacpp provider emits OpenAI chunks
108108
db_models.ProviderType.llamacpp: self._format_openai,
109+
# OpenRouter is a dialect of OpenAI
110+
db_models.ProviderType.openrouter: self._format_openai,
109111
}
110112

111113
def _format_ollama(self, chunk: str) -> str:

src/codegate/providers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
from codegate.providers.base import BaseProvider
33
from codegate.providers.ollama.provider import OllamaProvider
44
from codegate.providers.openai.provider import OpenAIProvider
5+
from codegate.providers.openrouter.provider import OpenRouterProvider
56
from codegate.providers.registry import ProviderRegistry
67
from codegate.providers.vllm.provider import VLLMProvider
78

89
__all__ = [
910
"BaseProvider",
1011
"ProviderRegistry",
1112
"OpenAIProvider",
13+
"OpenRouterProvider",
1214
"AnthropicProvider",
1315
"VLLMProvider",
1416
"OllamaProvider",

src/codegate/providers/base.py

+8
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from codegate.clients.clients import ClientType
1414
from codegate.codegate_logging import setup_logging
15+
from codegate.config import Config
1516
from codegate.db.connection import DbRecorder
1617
from codegate.pipeline.base import (
1718
PipelineContext,
@@ -88,6 +89,13 @@ async def process_request(
8889
def provider_route_name(self) -> str:
8990
pass
9091

92+
def _get_base_url(self) -> str:
93+
"""
94+
Get the base URL from config with proper formatting
95+
"""
96+
config = Config.get_config()
97+
return config.provider_urls.get(self.provider_route_name) if config else ""
98+
9199
async def _run_output_stream_pipeline(
92100
self,
93101
input_context: PipelineContext,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from codegate.providers.openai.provider import OpenAIProvider
2+
3+
__all__ = ["OpenAIProvider"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import json
2+
3+
from fastapi import Header, HTTPException, Request
4+
5+
from codegate.clients.detector import DetectClient
6+
from codegate.pipeline.factory import PipelineFactory
7+
from codegate.providers.openai import OpenAIProvider
8+
9+
10+
class OpenRouterProvider(OpenAIProvider):
11+
def __init__(self, pipeline_factory: PipelineFactory):
12+
super().__init__(pipeline_factory)
13+
14+
@property
15+
def provider_route_name(self) -> str:
16+
return "openrouter"
17+
18+
def _setup_routes(self):
19+
@self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions")
20+
@self.router.post(f"/{self.provider_route_name}/chat/completions")
21+
@DetectClient()
22+
async def create_completion(
23+
request: Request,
24+
authorization: str = Header(..., description="Bearer token"),
25+
):
26+
if not authorization.startswith("Bearer "):
27+
raise HTTPException(status_code=401, detail="Invalid authorization header")
28+
29+
api_key = authorization.split(" ")[1]
30+
body = await request.body()
31+
data = json.loads(body)
32+
33+
base_url = self._get_base_url()
34+
data["base_url"] = base_url
35+
36+
# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible
37+
# once we get rid of litellm, this can simply be removed
38+
original_model = data.get("model", "")
39+
if not original_model.startswith("openrouter/"):
40+
data["model"] = f"openrouter/{original_model}"
41+
42+
return await self.process_request(
43+
data,
44+
api_key,
45+
request.url.path,
46+
request.state.detected_client,
47+
)

src/codegate/providers/vllm/provider.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
from codegate.clients.clients import ClientType
1111
from codegate.clients.detector import DetectClient
12-
from codegate.config import Config
1312
from codegate.pipeline.factory import PipelineFactory
1413
from codegate.providers.base import BaseProvider, ModelFetchError
1514
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
@@ -39,8 +38,7 @@ def _get_base_url(self) -> str:
3938
"""
4039
Get the base URL from config with proper formatting
4140
"""
42-
config = Config.get_config()
43-
base_url = config.provider_urls.get("vllm") if config else ""
41+
base_url = super()._get_base_url()
4442
if base_url:
4543
base_url = base_url.rstrip("/")
4644
# Add /v1 if not present

src/codegate/server.py

+5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from codegate.providers.lm_studio.provider import LmStudioProvider
1919
from codegate.providers.ollama.provider import OllamaProvider
2020
from codegate.providers.openai.provider import OpenAIProvider
21+
from codegate.providers.openrouter.provider import OpenRouterProvider
2122
from codegate.providers.registry import ProviderRegistry, get_provider_registry
2223
from codegate.providers.vllm.provider import VLLMProvider
2324

@@ -75,6 +76,10 @@ async def log_user_agent(request: Request, call_next):
7576
ProviderType.openai,
7677
OpenAIProvider(pipeline_factory),
7778
)
79+
registry.add_provider(
80+
ProviderType.openrouter,
81+
OpenRouterProvider(pipeline_factory),
82+
)
7883
registry.add_provider(
7984
ProviderType.anthropic,
8085
AnthropicProvider(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import json
2+
from unittest.mock import AsyncMock, MagicMock
3+
4+
import pytest
5+
from fastapi import HTTPException
6+
from fastapi.requests import Request
7+
8+
from codegate.config import DEFAULT_PROVIDER_URLS
9+
from codegate.pipeline.factory import PipelineFactory
10+
from codegate.providers.openrouter.provider import OpenRouterProvider
11+
12+
13+
@pytest.fixture
14+
def mock_factory():
15+
return MagicMock(spec=PipelineFactory)
16+
17+
18+
@pytest.fixture
19+
def provider(mock_factory):
20+
return OpenRouterProvider(mock_factory)
21+
22+
23+
def test_get_base_url(provider):
24+
"""Test that _get_base_url returns the correct OpenRouter API URL"""
25+
assert provider._get_base_url() == DEFAULT_PROVIDER_URLS["openrouter"]
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_model_prefix_added():
30+
"""Test that model name gets prefixed with openrouter/ when not already present"""
31+
mock_factory = MagicMock(spec=PipelineFactory)
32+
provider = OpenRouterProvider(mock_factory)
33+
provider.process_request = AsyncMock()
34+
35+
# Mock request
36+
mock_request = MagicMock(spec=Request)
37+
mock_request.body = AsyncMock(return_value=json.dumps({"model": "gpt-4"}).encode())
38+
mock_request.url.path = "/openrouter/chat/completions"
39+
mock_request.state.detected_client = "test-client"
40+
41+
# Get the route handler function
42+
route_handlers = [
43+
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
44+
]
45+
create_completion = route_handlers[0].endpoint
46+
47+
await create_completion(request=mock_request, authorization="Bearer test-token")
48+
49+
# Verify process_request was called with prefixed model
50+
call_args = provider.process_request.call_args[0]
51+
assert call_args[0]["model"] == "openrouter/gpt-4"
52+
53+
54+
@pytest.mark.asyncio
55+
async def test_model_prefix_preserved():
56+
"""Test that model name is not modified when openrouter/ prefix is already present"""
57+
mock_factory = MagicMock(spec=PipelineFactory)
58+
provider = OpenRouterProvider(mock_factory)
59+
provider.process_request = AsyncMock()
60+
61+
# Mock request
62+
mock_request = MagicMock(spec=Request)
63+
mock_request.body = AsyncMock(return_value=json.dumps({"model": "openrouter/gpt-4"}).encode())
64+
mock_request.url.path = "/openrouter/chat/completions"
65+
mock_request.state.detected_client = "test-client"
66+
67+
# Get the route handler function
68+
route_handlers = [
69+
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
70+
]
71+
create_completion = route_handlers[0].endpoint
72+
73+
await create_completion(request=mock_request, authorization="Bearer test-token")
74+
75+
# Verify process_request was called with unchanged model name
76+
call_args = provider.process_request.call_args[0]
77+
assert call_args[0]["model"] == "openrouter/gpt-4"
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_invalid_auth_header():
82+
"""Test that invalid authorization header format raises HTTPException"""
83+
mock_factory = MagicMock(spec=PipelineFactory)
84+
provider = OpenRouterProvider(mock_factory)
85+
86+
mock_request = MagicMock(spec=Request)
87+
88+
# Get the route handler function
89+
route_handlers = [
90+
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
91+
]
92+
create_completion = route_handlers[0].endpoint
93+
94+
with pytest.raises(HTTPException) as exc_info:
95+
await create_completion(request=mock_request, authorization="InvalidToken")
96+
97+
assert exc_info.value.status_code == 401
98+
assert exc_info.value.detail == "Invalid authorization header"

tests/test_server.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_fa
108108
# Verify all providers were registered
109109
registry_instance = mock_registry.return_value
110110
assert (
111-
registry_instance.add_provider.call_count == 6
112-
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio
111+
registry_instance.add_provider.call_count == 7
112+
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio, openrouter
113113

114114
# Verify specific providers were registered
115115
provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list]

0 commit comments

Comments
 (0)