Skip to content

Commit 5a85eed

Browse files
authored
Fixed SCIM search for large queries (#2049)
1 parent 13a3cde commit 5a85eed

File tree

11 files changed

+410
-56
lines changed

11 files changed

+410
-56
lines changed

main/settings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145
"SERVICE_PROVIDER_CONFIG_MODEL": "scim.config.LearnSCIMServiceProviderConfig",
146146
"USER_ADAPTER": "scim.adapters.LearnSCIMUser",
147147
"USER_MODEL_GETTER": "scim.adapters.get_user_model_for_scim",
148-
"USER_FILTER_PARSER": "scim.filters.LearnUserFilterQuery",
148+
"USER_FILTER_PARSER": "scim.filters.UserFilterQuery",
149149
"GET_IS_AUTHENTICATED_PREDICATE": "scim.utils.is_authenticated_predicate",
150150
}
151151

poetry.lock

Lines changed: 15 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ llama-index-agent-openai = "^0.4.1"
9292
langchain-experimental = "^0.3.4"
9393
langchain-openai = "^0.3.2"
9494
deepmerge = "^2.0"
95+
pyparsing = "^3.2.1"
9596

9697

9798
[tool.poetry.group.dev.dependencies]
@@ -117,6 +118,7 @@ freezegun = "^1.4.0"
117118
pytest-xdist = { version = "^3.6.1", extras = ["psutil"] }
118119
anys = "^0.3.0"
119120
locust = "^2.31.2"
121+
traceback-with-variables = "^2.1.1"
120122

121123

122124
[build-system]

scim/filters.py

Lines changed: 114 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,126 @@
1+
import operator
2+
from collections.abc import Callable
13
from typing import Optional
24

3-
from django_scim.filters import UserFilterQuery
5+
from django.contrib.auth import get_user_model
6+
from django.db.models import Model, Q
7+
from pyparsing import ParseResults
48

5-
from scim.parser.queries.sql import PatchedSQLQuery
9+
from scim.parser.grammar import Filters, TermType
610

711

8-
class LearnUserFilterQuery(UserFilterQuery):
12+
class FilterQuery:
913
"""Filters for users"""
1014

11-
query_class = PatchedSQLQuery
15+
model_cls: type[Model]
1216

13-
attr_map: dict[tuple[Optional[str], Optional[str], Optional[str]], str] = {
14-
("userName", None, None): "auth_user.username",
15-
("emails", "value", None): "auth_user.email",
16-
("active", None, None): "auth_user.is_active",
17-
("fullName", None, None): "profiles_profile.name",
18-
("name", "givenName", None): "auth_user.first_name",
19-
("name", "familyName", None): "auth_user.last_name",
17+
attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]]
18+
19+
related_selects: list[str] = []
20+
21+
dj_op_mapping = {
22+
"eq": "exact",
23+
"ne": "exact",
24+
"gt": "gt",
25+
"ge": "gte",
26+
"lt": "lt",
27+
"le": "lte",
28+
"pr": "isnull",
29+
"co": "contains",
30+
"sw": "startswith",
31+
"ew": "endswith",
2032
}
2133

22-
joins: tuple[str, ...] = (
23-
"INNER JOIN profiles_profile ON profiles_profile.user_id = auth_user.id",
24-
)
34+
dj_negated_ops = ("ne", "pr")
35+
36+
@classmethod
37+
def _filter_expr(cls, parsed: ParseResults) -> Q:
38+
if parsed is None:
39+
msg = "Expected a filter, got: None"
40+
raise ValueError(msg)
41+
42+
if parsed.term_type == TermType.attr_expr:
43+
return cls._attr_expr(parsed)
44+
45+
msg = f"Unsupported term type: {parsed.term_type}"
46+
raise ValueError(msg)
47+
48+
@classmethod
49+
def _attr_expr(cls, parsed: ParseResults) -> Q:
50+
dj_op = cls.dj_op_mapping[parsed.comparison_operator.lower()]
51+
52+
scim_keys = (parsed.attr_name, parsed.sub_attr)
53+
54+
path_parts = list(
55+
filter(
56+
lambda part: part is not None,
57+
(
58+
*cls.attr_map.get(scim_keys, scim_keys),
59+
dj_op,
60+
),
61+
)
62+
)
63+
path = "__".join(path_parts)
64+
65+
q = Q(**{path: parsed.value})
66+
67+
if parsed.comparison_operator in cls.dj_negated_ops:
68+
q = ~q
69+
70+
return q
71+
72+
@classmethod
73+
def _filters(cls, parsed: ParseResults) -> Q:
74+
parsed_iter = iter(parsed)
75+
q = cls._filter_expr(next(parsed_iter))
76+
77+
try:
78+
while operator := cls._logical_op(next(parsed_iter)):
79+
filter_q = cls._filter_expr(next(parsed_iter))
80+
81+
# combine the previous and next Q() objects using the bitwise operator
82+
q = operator(q, filter_q)
83+
except StopIteration:
84+
pass
85+
86+
return q
87+
88+
@classmethod
89+
def _logical_op(cls, parsed: ParseResults) -> Callable[[Q, Q], Q] | None:
90+
"""Convert a defined operator to the corresponding bitwise operator"""
91+
if parsed is None:
92+
return None
93+
94+
if parsed.logical_operator.lower() == "and":
95+
return operator.and_
96+
elif parsed.logical_operator.lower() == "or":
97+
return operator.or_
98+
else:
99+
msg = f"Unexpected operator: {parsed.operator}"
100+
raise ValueError(msg)
25101

