Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/Interceptor #2032

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 75 additions & 0 deletions examples/interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import time
from typing import Any, TypeVar
from typing_extensions import override

from openai import OpenAI
from openai._interceptor import Interceptor, InterceptorRequest, InterceptorResponse

T = TypeVar("T")


# Define a custom logging interceptor
class LoggingInterceptor(Interceptor):
@override
def before_request(self, request: InterceptorRequest) -> InterceptorRequest:
print(f"Request: {request.method} {request.url}")
print(f"Headers: {request.headers}")
if request.body:
print(f"Body: {request.body}")
return request

@override
def after_response(self, response: InterceptorResponse[Any]) -> InterceptorResponse[Any]:
print(f"Response Status: {response.status_code}")
print(f"Response Headers: {response.headers}")
print(f"Response Body: {response.body}")
return response


# Define an interceptor that implements retry logic with exponential backoff
class RetryInterceptor(Interceptor):
def __init__(self, max_retries: int = 3, initial_delay: float = 1.0):
self.max_retries = max_retries
self.initial_delay = initial_delay
self.current_retry = 0

@override
def before_request(self, request: InterceptorRequest) -> InterceptorRequest:
return request

@override
def after_response(self, response: InterceptorResponse[Any]) -> InterceptorResponse[Any]:
# If response is successful or we've exhausted retries, return as is
if response.status_code < 500 or self.current_retry >= self.max_retries:
self.current_retry = 0 # Reset for next request
return response

# Calculate delay with exponential backoff
delay = self.initial_delay * (2**self.current_retry)
print(f"Request failed with status {response.status_code}. Retrying in {delay} seconds...")
time.sleep(delay)

self.current_retry += 1
# Trigger a retry by raising an exception
raise Exception(f"Retrying request (attempt {self.current_retry}/{self.max_retries})")


# Initialize the OpenAI client and add interceptors
if __name__ == "__main__":
# Create the interceptor chain
logging_interceptor = LoggingInterceptor()
retry_interceptor = RetryInterceptor(max_retries=3, initial_delay=1.0)

# Create client with interceptors
client = OpenAI(interceptors=[logging_interceptor, retry_interceptor])

# Make a request using the client
response = client.chat.completions.create( # type: ignore
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Tell me about error handling and retries in software systems."}],
max_tokens=100,
stream=False,
)

# Output the final response
print("Final Response:", response)
46 changes: 45 additions & 1 deletion src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
APIConnectionError,
APIResponseValidationError,
)
from ._interceptor import Interceptor, InterceptorChain, InterceptorRequest, InterceptorResponse
from ._legacy_response import LegacyAPIResponse

log: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -339,6 +340,7 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
_strict_response_validation: bool
_idempotency_header: str | None
_default_stream_cls: type[_DefaultStreamT] | None = None
_interceptor_chain: InterceptorChain

