Skip to content
This repository was archived by the owner on Jun 12, 2021. It is now read-only.

Change model for persistent storage. #13

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .coverage

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
sudo: false
dist: bionic
language: python
python:
- 3.6
- 3.7
- 3.8
- 3.9
- pypy3
addons:
apt:
packages:
-
- rustc
- cargo
install:
- pip install --upgrade pip
- pip install codecov
- pip install tox
- pip install isort
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"Topic :: Software Development :: Libraries :: Python Modules"],
install_requires=[
"pyyaml>=5.1.0",
'oidcmsg>=1.1.0',
'oidcmsg>=1.2.1',
'requests'
],
tests_require=[
Expand Down
2 changes: 1 addition & 1 deletion src/oidcservice/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import random as rnd

__author__ = 'Roland Hedberg'
__version__ = '1.1.1'
__version__ = '1.2.0'

OIDCONF_PATTERN = "{}/.well-known/openid-configuration"
CC_METHOD = {
Expand Down
47 changes: 26 additions & 21 deletions src/oidcservice/client_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from urllib.parse import quote_plus

from cryptojwt.exception import MissingKey
from cryptojwt.exception import UnsupportedAlgorithm
from cryptojwt.jws.jws import SIGNER_ALGS
from cryptojwt.jws.utils import alg2keytype
from oidcmsg.message import VREQUIRED
Expand Down Expand Up @@ -87,15 +88,15 @@ def _get_passwd(request, service, **kwargs):
try:
passwd = request["client_secret"]
except KeyError:
passwd = service.service_context.get('client_secret')
passwd = service.service_context.client_secret
return passwd

@staticmethod
def _get_user(service, **kwargs):
try:
user = kwargs["user"]
except KeyError:
user = service.service_context.get('client_id')
user = service.service_context.client_id
return user

def _get_authentication_token(self, request, service, **kwargs):
Expand Down Expand Up @@ -128,7 +129,7 @@ def _with_or_without_client_id(request, service):
'grant_type'] == 'authorization_code':
if 'client_id' not in request:
try:
request['client_id'] = service.service_context.get('client_id')
request['client_id'] = service.service_context.client_id
except AttributeError:
pass
else:
Expand Down Expand Up @@ -210,13 +211,13 @@ def modify_request(self, request, service, **kwargs):
try:
request["client_secret"] = kwargs["client_secret"]
except (KeyError, TypeError):
if _context.get('client_secret'):
request["client_secret"] = _context.get('client_secret')
if _context.client_secret:
request["client_secret"] = _context.client_secret
else:
raise AuthnFailure("Missing client secret")

# Add the client_id to the request
request["client_id"] = _context.get('client_id')
# Set the client_id in the the request
request["client_id"] = _context.client_id

def construct(self, request, service=None, http_args=None, **kwargs):
"""
Expand Down Expand Up @@ -263,7 +264,7 @@ def find_token(request, token_type, service, **kwargs):
except KeyError:
# I should pick the latest acquired token, this should be the right
# order for that.
_arg = service.multiple_extend_request_args(
_arg = service.service_context.state.multiple_extend_request_args(
{}, kwargs['key'], ['access_token'],
['auth_response', 'token_response', 'refresh_token_response'])
return _arg['access_token']
Expand Down Expand Up @@ -413,13 +414,14 @@ def _get_key_by_kid(kid, algorithm, service_context):
:py:class:`oidcservice.service_context.ServiceContext` instance
:return: A matching key
"""
_key = service_context.keyjar.get_key_by_kid(kid)
if _key:
ktype = alg2keytype(algorithm)
if _key.kty != ktype:
raise MissingKey("Wrong key type")
# signing so using my keys
for _key in service_context.keyjar.get_issuer_keys(""):
if kid == _key.kid:
ktype = alg2keytype(algorithm)
if _key.kty != ktype:
raise MissingKey("Wrong key type")

return _key
return _key

raise MissingKey("No key with kid:%s" % kid)

Expand Down Expand Up @@ -448,13 +450,13 @@ def _get_audience_and_algorithm(self, context, **kwargs):
# audience for the signed JWT depends on which endpoint
# we're talking to.
if 'authn_endpoint' in kwargs and kwargs['authn_endpoint'] in ['token_endpoint']:
reg_resp = context.get("registration_response")
reg_resp = context.registration_response
if reg_resp:
algorithm = reg_resp.get("token_endpoint_auth_signing_alg")
algorithm = reg_resp["token_endpoint_auth_signing_alg"]
else:
algorithm = context.client_preferences.get("token_endpoint_auth_signing_alg")
if algorithm is None:
_pi = context.get("provider_info")
_pi = context.provider_info
try:
algs = _pi["token_endpoint_auth_signing_alg_values_supported"]
except KeyError:
Expand All @@ -466,9 +468,9 @@ def _get_audience_and_algorithm(self, context, **kwargs):
algorithm = alg
break

audience = context.get('provider_info')['token_endpoint']
audience = context.provider_info['token_endpoint']
else:
audience = context.get('provider_info')['issuer']
audience = context.provider_info['issuer']

if not algorithm:
algorithm = self.choose_algorithm(**kwargs)
Expand All @@ -484,14 +486,17 @@ def _construct_client_assertion(self, service, **kwargs):
else:
signing_key = self._get_signing_key(algorithm, _context)

if not signing_key:
raise UnsupportedAlgorithm(algorithm)

try:
_args = {'lifetime': kwargs['lifetime']}
except KeyError:
_args = {}

# construct the signed JWT with the assertions and add
# it as value to the 'client_assertion' claim of the request
return assertion_jwt(_context.get('client_id'), signing_key, audience, algorithm, **_args)
return assertion_jwt(_context.client_id, signing_key, audience, algorithm, **_args)

def modify_request(self, request, service, **kwargs):
"""
Expand Down Expand Up @@ -593,7 +598,7 @@ def valid_service_context(service_context, when=0):
:param when: A time stamp against which the expiration time is to be checked
:return: True if the client_secret is still valid
"""
eta = service_context.get('client_secret_expires_at', 0)
eta = service_context.client_secret_expires_at
now = when or utc_time_sans_frac()
if eta != 0 and eta < now:
return False
Expand Down
10 changes: 5 additions & 5 deletions src/oidcservice/oauth2/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, service_context, client_authn_factory=None, conf=None):
def update_service_context(self, resp, key='', **kwargs):
if 'expires_in' in resp:
resp['__expires_at'] = time_sans_frac() + int(resp['expires_in'])
self.store_item(resp, 'token_response', key)
self.service_context.state.store_item(resp, 'token_response', key)

