Skip to content

Commit c37965e

Browse files
committed
Intial Commit
1 parent 33e4085 commit c37965e

File tree

5 files changed

+397
-5
lines changed

5 files changed

+397
-5
lines changed

examples/interceptor.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from typing import TypeVar, Any
2+
import time
3+
from typing_extensions import override
4+
from openai import OpenAI
5+
from openai._interceptor import Interceptor, InterceptorRequest, InterceptorResponse
6+
from dotenv import load_dotenv
7+
8+
load_dotenv()
9+
10+
T = TypeVar("T")
11+
12+
# Define a custom logging interceptor
13+
class LoggingInterceptor(Interceptor):
14+
@override
15+
def before_request(self, request: InterceptorRequest) -> InterceptorRequest:
16+
print(f"Request: {request.method} {request.url}")
17+
print(f"Headers: {request.headers}")
18+
if request.body:
19+
print(f"Body: {request.body}")
20+
return request
21+
22+
@override
23+
def after_response(self, response: InterceptorResponse[Any]) -> InterceptorResponse[Any]:
24+
print(f"Response Status: {response.status_code}")
25+
print(f"Response Headers: {response.headers}")
26+
print(f"Response Body: {response.body}")
27+
return response
28+
29+
# Define an interceptor that implements retry logic with exponential backoff
30+
class RetryInterceptor(Interceptor):
31+
def __init__(self, max_retries: int = 3, initial_delay: float = 1.0):
32+
self.max_retries = max_retries
33+
self.initial_delay = initial_delay
34+
self.current_retry = 0
35+
36+
@override
37+
def before_request(self, request: InterceptorRequest) -> InterceptorRequest:
38+
return request
39+
40+
@override
41+
def after_response(self, response: InterceptorResponse[Any]) -> InterceptorResponse[Any]:
42+
# If response is successful or we've exhausted retries, return as is
43+
if response.status_code < 500 or self.current_retry >= self.max_retries:
44+
self.current_retry = 0 # Reset for next request
45+
return response
46+
47+
# Calculate delay with exponential backoff
48+
delay = self.initial_delay * (2 ** self.current_retry)
49+
print(f"Request failed with status {response.status_code}. Retrying in {delay} seconds...")
50+
time.sleep(delay)
51+
52+
self.current_retry += 1
53+
# Trigger a retry by raising an exception
54+
raise Exception(f"Retrying request (attempt {self.current_retry}/{self.max_retries})")
55+
56+
# Initialize the OpenAI client and add interceptors
57+
if __name__ == "__main__":
58+
# Create the interceptor chain
59+
logging_interceptor = LoggingInterceptor()
60+
retry_interceptor = RetryInterceptor(max_retries=3, initial_delay=1.0)
61+
62+
# Create client with interceptors
63+
client = OpenAI(
64+
interceptors=[logging_interceptor, retry_interceptor]
65+
)
66+
67+
# Make a request using the client
68+
response = client.chat.completions.create(
69+
model="gpt-3.5-turbo",
70+
messages=[{"role": "user", "content": "Tell me about error handling and retries in software systems."}],
71+
max_tokens=100,
72+
)
73+
74+
# Output the final response
75+
print("Final Response:", response)

src/openai/_base_client.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@
8989
)
9090
from ._legacy_response import LegacyAPIResponse
9191

92+
from ._interceptor import InterceptorChain, InterceptorRequest, InterceptorResponse, Interceptor
93+
9294
log: logging.Logger = logging.getLogger(__name__)
9395
log.addFilter(SensitiveHeadersFilter())
9496

@@ -339,6 +341,8 @@ class BaseClient(Generic[_HttpxClientT, _DefaultStreamT]):
339341
_strict_response_validation: bool
340342
_idempotency_header: str | None
341343
_default_stream_cls: type[_DefaultStreamT] | None = None
344+
_interceptor_chain: InterceptorChain
345+
342346

