2020from casbin .persist import Adapter
2121from casbin .persist .adapters import FileAdapter
2222from casbin .rbac import default_role_manager
23- from casbin .util import generate_g_function , SimpleEval , util
23+ from casbin .util import generate_g_function , SimpleEval , util , generate_conditional_g_function
2424from casbin .util .log import configure_logging
2525
2626
@@ -47,6 +47,7 @@ class CoreEnforcer:
4747 adapter = None
4848 watcher = None
4949 rm_map = None
50+ cond_rm_map = None
5051
5152 enabled = False
5253 auto_save = False
@@ -104,6 +105,7 @@ def init_with_model_and_adapter(self, m, adapter=None):
104105
105106 def _initialize (self ):
106107 self .rm_map = dict ()
108+ self .cond_rm_map = dict ()
107109 self .eft = get_effector (self .model ["e" ]["e" ].value )
108110 self .watcher = None
109111
@@ -192,10 +194,26 @@ def init_rm_map(self):
192194 if "g" in self .model .keys ():
193195 for ptype in self .model ["g" ]:
194196 assertion = self .model ["g" ][ptype ]
195- if assertion .value .count ("_" ) == 2 :
196- self .rm_map [ptype ] = default_role_manager .RoleManager (10 )
197- else :
198- self .rm_map [ptype ] = default_role_manager .DomainManager (10 )
197+ if ptype in self .rm_map :
198+ rm = self .rm_map [ptype ]
199+ rm .clear ()
200+ continue
201+
202+ if len (assertion .tokens ) <= 2 and len (assertion .params_tokens ) == 0 :
203+ assertion .rm = default_role_manager .RoleManager (10 )
204+ self .rm_map [ptype ] = assertion .rm
205+
206+ if len (assertion .tokens ) <= 2 and len (assertion .params_tokens ) != 0 :
207+ assertion .cond_rm = default_role_manager .ConditionalRoleManager (10 )
208+ self .cond_rm_map [ptype ] = assertion .cond_rm
209+
210+ if len (assertion .tokens ) > 2 :
211+ if len (assertion .params_tokens ) == 0 :
212+ assertion .rm = default_role_manager .DomainManager (10 )
213+ self .rm_map [ptype ] = assertion .rm
214+ else :
215+ assertion .cond_rm = default_role_manager .ConditionalDomainManager (10 )
216+ self .cond_rm_map [ptype ] = assertion .cond_rm
199217
200218 def load_policy (self ):
201219 """reloads the policy from file/database."""
@@ -216,8 +234,13 @@ def load_policy(self):
216234 need_to_rebuild = True
217235 for rm in self .rm_map .values ():
218236 rm .clear ()
237+ if len (self .rm_map ) != 0 :
238+ new_model .build_role_links (self .rm_map )
219239
220- new_model .build_role_links (self .rm_map )
240+ for crm in self .cond_rm_map .values ():
241+ crm .clear ()
242+ if len (self .cond_rm_map ) != 0 :
243+ new_model .build_conditional_role_links (self .cond_rm_map )
221244
222245 self .model = new_model
223246
@@ -313,6 +336,40 @@ def add_named_domain_matching_func(self, ptype, fn):
313336
314337 return False
315338
339+ def add_named_link_condition_func (self , ptype , user , role , fn ):
340+ """Add condition function fn for Link userName->roleName,
341+ when fn returns true, Link is valid, otherwise invalid"""
342+ if ptype in self .cond_rm_map :
343+ rm = self .cond_rm_map [ptype ]
344+ rm .add_link_condition_func (user , role , fn )
345+ return True
346+ return False
347+
348+ def add_named_domain_link_condition_func (self , ptype , user , role , domain , fn ):
349+ """Add condition function fn for Link userName-> {roleName, domain},
350+ when fn returns true, Link is valid, otherwise invalid"""
351+ if ptype in self .cond_rm_map :
352+ rm = self .cond_rm_map [ptype ]
353+ rm .add_domain_link_condition_func (user , role , domain , fn )
354+ return True
355+ return False
356+
357+ def set_named_link_condition_func_params (self , ptype , user , role , * params ):
358+ """Sets the parameters of the condition function fn for Link userName->roleName"""
359+ if ptype in self .cond_rm_map :
360+ rm = self .cond_rm_map [ptype ]
361+ rm .set_link_condition_func_params (user , role , * params )
362+ return True
363+ return False
364+
365+ def set_named_domain_link_condition_func_params (self , ptype , user , role , domain , * params ):
366+ """Sets the parameters of the condition function fn for Link userName->{roleName, domain}"""
367+ if ptype in self .cond_rm_map :
368+ rm = self .cond_rm_map [ptype ]
369+ rm .set_domain_link_condition_func_params (user , role , domain , * params )
370+ return True
371+ return False
372+
316373 def new_enforce_context (self , suffix : str ) -> EnforceContext :
317374 return EnforceContext (
318375 rtype = "r" + suffix ,
@@ -346,8 +403,10 @@ def enforce_ex(self, *rvals):
346403
347404 if "g" in self .model .keys ():
348405 for key , ast in self .model ["g" ].items ():
349- rm = ast .rm
350- functions [key ] = generate_g_function (rm )
406+ if len (self .rm_map ) != 0 :
407+ functions [key ] = generate_g_function (ast .rm )
408+ if len (self .cond_rm_map ) != 0 :
409+ functions [key ] = generate_conditional_g_function (ast .cond_rm )
351410
352411 if len (rvals ) != 0 :
353412 if isinstance (rvals [0 ], EnforceContext ):
0 commit comments