Skip to content

Commit 67537d6

Browse files
authored
feat: port fastbin to casbin (#318)
* feat: port fastbin * feat: implement FastEnforcer * fix: remove redundant init code
1 parent 94b2172 commit 67537d6

File tree

11 files changed

+373
-15
lines changed

11 files changed

+373
-15
lines changed

casbin/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .enforcer import *
1616
from .synced_enforcer import SyncedEnforcer
1717
from .distributed_enforcer import DistributedEnforcer
18+
from .fast_enforcer import FastEnforcer
1819
from .async_enforcer import AsyncEnforcer
1920
from . import util
2021
from .persist import *

casbin/core_enforcer.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import logging
1615
import copy
16+
import logging
1717

1818
from casbin.effect import Effector, get_effector, effect_to_bool
1919
from casbin.model import Model, FunctionMap
@@ -202,7 +202,6 @@ def load_policy(self):
202202
new_model.clear_policy()
203203

204204
try:
205-
206205
self.adapter.load_policy(new_model)
207206

208207
new_model.sort_policies_by_subject_hierarchy()
@@ -212,7 +211,6 @@ def load_policy(self):
212211
new_model.print_policy()
213212

214213
if self.auto_build_role_links:
215-
216214
need_to_rebuild = True
217215
for rm in self.rm_map.values():
218216
rm.clear()
@@ -222,7 +220,6 @@ def load_policy(self):
222220
self.model = new_model
223221

224222
except Exception as e:
225-
226223
if self.auto_build_role_links and need_to_rebuild:
227224
self.build_role_links()
228225

@@ -315,7 +312,6 @@ def add_named_domain_matching_func(self, ptype, fn):
315312
return False
316313

317314
def new_enforce_context(self, suffix: str) -> EnforceContext:
318-
319315
return EnforceContext(
320316
rtype="r" + suffix,
321317
ptype="p" + suffix,

casbin/distributed_enforcer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from casbin import SyncedEnforcer
16-
import logging
17-
18-
from casbin.persist import batch_adapter
1915
from casbin.model.policy_op import PolicyOp
16+
from casbin.persist import batch_adapter
2017
from casbin.persist.adapters import update_adapter
18+
from casbin.synced_enforcer import SyncedEnforcer
2119

2220

2321
class DistributedEnforcer(SyncedEnforcer):

casbin/fast_enforcer.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import logging
2+
from typing import Sequence
3+
4+
from casbin.enforcer import Enforcer
5+
from casbin.model import Model, FastModel, fast_policy_filter, FunctionMap
6+
from casbin.persist.adapters import FileAdapter
7+
from casbin.util.log import configure_logging
8+
9+
10+
class FastEnforcer(Enforcer):
11+
_cache_key_order: Sequence[int] = None
12+
13+
def __init__(self, model=None, adapter=None, enable_log=False, cache_key_order: Sequence[int] = None):
14+
self._cache_key_order = cache_key_order
15+
super().__init__(model, adapter, enable_log)
16+
17+
def new_model(self, path="", text=""):
18+
"""creates a model."""
19+
if self._cache_key_order is None:
20+
m = Model()
21+
else:
22+
m = FastModel(self._cache_key_order)
23+
if len(path) > 0:
24+
m.load_model(path)
25+
else:
26+
m.load_model_from_text(text)
27+
28+
return m
29+
30+
def enforce(self, *rvals):
31+
"""decides whether a "subject" can access a "object" with the operation "action",
32+
input parameters are usually: (sub, obj, act).
33+
"""
34+
if FastEnforcer._cache_key_order is None:
35+
result, _ = self.enforce_ex(*rvals)
36+
else:
37+
keys = [rvals[x] for x in self._cache_key_order]
38+
with fast_policy_filter(self.model.model["p"]["p"].policy, *keys):
39+
result, _ = self.enforce_ex(*rvals)
40+
41+
return result

casbin/model/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515
from .assertion import Assertion
16+
from .function import FunctionMap
1617
from .model import Model
18+
from .model_fast import FastModel
1719
from .policy import Policy
18-
from .function import FunctionMap
20+
from .policy_fast import FastPolicy, fast_policy_filter

casbin/model/model_fast.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2023 The casbin Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Sequence
16+
17+
from .model import Model
18+
from .policy_fast import FastPolicy
19+
20+
21+
class FastModel(Model):
22+
_cache_key_order: Sequence[int]
23+
24+
def __init__(self, cache_key_order: Sequence[int]) -> None:
25+
super().__init__()
26+
self._cache_key_order = cache_key_order
27+
28+
def add_def(self, sec: str, key: str, value: Any) -> None:
29+
super().add_def(sec, key, value)
30+
if sec == "p" and key == "p":
31+
self.model[sec][key].policy = FastPolicy(self._cache_key_order)
32+
33+
def clear_policy(self) -> None:
34+
"""clears all current policy."""
35+
super().clear_policy()
36+
self.model["p"]["p"].policy = FastPolicy(self._cache_key_order)

casbin/model/policy_fast.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2023 The casbin Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from contextlib import contextmanager
16+
from typing import Any, Container, Dict, Iterable, Iterator, Optional, Sequence, Set, cast
17+
18+
19+
def in_cache(cache: Dict[str, Any], keys: Sequence[str]) -> Optional[Set[Sequence[str]]]:
20+
if keys[0] in cache:
21+
if len(keys) > 1:
22+
return in_cache(cache[keys[-0]], keys[1:])
23+
return cast(Set[Sequence[str]], cache[keys[0]])
24+
else:
25+
return None
26+
27+
28+
class FastPolicy(Container[Sequence[str]]):
29+
_cache: Dict[str, Any]
30+
_current_filter: Optional[Set[Sequence[str]]]
31+
_cache_key_order: Sequence[int]
32+
33+
def __init__(self, cache_key_order: Sequence[int]) -> None:
34+
self._cache = {}
35+
self._current_filter = None
36+
self._cache_key_order = cache_key_order
37+
38+
def __iter__(self) -> Iterator[Sequence[str]]:
39+
yield from self.__get_policy()
40+
41+
def __len__(self) -> int:
42+
return len(list(self.__get_policy()))
43+
44+
def __contains__(self, item: object) -> bool:
45+
if not isinstance(item, (list, tuple)) or len(self._cache_key_order) >= len(item):
46+
return False
47+
keys = [item[x] for x in self._cache_key_order]
48+
exists = in_cache(self._cache, keys)
49+
if not exists:
50+
return False
51+
return tuple(item) in exists
52+
53+
def __getitem__(self, item: int) -> Sequence[str]:
54+
for i, entry in enumerate(self):
55+
if i == item:
56+
return entry
57+
raise KeyError("No such value exists")
58+
59+
def append(self, item: Sequence[str]) -> None:
60+
cache = self._cache
61+
keys = [item[x] for x in self._cache_key_order]
62+
63+
for key in keys[:-1]:
64+
if key not in cache:
65+
cache[key] = dict()
66+
cache = cache[key]
67+
if keys[-1] not in cache:
68+
cache[keys[-1]] = set()
69+
70+
cache[keys[-1]].add(tuple(item))
71+
72+
def remove(self, policy: Sequence[str]) -> bool:
73+
keys = [policy[x] for x in self._cache_key_order]
74+
exists = in_cache(self._cache, keys)
75+
if not exists:
76+
return True
77+
78+
exists.remove(tuple(policy))
79+
return True
80+
81+
def __get_policy(self) -> Iterable[Sequence[str]]:
82+
if self._current_filter is not None:
83+
return (list(x) for x in self._current_filter)
84+
else:
85+
return (list(v2) for v in self._cache.values() for v1 in v.values() for v2 in v1)
86+
87+
def apply_filter(self, *keys: str) -> None:
88+
value = in_cache(self._cache, keys)
89+
self._current_filter = value or set()
90+
91+
def clear_filter(self) -> None:
92+
self._current_filter = None
93+
94+
95+
@contextmanager
96+
def fast_policy_filter(policy: FastPolicy, *keys: str) -> Iterator[None]:
97+
try:
98+
policy.apply_filter(*keys)
99+
yield
100+
finally:
101+
policy.clear_filter()

tests/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from . import benchmarks
16+
from . import config
17+
from . import model
18+
from . import rbac
19+
from . import util
1520
from .test_distributed_api import TestDistributedApi
1621
from .test_enforcer import *
22+
from .test_fast_enforcer import TestFastEnforcer
1723
from .test_filter import TestFilteredAdapter
1824
from .test_frontend import TestFrontend
1925
from .test_management_api import TestManagementApi, TestManagementApiSynced
2026
from .test_rbac_api import TestRbacApi, TestRbacApiSynced
21-
from . import benchmarks
22-
from . import config
23-
from . import model
24-
from . import rbac
25-
from . import util

tests/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
# limitations under the License.
1414

1515
from .test_policy import TestPolicy
16+
from .test_policy_fast import TestContextManager, TestFastPolicy

tests/model/test_policy_fast.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2023 The casbin Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import TestCase
16+
17+
from casbin.model import FastPolicy, fast_policy_filter
18+
19+
20+
class TestFastPolicy(TestCase):
21+
def test_able_to_add_rules(self) -> None:
22+
policy = FastPolicy([2, 1])
23+
24+
policy.append(["sub", "obj", "read"])
25+
26+
assert list(policy) == [["sub", "obj", "read"]]
27+
28+
def test_does_not_add_duplicates(self) -> None:
29+
policy = FastPolicy([2, 1])
30+
31+
policy.append(["sub", "obj", "read"])
32+
policy.append(["sub", "obj", "read"])
33+
34+
assert list(policy) == [["sub", "obj", "read"]]
35+
36+
def test_can_remove_rules(self) -> None:
37+
policy = FastPolicy([2, 1])
38+
39+
policy.append(["sub", "obj", "read"])
40+
policy.remove(["sub", "obj", "read"])
41+
42+
assert list(policy) == []
43+
44+
def test_returns_lengtt(self) -> None:
45+
policy = FastPolicy([2, 1])
46+
47+
policy.append(["sub", "obj", "read"])
48+
49+
assert len(policy) == 1
50+
51+
def test_supports_in_keyword(self) -> None:
52+
policy = FastPolicy([2, 1])
53+
54+
policy.append(["sub", "obj", "read"])
55+
56+
assert ["sub", "obj", "read"] in policy
57+
58+
def test_supports_filters(self) -> None:
59+
policy = FastPolicy([2, 1])
60+
61+
policy.append(["sub", "obj", "read"])
62+
policy.append(["sub", "obj", "read2"])
63+
policy.append(["sub", "obj2", "read2"])
64+
65+
policy.apply_filter("read2", "obj2")
66+
67+
assert list(policy) == [["sub", "obj2", "read2"]]
68+
69+
def test_clears_filters(self) -> None:
70+
policy = FastPolicy([2, 1])
71+
72+
policy.append(["sub", "obj", "read"])
73+
policy.append(["sub", "obj", "read2"])
74+
policy.append(["sub", "obj2", "read2"])
75+
76+
policy.apply_filter("read2", "obj2")
77+
policy.clear_filter()
78+
79+
assert list(policy) == [
80+
["sub", "obj", "read"],
81+
["sub", "obj", "read2"],
82+
["sub", "obj2", "read2"],
83+
]
84+
85+
86+
class TestContextManager:
87+
def test_fast_policy_filter(self) -> None:
88+
policy = FastPolicy([2, 1])
89+
90+
policy.append(["sub", "obj", "read"])
91+
policy.append(["sub", "obj", "read2"])
92+
policy.append(["sub", "obj2", "read2"])
93+
94+
with fast_policy_filter(policy, "read2", "obj2"):
95+
assert list(policy) == [["sub", "obj2", "read2"]]
96+
97+
assert list(policy) == [
98+
["sub", "obj", "read"],
99+
["sub", "obj", "read2"],
100+
["sub", "obj2", "read2"],
101+
]

0 commit comments

Comments
 (0)