def __init__(
self,
Expand All @@ -353,6 +355,7 @@ def __init__(
proxies: ProxiesTypes | None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
interceptors: list[Interceptor] | None = None,
) -> None:
self._version = version
self._base_url = self._enforce_trailing_slash(URL(base_url))
Expand All @@ -372,6 +375,8 @@ def __init__(
"max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`"
)

self._interceptor_chain = InterceptorChain(interceptors)

def _enforce_trailing_slash(self, url: URL) -> URL:
if url.raw_path.endswith(b"/"):
return url
Expand Down Expand Up @@ -463,8 +468,27 @@ def _build_request(
else:
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")

# Build base headers and params
headers = self._build_headers(options, retries_taken=retries_taken)
params = _merge_mappings(self.default_query, options.params)
prepared_url = self._prepare_url(options.url)

# Execute request interceptors
interceptor_request = InterceptorRequest(
method=options.method,
url=str(prepared_url),
headers=dict(headers),
params=dict(params),
body=json_data,
)
interceptor_request = self._interceptor_chain.execute_before_request(interceptor_request)

# Apply interceptor modifications
headers = httpx.Headers(interceptor_request.headers)
params = interceptor_request.params or {}
json_data = interceptor_request.body
prepared_url = URL(interceptor_request.url)

content_type = headers.get("Content-Type")
files = options.files

Expand Down Expand Up @@ -506,7 +530,7 @@ def _build_request(
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
headers=headers,
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
method=options.method,
method=interceptor_request.method,
url=prepared_url,
# the `Query` type that we use is incompatible with qs'
# `Params` type as it needs to be typed as `Mapping[str, object]`
Expand Down Expand Up @@ -582,6 +606,22 @@ def _process_response_data(
return cast(ResponseT, data)

try:
# Create InterceptorResponse and execute interceptors
interceptor_response = InterceptorResponse(
status_code=response.status_code,
headers=dict(response.headers),
body=data,
request=InterceptorRequest(
method=response.request.method,
url=str(response.request.url),
headers=dict(response.request.headers),
body=response.request._content if response.request._content else None,
),
raw_response=response,
)
interceptor_response = self._interceptor_chain.execute_after_response(interceptor_response)
data = interceptor_response.body

if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
return cast(ResponseT, cast_to.build(response=response, data=data))

Expand Down Expand Up @@ -796,6 +836,7 @@ def __init__(
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
interceptors: list[Interceptor] | None = None,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
Expand Down Expand Up @@ -859,6 +900,7 @@ def __init__(
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
interceptors=interceptors,
)
self._client = http_client or SyncHttpxClientWrapper(
base_url=base_url,
Expand Down Expand Up @@ -1382,6 +1424,7 @@ def __init__(
http_client: httpx.AsyncClient | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
interceptors: list[Interceptor] | None = None,
) -> None:
kwargs: dict[str, Any] = {}
if limits is not None:
Expand Down Expand Up @@ -1445,6 +1488,7 @@ def __init__(
custom_query=custom_query,
custom_headers=custom_headers,
_strict_response_validation=_strict_response_validation,
interceptors=interceptors,
)
self._client = http_client or AsyncHttpxClientWrapper(
base_url=base_url,
Expand Down
10 changes: 10 additions & 0 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

import httpx

from openai._interceptor import Interceptor

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -96,6 +98,7 @@ def __init__(
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
interceptors: list[Interceptor] | None = None,
) -> None:
"""Construct a new synchronous openai client instance.

Expand Down Expand Up @@ -136,6 +139,7 @@ def __init__(
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
interceptors=interceptors,
)

self._default_stream_cls = Stream
Expand Down Expand Up @@ -192,6 +196,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
interceptors: list[Interceptor] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -227,6 +232,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
interceptors=interceptors,
**_extra_kwargs,
)

Expand Down Expand Up @@ -323,6 +329,7 @@ def __init__(
# outlining your use-case to help us decide if it should be
# part of our public interface in the future.
_strict_response_validation: bool = False,
interceptors: list[Interceptor] | None = None,
) -> None:
"""Construct a new async openai client instance.

Expand Down Expand Up @@ -363,6 +370,7 @@ def __init__(
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
interceptors=interceptors,
)

self._default_stream_cls = AsyncStream
Expand Down Expand Up @@ -419,6 +427,7 @@ def copy(
set_default_headers: Mapping[str, str] | None = None,
default_query: Mapping[str, object] | None = None,
set_default_query: Mapping[str, object] | None = None,
interceptors: list[Interceptor] | None = None,
_extra_kwargs: Mapping[str, Any] = {},
) -> Self:
"""
Expand Down Expand Up @@ -454,6 +463,7 @@ def copy(
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
interceptors=interceptors,
**_extra_kwargs,
)

Expand Down
75 changes: 75 additions & 0 deletions src/openai/_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Dict, Union, Generic, TypeVar, Optional
from dataclasses import dataclass

import httpx

from ._types import Body

T = TypeVar("T")


@dataclass
class InterceptorRequest:
"""Request data container for interceptor processing"""

method: str
url: str
headers: Dict[str, str]
params: Optional[Dict[str, Any]] = None
body: Optional[Union[Body, bytes]] = None


@dataclass
class InterceptorResponse(Generic[T]):
"""Response data container for interceptor processing"""

status_code: int
headers: Dict[str, str]
body: T
request: InterceptorRequest
raw_response: httpx.Response


class Interceptor(ABC):
"""Base class for request/response interceptors"""

@abstractmethod
def before_request(self, request: InterceptorRequest) -> InterceptorRequest:
"""Process request before sending"""
pass

@abstractmethod
def after_response(self, response: InterceptorResponse[T]) -> InterceptorResponse[T]:
"""Process response after receiving"""
pass


class InterceptorChain:
"""Chain of interceptors for sequential request/response processing"""

def __init__(self, interceptors: Optional[list[Interceptor]] = None):
self._interceptors = interceptors or []

def add_interceptor(self, interceptor: Interceptor) -> None:
self._interceptors.append(interceptor)

def execute_before_request(self, request: InterceptorRequest) -> InterceptorRequest:
current_request = request
for interceptor in self._interceptors:
try:
current_request = interceptor.before_request(current_request)
except Exception:
continue
return current_request

def execute_after_response(self, response: InterceptorResponse[T]) -> InterceptorResponse[T]:
current_response = response
for interceptor in self._interceptors:
try:
current_response = interceptor.after_response(current_response)
except Exception:
continue
return current_response
1 change: 1 addition & 0 deletions src/openai/_utils/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def blocking_func(arg1, arg2, kwarg1=None):
# blocking code
return result


result = asyncify(blocking_function)(arg1, arg2, kwarg1=value1)
```

Expand Down
Loading