diff --git a/src/codegate/config.py b/src/codegate/config.py index fb4d08bf..11cd96bf 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -17,6 +17,7 @@ # Default provider URLs DEFAULT_PROVIDER_URLS = { "openai": "https://api.openai.com/v1", + "openrouter": "https://openrouter.ai/api/v1", "anthropic": "https://api.anthropic.com/v1", "vllm": "http://localhost:8000", # Base URL without /v1 path "ollama": "http://localhost:11434", # Default Ollama server URL diff --git a/src/codegate/db/models.py b/src/codegate/db/models.py index 8d9a5261..21da54fe 100644 --- a/src/codegate/db/models.py +++ b/src/codegate/db/models.py @@ -128,6 +128,7 @@ class ProviderType(str, Enum): ollama = "ollama" lm_studio = "lm_studio" llamacpp = "llamacpp" + openrouter = "openai" class GetPromptWithOutputsRow(BaseModel): diff --git a/src/codegate/muxing/adapter.py b/src/codegate/muxing/adapter.py index 3a2d74ca..1fbabcaf 100644 --- a/src/codegate/muxing/adapter.py +++ b/src/codegate/muxing/adapter.py @@ -106,6 +106,8 @@ def __init__(self): db_models.ProviderType.anthropic: self._format_antropic, # Our Lllamacpp provider emits OpenAI chunks db_models.ProviderType.llamacpp: self._format_openai, + # OpenRouter is a dialect of OpenAI + db_models.ProviderType.openrouter: self._format_openai, } def _format_ollama(self, chunk: str) -> str: diff --git a/src/codegate/providers/__init__.py b/src/codegate/providers/__init__.py index d8a35e8e..69ec022e 100644 --- a/src/codegate/providers/__init__.py +++ b/src/codegate/providers/__init__.py @@ -2,6 +2,7 @@ from codegate.providers.base import BaseProvider from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider +from codegate.providers.openrouter.provider import OpenRouterProvider from codegate.providers.registry import ProviderRegistry from codegate.providers.vllm.provider import VLLMProvider @@ -9,6 +10,7 @@ "BaseProvider", "ProviderRegistry", "OpenAIProvider", + "OpenRouterProvider", "AnthropicProvider", "VLLMProvider", "OllamaProvider", diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 05a1f922..269fd0e5 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -12,6 +12,7 @@ from codegate.clients.clients import ClientType from codegate.codegate_logging import setup_logging +from codegate.config import Config from codegate.db.connection import DbRecorder from codegate.pipeline.base import ( PipelineContext, @@ -88,6 +89,13 @@ async def process_request( def provider_route_name(self) -> str: pass + def _get_base_url(self) -> str: + """ + Get the base URL from config with proper formatting + """ + config = Config.get_config() + return config.provider_urls.get(self.provider_route_name) if config else "" + async def _run_output_stream_pipeline( self, input_context: PipelineContext, diff --git a/src/codegate/providers/openai/__init__.py b/src/codegate/providers/openai/__init__.py index e69de29b..9055468d 100644 --- a/src/codegate/providers/openai/__init__.py +++ b/src/codegate/providers/openai/__init__.py @@ -0,0 +1,3 @@ +from codegate.providers.openai.provider import OpenAIProvider + +__all__ = ["OpenAIProvider"] diff --git a/src/codegate/providers/openrouter/provider.py b/src/codegate/providers/openrouter/provider.py new file mode 100644 index 00000000..0e770a01 --- /dev/null +++ b/src/codegate/providers/openrouter/provider.py @@ -0,0 +1,47 @@ +import json + +from fastapi import Header, HTTPException, Request + +from codegate.clients.detector import DetectClient +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.openai import OpenAIProvider + + +class OpenRouterProvider(OpenAIProvider): + def __init__(self, pipeline_factory: PipelineFactory): + super().__init__(pipeline_factory) + + @property + def provider_route_name(self) -> str: + return "openrouter" + + def _setup_routes(self): + @self.router.post(f"/{self.provider_route_name}/api/v1/chat/completions") + @self.router.post(f"/{self.provider_route_name}/chat/completions") + @DetectClient() + async def create_completion( + request: Request, + authorization: str = Header(..., description="Bearer token"), + ): + if not authorization.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Invalid authorization header") + + api_key = authorization.split(" ")[1] + body = await request.body() + data = json.loads(body) + + base_url = self._get_base_url() + data["base_url"] = base_url + + # litellm workaround - add openrouter/ prefix to model name to make it openai-compatible + # once we get rid of litellm, this can simply be removed + original_model = data.get("model", "") + if not original_model.startswith("openrouter/"): + data["model"] = f"openrouter/{original_model}" + + return await self.process_request( + data, + api_key, + request.url.path, + request.state.detected_client, + ) diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 48c74cce..16f73d6d 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -9,7 +9,6 @@ from codegate.clients.clients import ClientType from codegate.clients.detector import DetectClient -from codegate.config import Config from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider, ModelFetchError from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator @@ -39,8 +38,7 @@ def _get_base_url(self) -> str: """ Get the base URL from config with proper formatting """ - config = Config.get_config() - base_url = config.provider_urls.get("vllm") if config else "" + base_url = super()._get_base_url() if base_url: base_url = base_url.rstrip("/") # Add /v1 if not present diff --git a/src/codegate/server.py b/src/codegate/server.py index 215f6e09..57503b12 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -18,6 +18,7 @@ from codegate.providers.lm_studio.provider import LmStudioProvider from codegate.providers.ollama.provider import OllamaProvider from codegate.providers.openai.provider import OpenAIProvider +from codegate.providers.openrouter.provider import OpenRouterProvider from codegate.providers.registry import ProviderRegistry, get_provider_registry from codegate.providers.vllm.provider import VLLMProvider @@ -75,6 +76,10 @@ async def log_user_agent(request: Request, call_next): ProviderType.openai, OpenAIProvider(pipeline_factory), ) + registry.add_provider( + ProviderType.openrouter, + OpenRouterProvider(pipeline_factory), + ) registry.add_provider( ProviderType.anthropic, AnthropicProvider( diff --git a/tests/providers/openrouter/test_openrouter_provider.py b/tests/providers/openrouter/test_openrouter_provider.py new file mode 100644 index 00000000..14bafd35 --- /dev/null +++ b/tests/providers/openrouter/test_openrouter_provider.py @@ -0,0 +1,98 @@ +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest +from fastapi import HTTPException +from fastapi.requests import Request + +from codegate.config import DEFAULT_PROVIDER_URLS +from codegate.pipeline.factory import PipelineFactory +from codegate.providers.openrouter.provider import OpenRouterProvider + + +@pytest.fixture +def mock_factory(): + return MagicMock(spec=PipelineFactory) + + +@pytest.fixture +def provider(mock_factory): + return OpenRouterProvider(mock_factory) + + +def test_get_base_url(provider): + """Test that _get_base_url returns the correct OpenRouter API URL""" + assert provider._get_base_url() == DEFAULT_PROVIDER_URLS["openrouter"] + + +@pytest.mark.asyncio +async def test_model_prefix_added(): + """Test that model name gets prefixed with openrouter/ when not already present""" + mock_factory = MagicMock(spec=PipelineFactory) + provider = OpenRouterProvider(mock_factory) + provider.process_request = AsyncMock() + + # Mock request + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(return_value=json.dumps({"model": "gpt-4"}).encode()) + mock_request.url.path = "/openrouter/chat/completions" + mock_request.state.detected_client = "test-client" + + # Get the route handler function + route_handlers = [ + route for route in provider.router.routes if route.path == "/openrouter/chat/completions" + ] + create_completion = route_handlers[0].endpoint + + await create_completion(request=mock_request, authorization="Bearer test-token") + + # Verify process_request was called with prefixed model + call_args = provider.process_request.call_args[0] + assert call_args[0]["model"] == "openrouter/gpt-4" + + +@pytest.mark.asyncio +async def test_model_prefix_preserved(): + """Test that model name is not modified when openrouter/ prefix is already present""" + mock_factory = MagicMock(spec=PipelineFactory) + provider = OpenRouterProvider(mock_factory) + provider.process_request = AsyncMock() + + # Mock request + mock_request = MagicMock(spec=Request) + mock_request.body = AsyncMock(return_value=json.dumps({"model": "openrouter/gpt-4"}).encode()) + mock_request.url.path = "/openrouter/chat/completions" + mock_request.state.detected_client = "test-client" + + # Get the route handler function + route_handlers = [ + route for route in provider.router.routes if route.path == "/openrouter/chat/completions" + ] + create_completion = route_handlers[0].endpoint + + await create_completion(request=mock_request, authorization="Bearer test-token") + + # Verify process_request was called with unchanged model name + call_args = provider.process_request.call_args[0] + assert call_args[0]["model"] == "openrouter/gpt-4" + + +@pytest.mark.asyncio +async def test_invalid_auth_header(): + """Test that invalid authorization header format raises HTTPException""" + mock_factory = MagicMock(spec=PipelineFactory) + provider = OpenRouterProvider(mock_factory) + + mock_request = MagicMock(spec=Request) + + # Get the route handler function + route_handlers = [ + route for route in provider.router.routes if route.path == "/openrouter/chat/completions" + ] + create_completion = route_handlers[0].endpoint + + with pytest.raises(HTTPException) as exc_info: + await create_completion(request=mock_request, authorization="InvalidToken") + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid authorization header" diff --git a/tests/test_server.py b/tests/test_server.py index 46e2f867..1e06c096 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -108,8 +108,8 @@ def test_provider_registration(mock_registry, mock_secrets_mgr, mock_pipeline_fa # Verify all providers were registered registry_instance = mock_registry.return_value assert ( - registry_instance.add_provider.call_count == 6 - ) # openai, anthropic, llamacpp, vllm, ollama, lm_studio + registry_instance.add_provider.call_count == 7 + ) # openai, anthropic, llamacpp, vllm, ollama, lm_studio, openrouter # Verify specific providers were registered provider_names = [call.args[0] for call in registry_instance.add_provider.call_args_list]