9
9
from oidcmsg .logging import configure_logging
10
10
from oidcmsg .util import load_yaml_config
11
11
12
- DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key' , 'server_cert' , 'filename' , 'template_dir' ,
13
- 'private_path' , 'public_path' , 'db_file' ]
12
+ DEFAULT_FILE_ATTRIBUTE_NAMES = ['server_key' , 'server_cert' , 'filename' ,
13
+ 'private_path' , 'public_path' , 'db_file' , 'jwks_file' ]
14
14
15
- URIS = ["redirect_uris" , 'issuer' , 'base_url ' ]
15
+ DEFAULT_DIR_ATTRIBUTE_NAMES = ['template_dir ' ]
16
16
17
17
18
18
def lower_or_upper (config , param , default = None ):
@@ -22,17 +22,31 @@ def lower_or_upper(config, param, default=None):
22
22
return res
23
23
24
24
25
- def add_base_path (conf : dict , base_path : str , file_attributes : List [str ]):
25
+ def add_path_to_filename (filename , base_path ):
26
+ if filename == "" or filename .startswith ("/" ):
27
+ return filename
28
+ else :
29
+ return os .path .join (base_path , filename )
30
+
31
+
32
+ def add_path_to_directory_name (directory_name , base_path ):
33
+ if directory_name .startswith ("/" ):
34
+ return directory_name
35
+ elif directory_name == "" :
36
+ return "./" + directory_name
37
+ else :
38
+ return os .path .join (base_path , directory_name )
39
+
40
+
41
+ def add_base_path (conf : dict , base_path : str , attributes : List [str ], attribute_type : str = "file" ):
26
42
for key , val in conf .items ():
27
- if key in file_attributes :
28
- if val .startswith ("/" ):
29
- continue
30
- elif val == "" :
31
- conf [key ] = "./" + val
43
+ if key in attributes :
44
+ if attribute_type == "file" :
45
+ conf [key ] = add_path_to_filename (val , base_path )
32
46
else :
33
- conf [key ] = os . path . join ( base_path , val )
47
+ conf [key ] = add_path_to_directory_name ( val , base_path )
34
48
if isinstance (val , dict ):
35
- conf [key ] = add_base_path (val , base_path , file_attributes )
49
+ conf [key ] = add_base_path (val , base_path , attributes , attribute_type )
36
50
37
51
return conf
38
52
@@ -53,41 +67,71 @@ def set_domain_and_port(conf: dict, uris: List[str], domain: str, port: int):
53
67
return conf
54
68
55
69
56
- class Base :
70
+ class Base ( dict ) :
57
71
""" Configuration base class """
58
72
73
+ parameter = {}
74
+ uris = ["issuer" , "base_url" ]
75
+
59
76
def __init__ (self ,
60
77
conf : Dict ,
61
78
base_path : str = '' ,
62
79
file_attributes : Optional [List [str ]] = None ,
80
+ dir_attributes : Optional [List [str ]] = None ,
81
+ domain : Optional [str ] = "" ,
82
+ port : Optional [int ] = 0 ,
63
83
):
84
+ dict .__init__ (self )
85
+ self ._file_attributes = file_attributes or DEFAULT_FILE_ATTRIBUTE_NAMES
86
+ self ._dir_attributes = dir_attributes or DEFAULT_DIR_ATTRIBUTE_NAMES
64
87
65
- if file_attributes is None :
66
- file_attributes = DEFAULT_FILE_ATTRIBUTE_NAMES
67
-
68
- if base_path and file_attributes :
88
+ if base_path :
69
89
# this adds a base path to all paths in the configuration
70
- add_base_path (conf , base_path , file_attributes )
90
+ if self ._file_attributes :
91
+ add_base_path (conf , base_path , self ._file_attributes , "file" )
92
+ if self ._dir_attributes :
93
+ add_base_path (conf , base_path , self ._dir_attributes , "dir" )
71
94
72
- def __getitem__ (self , item ):
73
- if item in self .__dict__ :
74
- return self .__dict__ [item ]
95
+ # entity info
96
+ self .domain = domain or conf .get ("domain" , "127.0.0.1" )
97
+ self .port = port or conf .get ("port" , 80 )
98
+
99
+ self .conf = set_domain_and_port (conf , self .uris , self .domain , self .port )
100
+
101
+ def __getattr__ (self , item , default = None ):
102
+ if item in self :
103
+ return self [item ]
75
104
else :
76
- raise KeyError
105
+ return default
77
106
78
- def get (self , item , default = None ):
79
- return getattr (self , item , default )
107
+ def __setattr__ (self , key , value ):
108
+ if key in self :
109
+ raise KeyError ('{} has already been set' .format (key ))
110
+ super (Base , self ).__setitem__ (key , value )
111
+
112
+ def __setitem__ (self , key , value ):
113
+ if key in self :
114
+ raise KeyError ('{} has already been set' .format (key ))
115
+ super (Base , self ).__setitem__ (key , value )
80
116
81
- def __contains__ (self , item ):
82
- return item in self .__dict__
117
+ def get (self , item , default = None ):
118
+ return self .__getattr__ ( item , default )
83
119
84
120
def items (self ):
85
- for key in self .__dict__ :
121
+ for key in self .keys () :
86
122
if key .startswith ('__' ) and key .endswith ('__' ):
87
123
continue
88
124
yield key , getattr (self , key )
89
125
90
- def extend (self , entity_conf , conf , base_path , file_attributes , domain , port ):
126
+ def extend (self ,
127
+ conf : Dict ,
128
+ base_path : str ,
129
+ domain : str ,
130
+ port : int ,
131
+ entity_conf : Optional [List [dict ]] = None ,
132
+ file_attributes : Optional [List [str ]] = None ,
133
+ dir_attributes : Optional [List [str ]] = None ,
134
+ ):
91
135
for econf in entity_conf :
92
136
_path = econf .get ("path" )
93
137
_cnf = conf
@@ -98,11 +142,49 @@ def extend(self, entity_conf, conf, base_path, file_attributes, domain, port):
98
142
_cls = econf ["class" ]
99
143
setattr (self , _attr ,
100
144
_cls (_cnf , base_path = base_path , file_attributes = file_attributes ,
101
- domain = domain , port = port ))
145
+ domain = domain , port = port , dir_attributes = dir_attributes ))
146
+
147
+ def complete_paths (self , conf : Dict , keys : List [str ], default_config : Dict , base_path : str ):
148
+ for key in keys :
149
+ _val = conf .get (key )
150
+ if _val is None and key in default_config :
151
+ _val = default_config [key ]
152
+ if key in self ._file_attributes :
153
+ _val = add_path_to_filename (_val , base_path )
154
+ elif key in self ._dir_attributes :
155
+ _val = add_path_to_directory_name (_val , base_path )
156
+ if not _val :
157
+ continue
158
+
159
+ setattr (self , key , _val )
160
+
161
+ def format (self , conf , base_path : str , domain : str , port : int ,
162
+ file_attributes : Optional [List [str ]] = None ,
163
+ dir_attributes : Optional [List [str ]] = None ) -> None :
164
+ """
165
+ Formats parts of the configuration. That includes replacing the strings {domain} and {port}
166
+ with the used domain and port and making references to files and directories absolute
167
+ rather then relative. The formatting is done in place.
168
+
169
+ :param dir_attributes:
170
+ :param conf: The configuration part
171
+ :param base_path: The base path used to make file/directory refrences absolute
172
+ :param file_attributes: Attribute names that refer to files or directories.
173
+ :param domain: The domain name
174
+ :param port: The port used
175
+ """
176
+ if isinstance (conf , dict ):
177
+ if file_attributes :
178
+ add_base_path (conf , base_path , file_attributes , attribute_type = "file" )
179
+ if dir_attributes :
180
+ add_base_path (conf , base_path , dir_attributes , attribute_type = "dir" )
181
+ if isinstance (conf , dict ):
182
+ set_domain_and_port (conf , self .uris , domain = domain , port = port )
102
183
103
184
104
185
class Configuration (Base ):
105
- """Server Configuration"""
186
+ """Entity Configuration Base"""
187
+ uris = ["redirect_uris" , 'issuer' , 'base_url' , 'server_name' ]
106
188
107
189
def __init__ (self ,
108
190
conf : Dict ,
@@ -111,27 +193,24 @@ def __init__(self,
111
193
file_attributes : Optional [List [str ]] = None ,
112
194
domain : Optional [str ] = "" ,
113
195
port : Optional [int ] = 0 ,
196
+ dir_attributes : Optional [List [str ]] = None ,
114
197
):
115
- Base .__init__ (self , conf , base_path = base_path , file_attributes = file_attributes )
198
+ Base .__init__ (self , conf , base_path = base_path , file_attributes = file_attributes ,
199
+ dir_attributes = dir_attributes , domain = domain , port = port )
116
200
117
- log_conf = conf .get ('logging' )
201
+ log_conf = self . conf .get ('logging' )
118
202
if log_conf :
119
203
self .logger = configure_logging (config = log_conf ).getChild (__name__ )
120
204
else :
121
205
self .logger = logging .getLogger ('oidcrp' )
122
206
123
- self .web_conf = lower_or_upper (conf , "webserver" )
124
-
125
- # entity info
126
- if not domain :
127
- domain = conf .get ("domain" , "127.0.0.1" )
128
-
129
- if not port :
130
- port = conf .get ("port" , 80 )
207
+ self .web_conf = lower_or_upper (self .conf , "webserver" )
131
208
132
209
if entity_conf :
133
- self .extend (entity_conf = entity_conf , conf = conf , base_path = base_path ,
134
- file_attributes = file_attributes , domain = domain , port = port )
210
+ self .extend (conf = self .conf , base_path = base_path ,
211
+ domain = self .domain , port = self .port , entity_conf = entity_conf ,
212
+ file_attributes = self ._file_attributes ,
213
+ dir_attributes = self ._dir_attributes )
135
214
136
215
137
216
def create_from_config_file (cls ,
@@ -140,7 +219,9 @@ def create_from_config_file(cls,
140
219
entity_conf : Optional [List [dict ]] = None ,
141
220
file_attributes : Optional [List [str ]] = None ,
142
221
domain : Optional [str ] = "" ,
143
- port : Optional [int ] = 0 ):
222
+ port : Optional [int ] = 0 ,
223
+ dir_attributes : Optional [List [str ]] = None
224
+ ):
144
225
if filename .endswith (".yaml" ):
145
226
"""Load configuration as YAML"""
146
227
_cnf = load_yaml_config (filename )
@@ -158,4 +239,4 @@ def create_from_config_file(cls,
158
239
return cls (_cnf ,
159
240
entity_conf = entity_conf ,
160
241
base_path = base_path , file_attributes = file_attributes ,
161
- domain = domain , port = port )
242
+ domain = domain , port = port , dir_attributes = dir_attributes )
0 commit comments