Skip to content

Commit 22507ca

Browse files
authored
feat: add get_all_roles_by_domain api (#316)
* feat: add get_all_roles_by_domain api * feat: use set to improve performance
1 parent 70cf615 commit 22507ca

5 files changed

Lines changed: 79 additions & 0 deletions

File tree

casbin/async_enforcer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,3 +253,17 @@ async def get_permissions_for_user_in_domain(self, user, domain):
253253
async def get_named_permissions_for_user_in_domain(self, ptype, user, domain):
254254
"""gets permissions for a user or role with named policy inside domain."""
255255
return self.get_filtered_named_policy(ptype, 0, user, domain)
256+
257+
async def get_all_roles_by_domain(self, domain):
258+
"""gets all roles associated with the domain.
259+
note: Not applicable to Domains with inheritance relationship (implicit roles)"""
260+
g = self.model.model["g"]["g"]
261+
policies = g.policy
262+
roles = set()
263+
for policy in policies:
264+
if policy[len(policy) - 1] == domain:
265+
role = policy[len(policy) - 2]
266+
if role not in roles:
267+
roles.add(role)
268+
269+
return list(roles)

casbin/enforcer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,17 @@ def get_permissions_for_user_in_domain(self, user, domain):
262262
def get_named_permissions_for_user_in_domain(self, ptype, user, domain):
263263
"""gets permissions for a user or role with named policy inside domain."""
264264
return self.get_filtered_named_policy(ptype, 0, user, domain)
265+
266+
def get_all_roles_by_domain(self, domain):
267+
"""gets all roles associated with the domain.
268+
note: Not applicable to Domains with inheritance relationship (implicit roles)"""
269+
g = self.model.model["g"]["g"]
270+
policies = g.policy
271+
roles = set()
272+
for policy in policies:
273+
if policy[len(policy) - 1] == domain:
274+
role = policy[len(policy) - 2]
275+
if role not in roles:
276+
roles.add(role)
277+
278+
return list(roles)

casbin/synced_enforcer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,3 +637,9 @@ def set_field_index(self, ptype, field, index):
637637
"""sets the index of the field name."""
638638
assertion = self._e.model["p"][ptype]
639639
assertion.field_index_map[field] = index
640+
641+
def get_all_roles_by_domain(self, domain):
642+
"""gets all roles associated with the domain.
643+
note: Not applicable to Domains with inheritance relationship (implicit roles)"""
644+
with self._rl:
645+
return self._e.get_all_roles_by_domain(domain)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
p, admin, domain1, data1, read
2+
p, admin, domain1, data1, write
3+
p, admin, domain2, data2, read
4+
p, admin, domain2, data2, write
5+
p, user, domain3, data2, read
6+
g, alice, admin, domain1
7+
g, alice, admin, domain2
8+
g, bob, admin, domain2
9+
g, bob, user, domain3

tests/test_rbac_api.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,22 @@ def test_enforce_get_roles_with_domain(self):
376376
self.assertEqual(e.get_roles_for_user_in_domain("admin", "domain2"), [])
377377
self.assertEqual(e.get_roles_for_user_in_domain("non_exist", "domain2"), [])
378378

379+
def test_get_all_roles_by_domain(self):
380+
e = self.get_enforcer(
381+
get_examples("rbac_with_domains_model.conf"),
382+
get_examples("rbac_with_domains_policy.csv"),
383+
)
384+
self.assertEqual(e.get_all_roles_by_domain("domain1"), ["admin"])
385+
self.assertEqual(e.get_all_roles_by_domain("domain2"), ["admin"])
386+
387+
e = self.get_enforcer(
388+
get_examples("rbac_with_domains_model.conf"),
389+
get_examples("rbac_with_domains_policy2.csv"),
390+
)
391+
self.assertEqual(e.get_all_roles_by_domain("domain1"), ["admin"])
392+
self.assertEqual(e.get_all_roles_by_domain("domain2"), ["admin"])
393+
self.assertEqual(e.get_all_roles_by_domain("domain3"), ["user"])
394+
379395
def test_implicit_user_api(self):
380396
e = self.get_enforcer(
381397
get_examples("rbac_model.conf"),
@@ -824,6 +840,26 @@ async def test_enforce_get_roles_with_domain(self):
824840
self.assertEqual(await e.get_roles_for_user_in_domain("admin", "domain2"), [])
825841
self.assertEqual(await e.get_roles_for_user_in_domain("non_exist", "domain2"), [])
826842

843+
async def test_get_all_roles_by_domain(self):
844+
e = self.get_enforcer(
845+
get_examples("rbac_with_domains_model.conf"),
846+
get_examples("rbac_with_domains_policy.csv"),
847+
)
848+
await e.load_policy()
849+
850+
self.assertEqual(await e.get_all_roles_by_domain("domain1"), ["admin"])
851+
self.assertEqual(await e.get_all_roles_by_domain("domain2"), ["admin"])
852+
853+
e = self.get_enforcer(
854+
get_examples("rbac_with_domains_model.conf"),
855+
get_examples("rbac_with_domains_policy2.csv"),
856+
)
857+
await e.load_policy()
858+
859+
self.assertEqual(await e.get_all_roles_by_domain("domain1"), ["admin"])
860+
self.assertEqual(await e.get_all_roles_by_domain("domain2"), ["admin"])
861+
self.assertEqual(await e.get_all_roles_by_domain("domain3"), ["user"])
862+
827863
async def test_implicit_user_api(self):
828864
e = self.get_enforcer(
829865
get_examples("rbac_model.conf"),

0 commit comments

Comments
 (0)