Skip to content

Commit f964e2a

Browse files
authored
feat: Fixed the subjectPriority sorting algorithm and support for checking the subject role link loop (#322)
* fix: Fixed the `subjectPriority` sorting algorithm and support for checking the subject role link loop. * fix: Run black
1 parent 45bcc8b commit f964e2a

File tree

2 files changed

+79
-31
lines changed

2 files changed

+79
-31
lines changed

casbin/model/model.py

Lines changed: 32 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -134,14 +134,20 @@ def compare_policy(policy):
134134
name = self.get_name_with_domain(domain, policy[sub_index])
135135
return subject_hierarchy_map.get(name, 0)
136136

137-
assertion.policy = sorted(assertion.policy, key=compare_policy, reverse=True)
137+
assertion.policy = sorted(assertion.policy, key=compare_policy)
138138
for i, policy in enumerate(assertion.policy):
139139
assertion.policy_map[",".join(policy)] = i
140140

141141
def get_subject_hierarchy_map(self, policies):
142-
subject_hierarchy_map = {}
143-
# Tree structure of role
144-
policy_map = {}
142+
"""
143+
Get the subject hierarchy from the policy.
144+
Select the lowest level subject in multiple rounds until all subjects are selected.
145+
Return the subject hierarchy dictionary, the subject is the key, and the level is the value.
146+
The level starts from 0 and increases in turn. The smaller the level, the higher the priority.
147+
"""
148+
# Init unsorted policy, and subject
149+
unsorted_policy = []
150+
unsorted_sub = set()
145151
for policy in policies:
146152
if len(policy) < 2:
147153
raise RuntimeError("policy g expect 2 more params")
@@ -150,33 +156,28 @@ def get_subject_hierarchy_map(self, policies):
150156
domain = policy[2]
151157
child = self.get_name_with_domain(domain, policy[0])
152158
parent = self.get_name_with_domain(domain, policy[1])
153-
if parent not in policy_map.keys():
154-
policy_map[parent] = [child]
155-
else:
156-
policy_map[parent].append(child)
157-
if child not in subject_hierarchy_map.keys():
158-
subject_hierarchy_map[child] = 0
159-
if parent not in subject_hierarchy_map.keys():
160-
subject_hierarchy_map[parent] = 0
161-
subject_hierarchy_map[child] = 1
162-
# Use queues for levelOrder
163-
queue = []
164-
for k, v in subject_hierarchy_map.items():
165-
root = k
166-
if v != 0:
167-
continue
168-
lv = 0
169-
queue.append(root)
170-
while len(queue) != 0:
171-
sz = len(queue)
172-
for _ in range(sz):
173-
node = queue.pop(0)
174-
subject_hierarchy_map[node] = lv
175-
if node in policy_map.keys():
176-
for child in policy_map[node]:
177-
queue.append(child)
178-
lv += 1
179-
return subject_hierarchy_map
159+
unsorted_policy.append([child, parent])
160+
unsorted_sub.add(child)
161+
unsorted_sub.add(parent)
162+
# sort policy,and update sorted_sub_list
163+
sorted_sub_list = []
164+
while len(unsorted_policy) > 0:
165+
# get all parent subject
166+
parent_sub = {p[1] for p in unsorted_policy if p[1] != ""}
167+
# remove parent subject from unsorted_sub
168+
sorted_sub = unsorted_sub - parent_sub
169+
if not sorted_sub:
170+
raise RuntimeError("cycle dependency in subject hierarchy.subjects: {}".format(unsorted_sub))
171+
# update sorted_sub_list
172+
sorted_sub_list.append(sorted_sub)
173+
# remove sorted subject, and update unsorted_policy
174+
unsorted_policy = [p for p in unsorted_policy if p[0] not in sorted_sub]
175+
# update unsorted_sub
176+
unsorted_sub = unsorted_sub - sorted_sub
177+
if len(unsorted_sub) > 0:
178+
sorted_sub_list.append(unsorted_sub)
179+
# Tree structure of subject
180+
return {sub: i for i, subs in enumerate(sorted_sub_list) for sub in subs}
180181

181182
def get_name_with_domain(self, domain, name):
182183
return "{}{}{}".format(domain, DEFAULT_SEPARATOR, name)

tests/model/test_model.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from unittest import TestCase
2+
3+
from casbin import Model
4+
from casbin.model.model import DEFAULT_DOMAIN
5+
6+
7+
class TestModel(TestCase):
8+
m = Model()
9+
10+
def check_hierarchy(self, policies: list, subject_hierarchy_map: dict):
11+
"""check_hierarchy checks the hierarchy of the subject hierarchy map"""
12+
for policy in policies:
13+
if len(policy) < 2:
14+
raise RuntimeError("policy g expect 2 more params")
15+
domain = DEFAULT_DOMAIN
16+
if len(policy) != 2:
17+
domain = policy[2]
18+
child = self.m.get_name_with_domain(domain, policy[0])
19+
parent = self.m.get_name_with_domain(domain, policy[1])
20+
assert subject_hierarchy_map[child] < subject_hierarchy_map[parent]
21+
22+
def test_get_subject_hierarchy_map(self):
23+
# test 1
24+
policies = [
25+
["A1", "B1"],
26+
["A1", "B2"],
27+
["A2", "B3"],
28+
]
29+
res = self.m.get_subject_hierarchy_map(policies)
30+
self.check_hierarchy(policies, res)
31+
# test 2
32+
policies = [
33+
["A1", "B1"],
34+
["B1", "B2"],
35+
["B2", "B3"],
36+
["B1", "B4"],
37+
["A1", "B2"],
38+
]
39+
res = self.m.get_subject_hierarchy_map(policies)
40+
self.check_hierarchy(policies, res)
41+
# test 3
42+
policies = [
43+
["B1", "B2"],
44+
["B2", "B3"],
45+
["B3", "B1"],
46+
]
47+
self.assertRaises(RuntimeError, self.m.get_subject_hierarchy_map, policies)

0 commit comments

Comments
 (0)