def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs):
"""
Expand All @@ -44,11 +44,11 @@ def oauth_pre_construct(self, request_args=None, post_args=None, **kwargs):
_state = get_state_parameter(request_args, kwargs)
parameters = list(self.msg_type.c_param.keys())

_args = self.extend_request_args({}, oauth2.AuthorizationRequest,
'auth_request', _state, parameters)
_args = self.service_context.state.extend_request_args({}, oauth2.AuthorizationRequest,
'auth_request', _state, parameters)

_args = self.extend_request_args(_args, oauth2.AuthorizationResponse,
'auth_response', _state, parameters)
_args = self.service_context.state.extend_request_args(_args, oauth2.AuthorizationResponse,
'auth_response', _state, parameters)

if "grant_type" not in _args:
_args["grant_type"] = "authorization_code"
Expand Down
15 changes: 8 additions & 7 deletions src/oidcservice/oauth2/authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from oidcmsg.oauth2 import ResponseMessage
from oidcmsg.time_util import time_sans_frac

from oidcservice.oauth2.utils import (get_state_parameter, pick_redirect_uris,
set_state_parameter)
from oidcservice.oauth2.utils import get_state_parameter
from oidcservice.oauth2.utils import pick_redirect_uris
from oidcservice.oauth2.utils import set_state_parameter
from oidcservice.service import Service

LOGGER = logging.getLogger(__name__)
Expand All @@ -32,20 +33,20 @@ def __init__(self, service_context, client_authn_factory=None, conf=None):
def update_service_context(self, resp, key='', **kwargs):
if 'expires_in' in resp:
resp['__expires_at'] = time_sans_frac() + int(resp['expires_in'])
self.store_item(resp, 'auth_response', key)
self.service_context.state.store_item(resp, 'auth_response', key)

def store_auth_request(self, request_args=None, **kwargs):
"""Store the authorization request in the state DB."""
_key = get_state_parameter(request_args, kwargs)
self.store_item(request_args, 'auth_request', _key)
self.service_context.state.store_item(request_args, 'auth_request', _key)
return request_args

def gather_request_args(self, **kwargs):
ar_args = Service.gather_request_args(self, **kwargs)

if 'redirect_uri' not in ar_args:
try:
ar_args['redirect_uri'] = self.service_context.get('redirect_uris')[0]
ar_args['redirect_uri'] = self.service_context.redirect_uris[0]
except (KeyError, AttributeError):
raise MissingParameter('redirect_uri')

Expand All @@ -68,8 +69,8 @@ def post_parse_response(self, response, **kwargs):
pass
else:
if _key:
item = self.get_item(oauth2.AuthorizationRequest,
'auth_request', _key)
item = self.service_context.state.get_item(oauth2.AuthorizationRequest,
'auth_request', _key)
try:
response["scope"] = item["scope"]
except KeyError:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ def __init__(self, service_context, client_authn_factory=None, conf=None):
def update_service_context(self, resp, key='cc', **kwargs):
if 'expires_in' in resp:
resp['__expires_at'] = time_sans_frac() + int(resp['expires_in'])
self.store_item(resp, 'token_response', key)
self.service_context.state.store_item(resp, 'token_response', key)
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ def __init__(self, service_context, client_authn_factory=None, conf=None):
self.post_construct.append(self.cc_post_construct)

def cc_pre_construct(self, request_args=None, **kwargs):
_state_id = kwargs.get("state", "cc")
parameters = ['refresh_token']
_args = self.extend_request_args({}, oauth2.AccessTokenResponse,
'token_response', 'cc', parameters)
_state_interface = self.service_context.state
_args = _state_interface.extend_request_args({}, oauth2.AccessTokenResponse,
'token_response', _state_id, parameters)

_args = self.extend_request_args(_args, oauth2.AccessTokenResponse,
'refresh_token_response', 'cc',
parameters)
_args = _state_interface.extend_request_args(_args, oauth2.AccessTokenResponse,
'refresh_token_response', _state_id,
parameters)

if request_args is None:
request_args = _args
Expand All @@ -50,4 +52,4 @@ def cc_post_construct(self, request_args, **kwargs):
def update_service_context(self, resp, key='cc', **kwargs):
if 'expires_in' in resp:
resp['__expires_at'] = time_sans_frac() + int(resp['expires_in'])
self.store_item(resp, 'token_response', key)
self.service_context.state.store_item(resp, 'token_response', key)
10 changes: 5 additions & 5 deletions src/oidcservice/oauth2/provider_info_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_endpoint(self):
:return: Service endpoint
"""
try:
_iss = self.service_context.get('issuer')
_iss = self.service_context.issuer
except AttributeError:
_iss = self.endpoint

