Skip to content

Commit 937b649

Browse files
authored
feat: align add_policy API with Golang Casbin (#398)
1 parent f3f130f commit 937b649

2 files changed

Lines changed: 20 additions & 16 deletions

File tree

casbin/async_internal_enforcer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,8 @@ async def save_policy(self):
117117

118118
async def _add_policy(self, sec, ptype, rule):
119119
"""async adds a rule to the current policy."""
120-
rule_added = self.model.add_policy(sec, ptype, rule)
121-
if not rule_added:
122-
return rule_added
120+
if self.model.has_policy(sec, ptype, rule):
121+
return False
123122

124123
if self.adapter and self.auto_save:
125124
result = await self.adapter.add_policy(sec, ptype, rule)
@@ -136,13 +135,15 @@ async def _add_policy(self, sec, ptype, rule):
136135
else:
137136
self.watcher.update()
138137

138+
rule_added = self.model.add_policy(sec, ptype, rule)
139+
139140
return rule_added
140141

141142
async def _add_policies(self, sec, ptype, rules):
142143
"""async adds rules to the current policy."""
143-
rules_added = self.model.add_policies(sec, ptype, rules)
144-
if not rules_added:
145-
return rules_added
144+
for rule in rules:
145+
if self.model.has_policy(sec, ptype, rule):
146+
return False
146147

147148
if self.adapter and self.auto_save:
148149
if hasattr(self.adapter, "add_policies") is False:
@@ -162,6 +163,8 @@ async def _add_policies(self, sec, ptype, rules):
162163
else:
163164
self.watcher.update()
164165

166+
rules_added = self.model.add_policies(sec, ptype, rules)
167+
165168
return rules_added
166169

167170
async def _update_policy(self, sec, ptype, old_rule, new_rule):

casbin/internal_enforcer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ class InternalEnforcer(CoreEnforcer):
2222

2323
def _add_policy(self, sec, ptype, rule):
2424
"""adds a rule to the current policy."""
25-
rule_added = self.model.add_policy(sec, ptype, rule)
26-
if not rule_added:
27-
return rule_added
25+
if self.model.has_policy(sec, ptype, rule):
26+
return False
2827

2928
if self.adapter and self.auto_save:
3029
if self.adapter.add_policy(sec, ptype, rule) is False:
@@ -36,13 +35,15 @@ def _add_policy(self, sec, ptype, rule):
3635
else:
3736
self.watcher.update()
3837

38+
rule_added = self.model.add_policy(sec, ptype, rule)
39+
3940
return rule_added
4041

4142
def _add_policies(self, sec, ptype, rules):
4243
"""adds rules to the current policy."""
43-
rules_added = self.model.add_policies(sec, ptype, rules)
44-
if not rules_added:
45-
return rules_added
44+
for rule in rules:
45+
if self.model.has_policy(sec, ptype, rule):
46+
return False
4647

4748
if self.adapter and self.auto_save:
4849
if hasattr(self.adapter, "add_policies") is False:
@@ -57,14 +58,12 @@ def _add_policies(self, sec, ptype, rules):
5758
else:
5859
self.watcher.update()
5960

61+
rules_added = self.model.add_policies(sec, ptype, rules)
62+
6063
return rules_added
6164

6265
def _add_policies_ex(self, sec, ptype, rules):
6366
"""adds rules to the current policy."""
64-
rules_added = self.model.add_policies_ex(sec, ptype, rules)
65-
if not rules_added:
66-
return rules_added
67-
6867
if self.adapter and self.auto_save:
6968
if hasattr(self.adapter, "add_policies_ex") is False:
7069
return False
@@ -78,6 +77,8 @@ def _add_policies_ex(self, sec, ptype, rules):
7877
else:
7978
self.watcher.update()
8079

80+
rules_added = self.model.add_policies_ex(sec, ptype, rules)
81+
8182
return rules_added
8283

8384
def _update_policy(self, sec, ptype, old_rule, new_rule):

0 commit comments

Comments
 (0)