Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an OpenRouter provider #921

Merged
merged 6 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/codegate/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Default provider URLs
DEFAULT_PROVIDER_URLS = {
"openai": "https://api.openai.com/v1",
"openrouter": "https://openrouter.ai/api/v1",
"anthropic": "https://api.anthropic.com/v1",
"vllm": "http://localhost:8000", # Base URL without /v1 path
"ollama": "http://localhost:11434", # Default Ollama server URL
Expand Down
1 change: 1 addition & 0 deletions src/codegate/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class ProviderType(str, Enum):
ollama = "ollama"
lm_studio = "lm_studio"
llamacpp = "llamacpp"
openrouter = "openai"


class GetPromptWithOutputsRow(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/muxing/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ def __init__(self):
db_models.ProviderType.anthropic: self._format_antropic,
# Our Lllamacpp provider emits OpenAI chunks
db_models.ProviderType.llamacpp: self._format_openai,
# OpenRouter is a dialect of OpenAI
db_models.ProviderType.openrouter: self._format_openai,
}

def _format_ollama(self, chunk: str) -> str:
Expand Down
2 changes: 2 additions & 0 deletions src/codegate/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from codegate.providers.base import BaseProvider
from codegate.providers.ollama.provider import OllamaProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.openrouter.provider import OpenRouterProvider
from codegate.providers.registry import ProviderRegistry
from codegate.providers.vllm.provider import VLLMProvider

__all__ = [
"BaseProvider",
"ProviderRegistry",
"OpenAIProvider",
"OpenRouterProvider",
"AnthropicProvider",
"VLLMProvider",
"OllamaProvider",
Expand Down
8 changes: 8 additions & 0 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from codegate.clients.clients import ClientType
from codegate.codegate_logging import setup_logging
from codegate.config import Config
from codegate.db.connection import DbRecorder
from codegate.pipeline.base import (
PipelineContext,
Expand Down Expand Up @@ -88,6 +89,13 @@ async def process_request(
def provider_route_name(self) -> str:
pass

def _get_base_url(self) -> str:
"""
Get the base URL from config with proper formatting
"""
config = Config.get_config()
return config.provider_urls.get(self.provider_route_name) if config else ""

async def _run_output_stream_pipeline(
self,
input_context: PipelineContext,
Expand Down
3 changes: 3 additions & 0 deletions src/codegate/providers/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from codegate.providers.openai.provider import OpenAIProvider

__all__ = ["OpenAIProvider"]
47 changes: 47 additions & 0 deletions src/codegate/providers/openrouter/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import json

from fastapi import Header, HTTPException, Request

from codegate.clients.detector import DetectClient
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.openai import OpenAIProvider


class OpenRouterProvider(OpenAIProvider):
def __init__(self, pipeline_factory: PipelineFactory):
super().__init__(pipeline_factory)

@property
def provider_route_name(self) -> str:
return "openrouter"

def _setup_routes(self):
@self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions")
@self.router.post(f"/{self.provider_route_name}/chat/completions")
@DetectClient()
async def create_completion(
request: Request,
authorization: str = Header(..., description="Bearer token"),
):
if not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")

api_key = authorization.split(" ")[1]
body = await request.body()
data = json.loads(body)

base_url = self._get_base_url()
data["base_url"] = base_url

# litellm workaround - add openrouter/ prefix to model name to make it openai-compatible
# once we get rid of litellm, this can simply be removed
original_model = data.get("model", "")
if not original_model.startswith("openrouter/"):
data["model"] = f"openrouter/{original_model}"

return await self.process_request(
data,
api_key,
request.url.path,
request.state.detected_client,
)
4 changes: 1 addition & 3 deletions src/codegate/providers/vllm/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from codegate.clients.clients import ClientType
from codegate.clients.detector import DetectClient
from codegate.config import Config
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.base import BaseProvider, ModelFetchError
from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator
Expand Down Expand Up @@ -39,8 +38,7 @@ def _get_base_url(self) -> str:
"""
Get the base URL from config with proper formatting
"""
config = Config.get_config()
base_url = config.provider_urls.get("vllm") if config else ""
base_url = super()._get_base_url()
if base_url:
base_url = base_url.rstrip("/")
# Add /v1 if not present
Expand Down
5 changes: 5 additions & 0 deletions src/codegate/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from codegate.providers.lm_studio.provider import LmStudioProvider
from codegate.providers.ollama.provider import OllamaProvider
from codegate.providers.openai.provider import OpenAIProvider
from codegate.providers.openrouter.provider import OpenRouterProvider
from codegate.providers.registry import ProviderRegistry, get_provider_registry
from codegate.providers.vllm.provider import VLLMProvider

Expand Down Expand Up @@ -75,6 +76,10 @@ async def log_user_agent(request: Request, call_next):
ProviderType.openai,
OpenAIProvider(pipeline_factory),
)
registry.add_provider(
ProviderType.openrouter,
OpenRouterProvider(pipeline_factory),
)
registry.add_provider(
ProviderType.anthropic,
AnthropicProvider(
Expand Down
98 changes: 98 additions & 0 deletions tests/providers/openrouter/test_openrouter_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
from unittest.mock import AsyncMock, MagicMock

import pytest
from fastapi import HTTPException
from fastapi.requests import Request

from codegate.config import DEFAULT_PROVIDER_URLS
from codegate.pipeline.factory import PipelineFactory
from codegate.providers.openrouter.provider import OpenRouterProvider


@pytest.fixture
def mock_factory():
return MagicMock(spec=PipelineFactory)


@pytest.fixture
def provider(mock_factory):
return OpenRouterProvider(mock_factory)


def test_get_base_url(provider):
"""Test that _get_base_url returns the correct OpenRouter API URL"""
assert provider._get_base_url() == DEFAULT_PROVIDER_URLS["openrouter"]


@pytest.mark.asyncio
async def test_model_prefix_added():
"""Test that model name gets prefixed with openrouter/ when not already present"""
mock_factory = MagicMock(spec=PipelineFactory)
provider = OpenRouterProvider(mock_factory)
provider.process_request = AsyncMock()

# Mock request
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(return_value=json.dumps({"model": "gpt-4"}).encode())
mock_request.url.path = "/openrouter/chat/completions"
mock_request.state.detected_client = "test-client"

# Get the route handler function
route_handlers = [
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
]
create_completion = route_handlers[0].endpoint

await create_completion(request=mock_request, authorization="Bearer test-token")

# Verify process_request was called with prefixed model
call_args = provider.process_request.call_args[0]
assert call_args[0]["model"] == "openrouter/gpt-4"


@pytest.mark.asyncio
async def test_model_prefix_preserved():
"""Test that model name is not modified when openrouter/ prefix is already present"""
mock_factory = MagicMock(spec=PipelineFactory)
provider = OpenRouterProvider(mock_factory)
provider.process_request = AsyncMock()

# Mock request
mock_request = MagicMock(spec=Request)
mock_request.body = AsyncMock(return_value=json.dumps({"model": "openrouter/gpt-4"}).encode())
mock_request.url.path = "/openrouter/chat/completions"
mock_request.state.detected_client = "test-client"

# Get the route handler function
route_handlers = [
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
]
create_completion = route_handlers[0].endpoint

await create_completion(request=mock_request, authorization="Bearer test-token")

# Verify process_request was called with unchanged model name
call_args = provider.process_request.call_args[0]
assert call_args[0]["model"] == "openrouter/gpt-4"


@pytest.mark.asyncio
async def test_invalid_auth_header():
"""Test that invalid authorization header format raises HTTPException"""
mock_factory = MagicMock(spec=PipelineFactory)
provider = OpenRouterProvider(mock_factory)

mock_request = MagicMock(spec=Request)

# Get the route handler function
route_handlers = [
route for route in provider.router.routes if route.path == "/openrouter/chat/completions"
]
create_completion = route_handlers[0].endpoint

with pytest.raises(HTTPException) as exc_info:
await create_completion(request=mock_request, authorization="InvalidToken")

assert exc_info.value.status_code == 401
assert exc_info.value.detail == "Invalid authorization header"
4 changes: 2 additions & 2 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_fa
# Verify all providers were registered
registry_instance = mock_registry.return_value
assert (
registry_instance.add_provider.call_count == 6
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio
registry_instance.add_provider.call_count == 7
) # openai, anthropic, llamacpp, vllm, ollama, lm_studio, openrouter

# Verify specific providers were registered
provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list]
Expand Down