Skip to content

Commit 3ad123b

Browse files
[red-knot] Narrowing on in tuple[...] and in str (#17059)
## Summary Part of #13694 Seems there a bit more to cover regarding `in` and other types, but i can cover them in different PRs ## Test Plan Add `in.md` file in narrowing conditionals folder --------- Co-authored-by: Carl Meyer <[email protected]>
1 parent a1535fb commit 3ad123b

File tree

3 files changed

+131
-0
lines changed

3 files changed

+131
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Narrowing for `in` conditionals
2+
3+
## `in` for tuples
4+
5+
```py
6+
def _(x: int):
7+
if x in (1, 2, 3):
8+
reveal_type(x) # revealed: int
9+
else:
10+
reveal_type(x) # revealed: int
11+
```
12+
13+
```py
14+
def _(x: str):
15+
if x in ("a", "b", "c"):
16+
reveal_type(x) # revealed: str
17+
else:
18+
reveal_type(x) # revealed: str
19+
```
20+
21+
```py
22+
from typing import Literal
23+
24+
def _(x: Literal[1, 2, "a", "b", False, b"abc"]):
25+
if x in (1,):
26+
reveal_type(x) # revealed: Literal[1]
27+
elif x in (2, "a"):
28+
reveal_type(x) # revealed: Literal[2, "a"]
29+
elif x in (b"abc",):
30+
reveal_type(x) # revealed: Literal[b"abc"]
31+
elif x not in (3,):
32+
reveal_type(x) # revealed: Literal["b", False]
33+
else:
34+
reveal_type(x) # revealed: Never
35+
```
36+
37+
```py
38+
def _(x: Literal["a", "b", "c", 1]):
39+
if x in ("a", "b", "c", 2):
40+
reveal_type(x) # revealed: Literal["a", "b", "c"]
41+
else:
42+
reveal_type(x) # revealed: Literal[1]
43+
```
44+
45+
## `in` for `str` and literal strings
46+
47+
```py
48+
def _(x: str):
49+
if x in "abc":
50+
reveal_type(x) # revealed: str
51+
else:
52+
reveal_type(x) # revealed: str
53+
```
54+
55+
```py
56+
from typing import Literal
57+
58+
def _(x: Literal["a", "b", "c", "d"]):
59+
if x in "abc":
60+
reveal_type(x) # revealed: Literal["a", "b", "c"]
61+
else:
62+
reveal_type(x) # revealed: Literal["d"]
63+
```
64+
65+
```py
66+
def _(x: Literal["a", "b", "c", "e"]):
67+
if x in "abcd":
68+
reveal_type(x) # revealed: Literal["a", "b", "c"]
69+
else:
70+
reveal_type(x) # revealed: Literal["e"]
71+
```
72+
73+
```py
74+
def _(x: Literal[1, "a", "b", "c", "d"]):
75+
# error: [unsupported-operator]
76+
if x in "abc":
77+
reveal_type(x) # revealed: Literal["a", "b", "c"]
78+
else:
79+
reveal_type(x) # revealed: Literal[1, "d"]
80+
```

crates/red_knot_python_semantic/src/types.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,11 @@ impl<'db> Type<'db> {
408408
matches!(self, Type::FunctionLiteral(..))
409409
}
410410

411+
pub fn is_union_of_single_valued(&self, db: &'db dyn Db) -> bool {
412+
self.into_union()
413+
.is_some_and(|union| union.elements(db).iter().all(|ty| ty.is_single_valued(db)))
414+
}
415+
411416
pub const fn into_int_literal(self) -> Option<i64> {
412417
match self {
413418
Type::IntLiteral(value) => Some(value),
@@ -422,6 +427,10 @@ impl<'db> Type<'db> {
422427
}
423428
}
424429

430+
pub fn is_string_literal(&self) -> bool {
431+
matches!(self, Type::StringLiteral(..))
432+
}
433+
425434
#[track_caller]
426435
pub fn expect_int_literal(self) -> i64 {
427436
self.into_int_literal()
@@ -5403,6 +5412,14 @@ impl<'db> StringLiteralType<'db> {
54035412
pub fn python_len(&self, db: &'db dyn Db) -> usize {
54045413
self.value(db).chars().count()
54055414
}
5415+
5416+
/// Return an iterator over each character in the string literal.
5417+
/// as would be returned by Python's `iter()`.
5418+
pub fn iter_each_char(&self, db: &'db dyn Db) -> impl Iterator<Item = Self> {
5419+
self.value(db)
5420+
.chars()
5421+
.map(|c| StringLiteralType::new(db, c.to_string().as_str()))
5422+
}
54065423
}
54075424

54085425
#[salsa::interned(debug)]

crates/red_knot_python_semantic/src/types/narrow.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ use rustc_hash::FxHashMap;
1919
use std::collections::hash_map::Entry;
2020
use std::sync::Arc;
2121

22+
use super::UnionType;
23+
2224
/// Return the type constraint that `test` (if true) would place on `definition`, if any.
2325
///
2426
/// For example, if we have this code:
@@ -288,6 +290,28 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
288290
NarrowingConstraints::from_iter([(symbol, ty)])
289291
}
290292

293+
fn evaluate_expr_in(&mut self, lhs_ty: Type<'db>, rhs_ty: Type<'db>) -> Option<Type<'db>> {
294+
if lhs_ty.is_single_valued(self.db) || lhs_ty.is_union_of_single_valued(self.db) {
295+
match rhs_ty {
296+
Type::Tuple(rhs_tuple) => Some(UnionType::from_elements(
297+
self.db,
298+
rhs_tuple.elements(self.db),
299+
)),
300+
301+
Type::StringLiteral(string_literal) => Some(UnionType::from_elements(
302+
self.db,
303+
string_literal
304+
.iter_each_char(self.db)
305+
.map(Type::StringLiteral),
306+
)),
307+
308+
_ => None,
309+
}
310+
} else {
311+
None
312+
}
313+
}
314+
291315
fn evaluate_expr_compare(
292316
&mut self,
293317
expr_compare: &ast::ExprCompare,
@@ -371,6 +395,16 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
371395
ast::CmpOp::Eq if lhs_ty.is_literal_string() => {
372396
constraints.insert(symbol, rhs_ty);
373397
}
398+
ast::CmpOp::In => {
399+
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
400+
constraints.insert(symbol, ty);
401+
}
402+
}
403+
ast::CmpOp::NotIn => {
404+
if let Some(ty) = self.evaluate_expr_in(lhs_ty, rhs_ty) {
405+
constraints.insert(symbol, ty.negate(self.db));
406+
}
407+
}
374408
_ => {
375409
// TODO other comparison types
376410
}

0 commit comments

Comments
 (0)