Skip to content

Commit 490d452

Browse files
committed
fix mypy issues
1 parent fd3fae4 commit 490d452

File tree

4 files changed

+17
-12
lines changed

4 files changed

+17
-12
lines changed

src/schematic/datastream/datastream_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _coerce_nulls(data: dict, model_cls: type) -> dict:
5353
def _validate(model_cls: type, raw: Any) -> Any:
5454
"""Validate raw data into a Pydantic model, coercing Go-style nulls first."""
5555
if isinstance(raw, dict):
56-
return model_cls.model_validate(_coerce_nulls(raw, model_cls))
56+
return model_cls.model_validate(_coerce_nulls(raw, model_cls)) # type: ignore[attr-defined]
5757
return raw
5858

5959

@@ -396,8 +396,8 @@ async def check_flag(
396396
else:
397397
tasks.append(_resolved(cached_user))
398398

399-
company, user = await asyncio.gather(*tasks)
400-
return self._evaluate_flag(flag, company, user)
399+
results: list = await asyncio.gather(*tasks)
400+
return self._evaluate_flag(flag, results[0], results[1])
401401

402402
async def update_company_metrics(self, keys: Dict[str, str], event: str, quantity: int) -> None:
403403
"""Update company metrics locally in cache (for track events)."""
@@ -893,10 +893,10 @@ def _clear_pending_requests(self) -> None:
893893
fut.set_exception(RuntimeError("DataStream client disconnected"))
894894
self._pending_company.clear()
895895

896-
for futures in self._pending_user.values():
897-
for fut in futures:
898-
if not fut.done():
899-
fut.set_exception(RuntimeError("DataStream client disconnected"))
896+
for user_futures in self._pending_user.values():
897+
for user_fut in user_futures:
898+
if not user_fut.done():
899+
user_fut.set_exception(RuntimeError("DataStream client disconnected"))
900900
self._pending_user.clear()
901901

902902
if self._pending_flags is not None and not self._pending_flags.done():

src/schematic/datastream/websocket_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, options: ClientOptions) -> None:
112112
raise ValueError("url is required")
113113
if not options.api_key:
114114
raise ValueError("api_key is required")
115-
if not options.message_handler:
115+
if options.message_handler is None: # type: ignore[operator]
116116
raise ValueError("message_handler is required")
117117

118118
# Auto-convert HTTP(S) URLs to WebSocket URLs

tests/datastream/test_merge.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def test_sets_base_plan_to_none(self) -> None:
221221
class TestPartialCompanyMissingID:
222222
def test_raises_value_error(self) -> None:
223223
existing = base_company()
224-
partial = {"traits": []}
224+
partial: dict[str, list[str]] = {"traits": []}
225225

226226
with pytest.raises(ValueError, match="missing required field: id"):
227227
partial_company(existing, partial)
@@ -314,6 +314,7 @@ def test_full_entity_partial_message(self) -> None:
314314
# Credit balances merge: existing credit-1 overwritten, credit-new added
315315
assert merged.credit_balances == {"credit-1": 999.0, "credit-new": 50.0}
316316

317+
assert merged.entitlements is not None
317318
assert len(merged.entitlements) == 2
318319
assert merged.entitlements[0].feature_id == "feat-new"
319320
assert merged.entitlements[1].feature_id == "feat-2"
@@ -493,13 +494,15 @@ def test_full_copy_is_independent(self) -> None:
493494
)
494495

495496
cp = deep_copy_company(orig)
497+
assert cp is not None
496498

497499
assert cp.id == orig.id
498500
assert cp.account_id == orig.account_id
499501
assert cp.environment_id == orig.environment_id
500502
assert cp.base_plan_id == orig.base_plan_id
501503
assert cp.keys == orig.keys
502504
assert cp.credit_balances == orig.credit_balances
505+
assert cp.subscription is not None
503506
assert cp.subscription.id == "sub-1"
504507
assert cp.metrics[0].value == 42
505508

@@ -518,6 +521,7 @@ def test_empty_fields(self) -> None:
518521
rules=[],
519522
)
520523
cp = deep_copy_user(orig)
524+
assert cp is not None
521525

522526
assert cp.id == "u1"
523527
assert cp.keys == {}
@@ -528,6 +532,7 @@ def test_full_copy_is_independent(self) -> None:
528532
orig = base_user().model_copy(update={"rules": [_make_rule("r1")]})
529533

530534
cp = deep_copy_user(orig)
535+
assert cp is not None
531536

532537
assert cp.id == orig.id
533538
assert cp.account_id == orig.account_id

tests/datastream/test_websocket_client.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ async def handler(msg): pass
171171
def test_init_missing_message_handler() -> None:
172172
with pytest.raises(ValueError, match="message_handler is required"):
173173
DatastreamWSClient(
174-
ClientOptions(url="wss://example.com", api_key="key", message_handler=None, logger=logger)
174+
ClientOptions(url="wss://example.com", api_key="key", message_handler=None, logger=logger) # type: ignore[arg-type]
175175
)
176176

177177

@@ -236,7 +236,7 @@ async def test_string_message_delivered_to_handler() -> None:
236236

237237
async def test_bytes_message_delivered_to_handler() -> None:
238238
payload = json.dumps({"entity_type": "rulesengine.Flag", "message_type": "full", "data": None})
239-
ws = MockWebSocket(messages=[payload.encode()])
239+
ws = MockWebSocket(messages=[payload.encode()]) # type: ignore[list-item]
240240
client, ws, received = make_client(ws=ws)
241241

242242
with patch("schematic.datastream.websocket_client.websockets.connect", make_connect(ws)):
@@ -431,7 +431,7 @@ async def test_stops_at_max_reconnect_attempts() -> None:
431431
ClientOptions(
432432
url="wss://test.example.com/datastream",
433433
api_key="key",
434-
message_handler=lambda _: None,
434+
message_handler=lambda _: None, # type: ignore[arg-type,return-value]
435435
logger=logger,
436436
min_reconnect_delay=0.0,
437437
max_reconnect_delay=0.0,

0 commit comments

Comments
 (0)