Expand Down Expand Up @@ -115,12 +115,12 @@ def _update_service_context(self, resp):
# url that was used as service endpoint (without the .well-known part)
if "issuer" in resp:
_pcr_issuer = self._verify_issuer(resp,
self.service_context.get('issuer'))
self.service_context.issuer)
else: # No prior knowledge
_pcr_issuer = self.service_context.get('issuer')
_pcr_issuer = self.service_context.issuer

self.service_context.set('issuer', _pcr_issuer)
self.service_context.set('provider_info', resp)
self.service_context.issuer= _pcr_issuer
self.service_context.provider_info= resp

self._set_endpoints(resp)

Expand Down
13 changes: 7 additions & 6 deletions src/oidcservice/oauth2/refresh_access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,20 @@ def __init__(self, service_context, client_authn_factory=None, conf=None):
def update_service_context(self, resp, key='', **kwargs):
if 'expires_in' in resp:
resp['__expires_at'] = time_sans_frac() + int(resp['expires_in'])
self.store_item(resp, 'token_response', key)
self.service_context.state.store_item(resp, 'token_response', key)

def oauth_pre_construct(self, request_args=None, **kwargs):
"""Preconstructor of request arguments"""
_state = get_state_parameter(request_args, kwargs)
parameters = list(self.msg_type.c_param.keys())

_args = self.extend_request_args({}, oauth2.AccessTokenResponse,
'token_response', _state, parameters)
_si = self.service_context.state
_args = _si.extend_request_args({}, oauth2.AccessTokenResponse,
'token_response', _state, parameters)

_args = self.extend_request_args(_args, oauth2.AccessTokenResponse,
'refresh_token_response', _state,
parameters)
_args = _si.extend_request_args(_args, oauth2.AccessTokenResponse,
'refresh_token_response', _state,
parameters)

if request_args is None:
request_args = _args
Expand Down
6 changes: 3 additions & 3 deletions src/oidcservice/oauth2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ def pick_redirect_uris(request_args=None, service=None, **kwargs):
if 'redirect_uri' in request_args:
return request_args, {}

_callback = _context.get('callback')
_callback = _context.callback
if _callback:
try:
_response_type = request_args['response_type']
except KeyError:
_response_type = _context.get('behaviour')['response_types'][0]
_response_type = _context.behaviour['response_types'][0]
request_args['response_type'] = _response_type

try:
Expand All @@ -41,7 +41,7 @@ def pick_redirect_uris(request_args=None, service=None, **kwargs):
else:
request_args['redirect_uri'] = _callback['implicit']
else:
request_args['redirect_uri'] = _context.get('redirect_uris')[0]
request_args['redirect_uri'] = _context.redirect_uris[0]

return request_args, {}

Expand Down
Loading