343347
def __init__(
344348
self,
@@ -353,6 +357,8 @@ def __init__(
353357
proxies: ProxiesTypes | None,
354358
custom_headers: Mapping[str, str] | None = None,
355359
custom_query: Mapping[str, object] | None = None,
360+
interceptors: list[Interceptor] | None = None,
361+
356362
) -> None:
357363
self._version = version
358364
self._base_url = self._enforce_trailing_slash(URL(base_url))
@@ -372,7 +378,9 @@ def __init__(
372378
"max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`"
373379
)
374380

375-
def _enforce_trailing_slash(self, url: URL) -> URL:
381+
382+
self._interceptor_chain = InterceptorChain(interceptors)
383+
def _enforce_trailing_slash(self, url: URL) -> URL:
376384
if url.raw_path.endswith(b"/"):
377385
return url
378386
return url.copy_with(raw_path=url.raw_path + b"/")
@@ -463,8 +471,27 @@ def _build_request(
463471
else:
464472
raise RuntimeError(f"Unexpected JSON data type, {type(json_data)}, cannot merge with `extra_body`")
465473

474+
# Build base headers and params
466475
headers = self._build_headers(options, retries_taken=retries_taken)
467476
params = _merge_mappings(self.default_query, options.params)
477+
prepared_url = self._prepare_url(options.url)
478+
479+
# Execute request interceptors
480+
interceptor_request = InterceptorRequest(
481+
method=options.method,
482+
url=str(prepared_url),
483+
headers=dict(headers),
484+
params=dict(params),
485+
body=json_data,
486+
)
487+
interceptor_request = self._interceptor_chain.execute_before_request(interceptor_request)
488+
489+
# Apply interceptor modifications
490+
headers = httpx.Headers(interceptor_request.headers)
491+
params = interceptor_request.params or {}
492+
json_data = interceptor_request.body
493+
prepared_url = URL(interceptor_request.url)
494+
468495
content_type = headers.get("Content-Type")
469496
files = options.files
470497

@@ -506,7 +533,7 @@ def _build_request(
506533
return self._client.build_request( # pyright: ignore[reportUnknownMemberType]
507534
headers=headers,
508535
timeout=self.timeout if isinstance(options.timeout, NotGiven) else options.timeout,
509-
method=options.method,
536+
method=interceptor_request.method,
510537
url=prepared_url,
511538
# the `Query` type that we use is incompatible with qs'
512539
# `Params` type as it needs to be typed as `Mapping[str, object]`
@@ -582,6 +609,22 @@ def _process_response_data(
582609
return cast(ResponseT, data)
583610

584611
try:
612+
# Create InterceptorResponse and execute interceptors
613+
interceptor_response = InterceptorResponse(
614+
status_code=response.status_code,
615+
headers=dict(response.headers),
616+
body=data,
617+
request=InterceptorRequest(
618+
method=response.request.method,
619+
url=str(response.request.url),
620+
headers=dict(response.request.headers),
621+
body=response.request._content if response.request._content else None,
622+
),
623+
raw_response=response,
624+
)
625+
interceptor_response = self._interceptor_chain.execute_after_response(interceptor_response)
626+
data = interceptor_response.body
627+
585628
if inspect.isclass(cast_to) and issubclass(cast_to, ModelBuilderProtocol):
586629
return cast(ResponseT, cast_to.build(response=response, data=data))
587630

@@ -796,6 +839,9 @@ def __init__(
796839
custom_headers: Mapping[str, str] | None = None,
797840
custom_query: Mapping[str, object] | None = None,
798841
_strict_response_validation: bool,
842+
843+
interceptors: list[Interceptor] | None = None,
844+
799845
) -> None:
800846
kwargs: dict[str, Any] = {}
801847
if limits is not None:
@@ -859,6 +905,9 @@ def __init__(
859905
custom_query=custom_query,
860906
custom_headers=custom_headers,
861907
_strict_response_validation=_strict_response_validation,
908+
909+
interceptors=interceptors,
910+
862911
)
863912
self._client = http_client or SyncHttpxClientWrapper(
864913
base_url=base_url,
@@ -1382,6 +1431,8 @@ def __init__(
13821431
http_client: httpx.AsyncClient | None = None,
13831432
custom_headers: Mapping[str, str] | None = None,
13841433
custom_query: Mapping[str, object] | None = None,
1434+
1435+
interceptors: list[Interceptor] | None = None,
13851436
) -> None:
13861437
kwargs: dict[str, Any] = {}
13871438
if limits is not None:
@@ -1445,6 +1496,9 @@ def __init__(
14451496
custom_query=custom_query,
14461497
custom_headers=custom_headers,
14471498
_strict_response_validation=_strict_response_validation,
1499+
1500+
interceptors=interceptors,
1501+
14481502
)
14491503
self._client = http_client or AsyncHttpxClientWrapper(
14501504
base_url=base_url,

src/openai/_client.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing_extensions import Self, override
88

99
import httpx
10+
from openai._interceptor import Interceptor
1011

1112
from . import _exceptions
1213
from ._qs import Querystring
@@ -96,7 +97,8 @@ def __init__(
9697
# outlining your use-case to help us decide if it should be
9798
# part of our public interface in the future.
9899
_strict_response_validation: bool = False,
99-
) -> None:
100+
interceptors: list[Interceptor] | None = None,
101+
) -> None:
100102
"""Construct a new synchronous openai client instance.
101103
102104
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
@@ -136,7 +138,8 @@ def __init__(
136138
custom_headers=default_headers,
137139
custom_query=default_query,
138140
_strict_response_validation=_strict_response_validation,
139-
)
141+
interceptors=interceptors,
142+
)
140143

141144
self._default_stream_cls = Stream
142145

@@ -192,6 +195,7 @@ def copy(
192195
set_default_headers: Mapping[str, str] | None = None,
193196
default_query: Mapping[str, object] | None = None,
194197
set_default_query: Mapping[str, object] | None = None,
198+
interceptors: list[Interceptor] | None = None,
195199
_extra_kwargs: Mapping[str, Any] = {},
196200
) -> Self:
197201
"""
@@ -227,6 +231,7 @@ def copy(
227231
max_retries=max_retries if is_given(max_retries) else self.max_retries,
228232
default_headers=headers,
229233
default_query=params,
234+
interceptors=interceptors,
230235
**_extra_kwargs,
231236
)
232237

@@ -323,7 +328,8 @@ def __init__(
323328
# outlining your use-case to help us decide if it should be
324329
# part of our public interface in the future.
325330
_strict_response_validation: bool = False,
326-
) -> None:
331+
interceptors: list[Interceptor] | None = None,
332+
) -> None:
327333
"""Construct a new async openai client instance.
328334
329335
This automatically infers the following arguments from their corresponding environment variables if they are not provided:
@@ -363,6 +369,7 @@ def __init__(
363369
custom_headers=default_headers,
364370
custom_query=default_query,
365371
_strict_response_validation=_strict_response_validation,
372+
interceptors=interceptors,
366373
)
367374

368375
self._default_stream_cls = AsyncStream
@@ -419,6 +426,7 @@ def copy(
419426
set_default_headers: Mapping[str, str] | None = None,
420427
default_query: Mapping[str, object] | None = None,
421428
set_default_query: Mapping[str, object] | None = None,
429+
interceptors: list[Interceptor] | None = None,
422430
_extra_kwargs: Mapping[str, Any] = {},
423431
) -> Self:
424432
"""
@@ -454,6 +462,7 @@ def copy(
454462
max_retries=max_retries if is_given(max_retries) else self.max_retries,
455463
default_headers=headers,
456464
default_query=params,
465+
interceptors=interceptors,
457466
**_extra_kwargs,
458467
)
459468

src/openai/_interceptor.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
from abc import ABC, abstractmethod
4+
from dataclasses import dataclass
5+
from typing import Dict, Optional, TypeVar, Generic, Any, Union
6+
7+
import httpx
8+
9+
from ._types import Body
10+
11+
T = TypeVar("T")
12+
13+
@dataclass
14+
class InterceptorRequest:
15+
"""Container for request data that can be modified by interceptors"""
16+
method: str
17+
url: str
18+
headers: Dict[str, str]
19+
params: Optional[Dict[str, Any]] = None
20+
body: Optional[Union[Body, bytes]] = None
21+
22+
@dataclass
23+
class InterceptorResponse(Generic[T]):
24+
"""Container for response data that can be processed by interceptors"""
25+
status_code: int
26+
headers: Dict[str, str]
27+
body: T
28+
request: InterceptorRequest
29+
raw_response: httpx.Response
30+
31+
class Interceptor(ABC):
32+
"""Base class for implementing request/response interceptors"""
33+
34+
@abstractmethod
35+
def before_request(self, request: InterceptorRequest) -> InterceptorRequest:
36+
"""Process and optionally modify the request before it is sent.
37+
38+
Args:
39+
request: The request to process
40+
41+
Returns:
42+
The processed request
43+
"""
44+
pass
45+
46+
@abstractmethod
47+
def after_response(self, response: InterceptorResponse[T]) -> InterceptorResponse[T]:
48+
"""Process and optionally modify the response after it is received.
49+
50+
Args:
51+
response: The response to process
52+
53+
Returns:
54+
The processed response
55+
"""
56+
pass
57+
58+
class InterceptorChain:
59+
"""Manages a chain of interceptors that process requests/responses in sequence"""
60+
61+
def __init__(self, interceptors: Optional[list[Interceptor]] = None):
62+
self._interceptors = interceptors or []
63+
64+
def add_interceptor(self, interceptor: Interceptor) -> None:
65+
"""Add an interceptor to the chain"""
66+
self._interceptors.append(interceptor)
67+
68+
def execute_before_request(self, request: InterceptorRequest) -> InterceptorRequest:
69+
"""Execute all interceptors' before_request methods in sequence"""
70+
print("\n=== Intercepted Request ===")
71+
print(f"Method: {request.method}")
72+
print(f"URL: {request.url}")
73+
print(f"Headers: {request.headers}")
74+
if request.params:
75+
print(f"Query Params: {request.params}")
76+
if request.body:
77+
print(f"Request Body: {request.body}")
78+
print("========================\n")
79+
80+
current_request = request
81+
for interceptor in self._interceptors:
82+
try:
83+
current_request = interceptor.before_request(current_request)
84+
except Exception as e:
85+
# Log error but continue processing
86+
print(f"Error in interceptor {interceptor.__class__.__name__}: {e}")
87+
return current_request
88+
89+
def execute_after_response(self, response: InterceptorResponse[T]) -> InterceptorResponse[T]:
90+
"""Execute all interceptors' after_response methods in sequence"""
91+
print("\n=== Intercepted Response ===")
92+
print(f"Status Code: {response.status_code}")
93+
print(f"Headers: {response.headers}")
94+
print(f"Response Body: {response.body}")
95+
print("=========================\n")
96+
97+
current_response = response
98+
for interceptor in self._interceptors:
99+
try:
100+
current_response = interceptor.after_response(current_response)
101+
except Exception as e:
102+
# Log error but continue processing
103+
print(f"Error in interceptor {interceptor.__class__.__name__}: {e}")
104+
return current_response

0 commit comments

Comments
 (0)