1
1
import re
2
2
import logging
3
+ from itertools import chain
3
4
from functools import wraps
4
5
from typing import Generator , Optional , List , Any
5
6
6
7
from base64 import b64encode , b64decode
7
8
8
9
import gssapi
9
- from gssapi import SecurityContext as SecCtx
10
+ from gssapi import SecurityContext
10
11
from gssapi .exceptions import GSSError
11
12
12
13
import httpx
@@ -60,7 +61,7 @@ def _sanitize_response(response: Response):
60
61
response .headers [header ] = headers [header ]
61
62
62
63
63
- def _handle_gsserror (* , gss_stage : str , result : Any ):
64
+ def _handle_gsserror (* , gss_stage : str , result : Any = ... ):
64
65
"""
65
66
Decorator to handle GSSErrors and properly log them against the decorated
66
67
function's name.
@@ -69,8 +70,9 @@ def _handle_gsserror(*, gss_stage: str, result: Any):
69
70
Name of GSS stage that the function is handling. Typically either
70
71
'initializing' or 'stepping'.
71
72
:param result:
72
- The result to return if a GSSError is raised. If it's an Exception
73
- type, then it will be raised with the logged message.
73
+ The result to return if a GSSError is raised. If the result is a
74
+ callable, it will be called first with the message and all args
75
+ and kwargs.
74
76
"""
75
77
76
78
def _decor (func ):
@@ -81,15 +83,27 @@ def _wrapper(*args, **kwargs):
81
83
except GSSError as error :
82
84
msg = f"{ gss_stage } context failed: { error .gen_message ()} "
83
85
log .exception (f"{ func .__name__ } (): { msg } " )
84
- if isinstance (result , type ) and issubclass ( result , Exception ):
85
- raise result (msg )
86
+ if callable (result ):
87
+ return result (msg , * args , ** kwargs )
86
88
return result
87
89
88
90
return _wrapper
89
91
90
92
return _decor
91
93
92
94
95
+ def _gss_to_spnego_error (message : str , * args : Any , ** kwargs : Any ):
96
+ """Helper function to _handle_gsserror to raise SPNEGOExchangeErrors."""
97
+ try :
98
+ request = next (
99
+ a for a in chain (args , kwargs .values ())
100
+ if isinstance (a , Request )
101
+ )
102
+ except StopIteration : # sanity check
103
+ raise RuntimeError ("No request in arguments!" )
104
+ raise SPNEGOExchangeError (message , request = request )
105
+
106
+
93
107
class HTTPSPNEGOAuth (Auth ):
94
108
"""
95
109
Attaches HTTP GSSAPI Authentication to the given Request object.
@@ -146,7 +160,7 @@ def auth_flow(self, request: Request) -> FlowGen:
146
160
147
161
def handle_response (self ,
148
162
response : Response ,
149
- ctx : SecCtx = None ) -> FlowGen :
163
+ ctx : SecurityContext = None ) -> FlowGen :
150
164
num_401s = 0
151
165
while response .status_code == 401 and num_401s < 2 :
152
166
num_401s += 1
@@ -171,7 +185,7 @@ def handle_response(self,
171
185
172
186
self .handle_mutual_auth (response , ctx )
173
187
174
- def handle_mutual_auth (self , response : Response , ctx : SecCtx ):
188
+ def handle_mutual_auth (self , response : Response , ctx : SecurityContext ):
175
189
"""
176
190
Handles all responses with the exception of 401s.
177
191
@@ -212,18 +226,18 @@ def handle_mutual_auth(self, response: Response, ctx: SecCtx):
212
226
log .error ("handle_other(): Mutual authentication failed" )
213
227
raise MutualAuthenticationError (response = response )
214
228
215
- @_handle_gsserror (gss_stage = 'stepping' , result = SPNEGOExchangeError )
229
+ @_handle_gsserror (gss_stage = 'stepping' , result = _gss_to_spnego_error )
216
230
def set_auth_header (self ,
217
231
request : Request ,
218
- response : Response = None ) -> SecCtx :
232
+ response : Response = None ) -> SecurityContext :
219
233
"""
220
234
Create a new security context, generate the GSSAPI authentication
221
235
token, and insert it into the request header. The new security context
222
236
will be returned.
223
237
224
238
If any GSSAPI step fails, raise SPNEGOExchangeError with failure detail.
225
239
"""
226
- ctx = self ._make_context (request . url . host )
240
+ ctx = self ._make_context (request )
227
241
228
242
token = _negotiate_value (response ) if response is not None else None
229
243
gss_resp = ctx .step (token or None )
@@ -238,7 +252,9 @@ def set_auth_header(self,
238
252
return ctx
239
253
240
254
@_handle_gsserror (gss_stage = "stepping" , result = False )
241
- def authenticate_server (self , response : Response , ctx : SecCtx ) -> bool :
255
+ def authenticate_server (self ,
256
+ response : Response ,
257
+ ctx : SecurityContext ) -> bool :
242
258
"""
243
259
Uses GSSAPI to authenticate the server by extracting the negotiate
244
260
value from the response and stepping the security context.
@@ -254,22 +270,22 @@ def authenticate_server(self, response: Response, ctx: SecCtx) -> bool:
254
270
log .debug ("authenticate_server(): authentication successful" )
255
271
return True
256
272
257
- @_handle_gsserror (gss_stage = " initializing" , result = SPNEGOExchangeError )
258
- def _make_context (self , host : str ) -> SecCtx :
273
+ @_handle_gsserror (gss_stage = ' initializing' , result = _gss_to_spnego_error )
274
+ def _make_context (self , request : Request ) -> SecurityContext :
259
275
"""
260
276
Create a GSSAPI security context for handling the authentication.
261
277
262
- :param host :
263
- Hostname to create context for. Only used if it isn't included
264
- in :py:attr:`target_name`
278
+ :param request :
279
+ Request to make the context for. The hostname from it is
280
+ used if it isn't included in :py:attr:`target_name`.
265
281
"""
266
282
name = self .target_name
267
283
if type (name ) != gssapi .Name : # type(name) is str
268
284
if '@' not in name :
269
- name += f"@{ host } "
285
+ name += f"@{ request . url . host } "
270
286
name = gssapi .Name (name , gssapi .NameType .hostbased_service )
271
287
272
- return SecCtx (
288
+ return SecurityContext (
273
289
usage = "initiate" ,
274
290
flags = self ._gssflags ,
275
291
name = name ,
0 commit comments