Skip to content

Commit 6d1b083

Browse files
committed
Refactor for httpx 0.16
1 parent 858b9a3 commit 6d1b083

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

Diff for: httpx_gssapi/exceptions.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,20 @@
55
This module contains the set of exceptions.
66
77
"""
8-
from httpx import HTTPError
8+
from httpx import HTTPError, Request, Response
99

1010

1111
class MutualAuthenticationError(HTTPError):
1212
"""Mutual Authentication Error"""
13-
def __str__(self):
14-
return f"Unable to authenticate {self.response}"
1513

16-
def __repr__(self):
17-
return f"{__class__.__name__}('{self}')"
14+
def __init__(self, *,
15+
request: Request = None,
16+
response: Response):
17+
self.response = response
18+
super().__init__(
19+
f"Unable to authenticate {self.response}",
20+
request=request or self.response.request,
21+
)
1822

1923

2024
class SPNEGOExchangeError(HTTPError):

Diff for: httpx_gssapi/gssapi_.py

+35-19
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import re
22
import logging
3+
from itertools import chain
34
from functools import wraps
45
from typing import Generator, Optional, List, Any
56

67
from base64 import b64encode, b64decode
78

89
import gssapi
9-
from gssapi import SecurityContext as SecCtx
10+
from gssapi import SecurityContext
1011
from gssapi.exceptions import GSSError
1112

1213
import httpx
@@ -60,7 +61,7 @@ def _sanitize_response(response: Response):
6061
response.headers[header] = headers[header]
6162

6263

63-
def _handle_gsserror(*, gss_stage: str, result: Any):
64+
def _handle_gsserror(*, gss_stage: str, result: Any = ...):
6465
"""
6566
Decorator to handle GSSErrors and properly log them against the decorated
6667
function's name.
@@ -69,8 +70,9 @@ def _handle_gsserror(*, gss_stage: str, result: Any):
6970
Name of GSS stage that the function is handling. Typically either
7071
'initializing' or 'stepping'.
7172
:param result:
72-
The result to return if a GSSError is raised. If it's an Exception
73-
type, then it will be raised with the logged message.
73+
The result to return if a GSSError is raised. If the result is a
74+
callable, it will be called first with the message and all args
75+
and kwargs.
7476
"""
7577

7678
def _decor(func):
@@ -81,15 +83,27 @@ def _wrapper(*args, **kwargs):
8183
except GSSError as error:
8284
msg = f"{gss_stage} context failed: {error.gen_message()}"
8385
log.exception(f"{func.__name__}(): {msg}")
84-
if isinstance(result, type) and issubclass(result, Exception):
85-
raise result(msg)
86+
if callable(result):
87+
return result(msg, *args, **kwargs)
8688
return result
8789

8890
return _wrapper
8991

9092
return _decor
9193

9294

95+
def _gss_to_spnego_error(message: str, *args: Any, **kwargs: Any):
96+
"""Helper function to _handle_gsserror to raise SPNEGOExchangeErrors."""
97+
try:
98+
request = next(
99+
a for a in chain(args, kwargs.values())
100+
if isinstance(a, Request)
101+
)
102+
except StopIteration: # sanity check
103+
raise RuntimeError("No request in arguments!")
104+
raise SPNEGOExchangeError(message, request=request)
105+
106+
93107
class HTTPSPNEGOAuth(Auth):
94108
"""
95109
Attaches HTTP GSSAPI Authentication to the given Request object.
@@ -146,7 +160,7 @@ def auth_flow(self, request: Request) -> FlowGen:
146160

147161
def handle_response(self,
148162
response: Response,
149-
ctx: SecCtx = None) -> FlowGen:
163+
ctx: SecurityContext = None) -> FlowGen:
150164
num_401s = 0
151165
while response.status_code == 401 and num_401s < 2:
152166
num_401s += 1
@@ -171,7 +185,7 @@ def handle_response(self,
171185

172186
self.handle_mutual_auth(response, ctx)
173187

174-
def handle_mutual_auth(self, response: Response, ctx: SecCtx):
188+
def handle_mutual_auth(self, response: Response, ctx: SecurityContext):
175189
"""
176190
Handles all responses with the exception of 401s.
177191
@@ -212,18 +226,18 @@ def handle_mutual_auth(self, response: Response, ctx: SecCtx):
212226
log.error("handle_other(): Mutual authentication failed")
213227
raise MutualAuthenticationError(response=response)
214228

215-
@_handle_gsserror(gss_stage='stepping', result=SPNEGOExchangeError)
229+
@_handle_gsserror(gss_stage='stepping', result=_gss_to_spnego_error)
216230
def set_auth_header(self,
217231
request: Request,
218-
response: Response = None) -> SecCtx:
232+
response: Response = None) -> SecurityContext:
219233
"""
220234
Create a new security context, generate the GSSAPI authentication
221235
token, and insert it into the request header. The new security context
222236
will be returned.
223237
224238
If any GSSAPI step fails, raise SPNEGOExchangeError with failure detail.
225239
"""
226-
ctx = self._make_context(request.url.host)
240+
ctx = self._make_context(request)
227241

228242
token = _negotiate_value(response) if response is not None else None
229243
gss_resp = ctx.step(token or None)
@@ -238,7 +252,9 @@ def set_auth_header(self,
238252
return ctx
239253

240254
@_handle_gsserror(gss_stage="stepping", result=False)
241-
def authenticate_server(self, response: Response, ctx: SecCtx) -> bool:
255+
def authenticate_server(self,
256+
response: Response,
257+
ctx: SecurityContext) -> bool:
242258
"""
243259
Uses GSSAPI to authenticate the server by extracting the negotiate
244260
value from the response and stepping the security context.
@@ -254,22 +270,22 @@ def authenticate_server(self, response: Response, ctx: SecCtx) -> bool:
254270
log.debug("authenticate_server(): authentication successful")
255271
return True
256272

257-
@_handle_gsserror(gss_stage="initializing", result=SPNEGOExchangeError)
258-
def _make_context(self, host: str) -> SecCtx:
273+
@_handle_gsserror(gss_stage='initializing', result=_gss_to_spnego_error)
274+
def _make_context(self, request: Request) -> SecurityContext:
259275
"""
260276
Create a GSSAPI security context for handling the authentication.
261277
262-
:param host:
263-
Hostname to create context for. Only used if it isn't included
264-
in :py:attr:`target_name`
278+
:param request:
279+
Request to make the context for. The hostname from it is
280+
used if it isn't included in :py:attr:`target_name`.
265281
"""
266282
name = self.target_name
267283
if type(name) != gssapi.Name: # type(name) is str
268284
if '@' not in name:
269-
name += f"@{host}"
285+
name += f"@{request.url.host}"
270286
name = gssapi.Name(name, gssapi.NameType.hostbased_service)
271287

272-
return SecCtx(
288+
return SecurityContext(
273289
usage="initiate",
274290
flags=self._gssflags,
275291
name=name,

0 commit comments

Comments
 (0)