26102
@classmethod
27-
def search(cls, filter_query, request=None):
28-
return super().search(filter_query, request=request)
103+
def search(cls, filter_query, request=None): # noqa: ARG003
104+
"""Create a search query"""
105+
parsed = Filters.parse_string(filter_query, parse_all=True)
106+
107+
return cls.model_cls.objects.select_related(*cls.related_selects).filter(
108+
cls._filters(parsed)
109+
)
110+
111+
112+
class UserFilterQuery(FilterQuery):
113+
"""FilterQuery for User"""
114+
115+
attr_map: dict[tuple[str, Optional[str]], tuple[str, ...]] = {
116+
("userName", None): ("username",),
117+
("emails", "value"): ("email",),
118+
("active", None): ("is_active",),
119+
("fullName", None): ("profile", "name"),
120+
("name", "givenName"): ("first_name",),
121+
("name", "familyName"): ("last_name",),
122+
}
123+
124+
related_selects = ["profile"]
125+
126+
model_cls = get_user_model()

scim/parser/grammar.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
"""
2+
SCIM filter parsers
3+
+ _tag_term_type(TermType.attr_name)
4+
5+
This module aims to compliantly parse SCIM filter queries per the spec:
6+
https://datatracker.ietf.org/doc/html/rfc7644#section-3.4.2.2
7+
8+
Note that this implementation defines things slightly differently
9+
because a naive implementation exactly matching the filter grammar will
10+
result in hitting Python's recursion limit because the grammar defines
11+
logical lists (AND/OR chains) as a recursive relationship.
12+
13+
This implementation avoids that by defining separately FilterExpr and
14+
Filter. As a result of this, some definitions are collapsed and removed
15+
(e.g. valFilter => FilterExpr).
16+
"""
17+
18+
from enum import StrEnum, auto
19+
20+
from pyparsing import (
21+
CaselessKeyword,
22+
Char,
23+
Combine,
24+
DelimitedList,
25+
FollowedBy,
26+
Forward,
27+
Group,
28+
Literal,
29+
Suppress,
30+
Tag,
31+
alphanums,
32+
alphas,
33+
common,
34+
dbl_quoted_string,
35+
nested_expr,
36+
one_of,
37+
remove_quotes,
38+
ungroup,
39+
)
40+
41+
42+
class TagName(StrEnum):
43+
"""Tag names"""
44+
45+
term_type = auto()
46+
value_type = auto()
47+
48+
49+
class TermType(StrEnum):
50+
"""Tag term type"""
51+
52+
urn = auto()
53+
attr_name = auto()
54+
attr_path = auto()
55+
attr_expr = auto()
56+
value_path = auto()
57+
presence = auto()
58+
59+
logical_op = auto()
60+
compare_op = auto()
61+
negation_op = auto()
62+
63+
filter_expr = auto()
64+
filters = auto()
65+
66+
67+
class ValueType(StrEnum):
68+
"""Tag value_type"""
69+
70+
boolean = auto()
71+
number = auto()
72+
string = auto()
73+
null = auto()
74+
75+
76+
def _tag_term_type(term_type: TermType) -> Tag:
77+
return Tag(TagName.term_type.name, term_type)
78+
79+
80+
def _tag_value_type(value_type: ValueType) -> Tag:
81+
return Tag(TagName.value_type.name, value_type)
82+
83+
84+
NameChar = Char(alphanums + "_-")
85+
AttrName = Combine(
86+
Char(alphas)
87+
+ NameChar[...]
88+
# ensure we're not somehow parsing an URN
89+
+ ~FollowedBy(":")
90+
).set_results_name("attr_name") + _tag_term_type(TermType.attr_name)
91+
92+
# Example URN-qualifed attr:
93+
# urn:ietf:params:scim:schemas:core:2.0:User:userName
94+
# |--------------- URN --------------------|:| attr |
95+
UrnAttr = Combine(
96+
Combine(
97+
Literal("urn:")
98+
+ DelimitedList(
99+
# characters ONLY if followed by colon
100+
Char(alphanums + ".-_")[1, ...] + FollowedBy(":"),
101+
# separator
102+
Literal(":"),
103+
# combine everything back into a singular token
104+
combine=True,
105+
)[1, ...]
106+
).set_results_name("urn")
107+
# separator between URN and attribute name
108+
+ Literal(":")
109+
+ AttrName
110+
+ _tag_term_type(TermType.urn)
111+
)
112+
113+
114+
SubAttr = ungroup(Combine(Suppress(".") + AttrName)).set_results_name("sub_attr") ^ (
115+
Tag("sub_attr", None)
116+
)
117+
118+
AttrPath = (
119+
(
120+
# match on UrnAttr first
121+
UrnAttr ^ AttrName
122+
)
123+
+ SubAttr
124+
+ _tag_term_type(TermType.attr_path)
125+
)
126+
127+
ComparisonOperator = one_of(
128+
["eq", "ne", "co", "sw", "ew", "gt", "lt", "ge", "le"],
129+
caseless=True,
130+
as_keyword=True,
131+
).set_results_name("comparison_operator") + _tag_term_type(TermType.compare_op)
132+
133+
LogicalOperator = Group(
134+
one_of(["or", "and"], caseless=True).set_results_name("logical_operator")
135+
+ _tag_term_type(TermType.logical_op)
136+
)
137+
138+
NegationOperator = Group(
139+
(
140+
CaselessKeyword("not")
141+
+ _tag_term_type(TermType.negation_op)
142+
+ Tag("negated", True) # noqa: FBT003
143+
)[..., 1]
144+
^ Tag("negated", False) # noqa: FBT003
145+
)
146+
147+
ValueTrue = Literal("true").set_parse_action(lambda: True) + _tag_value_type(
148+
ValueType.boolean
149+
)
150+
ValueFalse = Literal("false").set_parse_action(lambda: False) + _tag_value_type(
151+
ValueType.boolean
152+
)
153+
ValueNull = Literal("null").set_parse_action(lambda: None) + _tag_value_type(
154+
ValueType.null
155+
)
156+
ValueNumber = (common.integer | common.fnumber) + _tag_value_type(ValueType.number)
157+
ValueString = dbl_quoted_string.set_parse_action(remove_quotes) + _tag_value_type(
158+
ValueType.string
159+
)
160+
161+
ComparisonValue = ungroup(
162+
ValueTrue | ValueFalse | ValueNull | ValueNumber | ValueString
163+
).set_results_name("value")
164+
165+
AttrPresence = Group(
166+
AttrPath + Literal("pr").set_results_name("presence").set_parse_action(lambda: True)
167+
) + _tag_term_type(TermType.presence)
168+
AttrExpression = AttrPresence | Group(
169+
AttrPath + ComparisonOperator + ComparisonValue + _tag_term_type(TermType.attr_expr)
170+
)
171+
172+
# these are forward references, so that we can have
173+
# parsers circularly reference themselves
174+
FilterExpr = Forward()
175+
Filters = Forward()
176+
177+
ValuePath = Group(AttrPath + nested_expr("[", "]", Filters)).set_results_name(
178+
"value_path"
179+
) + _tag_term_type(TermType.value_path)
180+
181+
FilterExpr <<= (
182+
AttrExpression | ValuePath | (NegationOperator + nested_expr("(", ")", Filters))
183+
) + _tag_term_type(TermType.filter_expr)
184+
185+
Filters <<= (
186+
# comment to force it to wrap the below for operator precedence
187+
(FilterExpr + (LogicalOperator + FilterExpr)[...])
188+
+ _tag_term_type(TermType.filters)
189+
)

0 commit comments

Comments
 (0)