Skip to content

Commit 3d17897

Browse files
authored
[ty] Fix narrowing and reachability of class patterns with arguments (#19512)
## Summary I noticed that our type narrowing and reachability analysis was incorrect for class patterns that are not irrefutable. The test cases below compare the old and the new behavior: ```py from dataclasses import dataclass @DataClass class Point: x: int y: int class Other: ... def _(target: Point): y = 1 match target: case Point(0, 0): y = 2 case Point(x=0, y=1): y = 3 case Point(x=1, y=0): y = 4 reveal_type(y) # revealed: Literal[1, 2, 3, 4] (previously: Literal[2]) def _(target: Point | Other): match target: case Point(0, 0): reveal_type(target) # revealed: Point case Point(x=0, y=1): reveal_type(target) # revealed: Point (previously: Never) case Point(x=1, y=0): reveal_type(target) # revealed: Point (previously: Never) case Other(): reveal_type(target) # revealed: Other (previously: Other & ~Point) ``` ## Test Plan New Markdown test
1 parent fa1df4c commit 3d17897

File tree

5 files changed

+111
-11
lines changed

5 files changed

+111
-11
lines changed

crates/ty_python_semantic/resources/mdtest/conditional/match.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ def _(subject: C):
8080
A `case` branch with a class pattern is taken if the subject is an instance of the given class, and
8181
all subpatterns in the class pattern match.
8282

83+
### Without arguments
84+
8385
```py
8486
from typing import final
8587

@@ -136,6 +138,51 @@ def _(target: FooSub | str):
136138
reveal_type(y) # revealed: Literal[1, 3, 4]
137139
```
138140

141+
### With arguments
142+
143+
```py
144+
from typing_extensions import assert_never
145+
from dataclasses import dataclass
146+
147+
@dataclass
148+
class Point:
149+
x: int
150+
y: int
151+
152+
class Other: ...
153+
154+
def _(target: Point):
155+
y = 1
156+
157+
match target:
158+
case Point(0, 0):
159+
y = 2
160+
case Point(x=0, y=1):
161+
y = 3
162+
case Point(x=1, y=0):
163+
y = 4
164+
165+
reveal_type(y) # revealed: Literal[1, 2, 3, 4]
166+
167+
def _(target: Point):
168+
match target:
169+
case Point(x, y): # irrefutable sub-patterns
170+
pass
171+
case _:
172+
assert_never(target)
173+
174+
def _(target: Point | Other):
175+
match target:
176+
case Point(0, 0):
177+
reveal_type(target) # revealed: Point
178+
case Point(x=0, y=1):
179+
reveal_type(target) # revealed: Point
180+
case Point(x=1, y=0):
181+
reveal_type(target) # revealed: Point
182+
case Other():
183+
reveal_type(target) # revealed: Other
184+
```
185+
139186
## Singleton match
140187

141188
Singleton patterns are matched based on identity, not equality comparisons or `isinstance()` checks.

crates/ty_python_semantic/src/semantic_index/builder.rs

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ use crate::semantic_index::place::{
3535
PlaceExprWithFlags, PlaceTableBuilder, Scope, ScopeId, ScopeKind, ScopedPlaceId,
3636
};
3737
use crate::semantic_index::predicate::{
38-
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
39-
PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate,
38+
CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate,
39+
PredicateNode, PredicateOrLiteral, ScopedPredicateId, StarImportPlaceholderPredicate,
4040
};
4141
use crate::semantic_index::re_exports::exported_names;
4242
use crate::semantic_index::reachability_constraints::{
@@ -697,7 +697,25 @@ impl<'db, 'ast> SemanticIndexBuilder<'db, 'ast> {
697697
}
698698
ast::Pattern::MatchClass(pattern) => {
699699
let cls = self.add_standalone_expression(&pattern.cls);
700-
PatternPredicateKind::Class(cls)
700+
701+
PatternPredicateKind::Class(
702+
cls,
703+
if pattern
704+
.arguments
705+
.patterns
706+
.iter()
707+
.all(ast::Pattern::is_irrefutable)
708+
&& pattern
709+
.arguments
710+
.keywords
711+
.iter()
712+
.all(|kw| kw.pattern.is_irrefutable())
713+
{
714+
ClassPatternKind::Irrefutable
715+
} else {
716+
ClassPatternKind::Refutable
717+
},
718+
)
701719
}
702720
ast::Pattern::MatchOr(pattern) => {
703721
let predicates = pattern

crates/ty_python_semantic/src/semantic_index/predicate.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,13 +116,25 @@ pub(crate) enum PredicateNode<'db> {
116116
StarImportPlaceholder(StarImportPlaceholderPredicate<'db>),
117117
}
118118

119+
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, salsa::Update)]
120+
pub(crate) enum ClassPatternKind {
121+
Irrefutable,
122+
Refutable,
123+
}
124+
125+
impl ClassPatternKind {
126+
pub(crate) fn is_irrefutable(self) -> bool {
127+
matches!(self, ClassPatternKind::Irrefutable)
128+
}
129+
}
130+
119131
/// Pattern kinds for which we support type narrowing and/or static reachability analysis.
120132
#[derive(Debug, Clone, Hash, PartialEq, salsa::Update)]
121133
pub(crate) enum PatternPredicateKind<'db> {
122134
Singleton(Singleton),
123135
Value(Expression<'db>),
124136
Or(Vec<PatternPredicateKind<'db>>),
125-
Class(Expression<'db>),
137+
Class(Expression<'db>, ClassPatternKind),
126138
Unsupported,
127139
}
128140

crates/ty_python_semantic/src/semantic_index/reachability_constraints.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -689,13 +689,20 @@ impl ReachabilityConstraints {
689689
});
690690
truthiness
691691
}
692-
PatternPredicateKind::Class(class_expr) => {
692+
PatternPredicateKind::Class(class_expr, kind) => {
693693
let subject_ty = infer_expression_type(db, subject);
694694
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);
695695

696696
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
697697
if subject_ty.is_subtype_of(db, class_ty) {
698-
Truthiness::AlwaysTrue
698+
if kind.is_irrefutable() {
699+
Truthiness::AlwaysTrue
700+
} else {
701+
// A class pattern like `case Point(x=0, y=0)` is not irrefutable,
702+
// i.e. it does not match all instances of `Point`. This means that
703+
// we can't tell for sure if this pattern will match or not.
704+
Truthiness::Ambiguous
705+
}
699706
} else if subject_ty.is_disjoint_from(db, class_ty) {
700707
Truthiness::AlwaysFalse
701708
} else {

crates/ty_python_semantic/src/types/narrow.rs

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ use crate::semantic_index::expression::Expression;
33
use crate::semantic_index::place::{PlaceExpr, PlaceTable, ScopeId, ScopedPlaceId};
44
use crate::semantic_index::place_table;
55
use crate::semantic_index::predicate::{
6-
CallableAndCallExpr, PatternPredicate, PatternPredicateKind, Predicate, PredicateNode,
6+
CallableAndCallExpr, ClassPatternKind, PatternPredicate, PatternPredicateKind, Predicate,
7+
PredicateNode,
78
};
89
use crate::types::enums::{enum_member_literals, enum_metadata};
910
use crate::types::function::KnownFunction;
@@ -398,15 +399,18 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
398399
&mut self,
399400
pattern_predicate_kind: &PatternPredicateKind<'db>,
400401
subject: Expression<'db>,
402+
is_positive: bool,
401403
) -> Option<NarrowingConstraints<'db>> {
402404
match pattern_predicate_kind {
403405
PatternPredicateKind::Singleton(singleton) => {
404406
self.evaluate_match_pattern_singleton(subject, *singleton)
405407
}
406-
PatternPredicateKind::Class(cls) => self.evaluate_match_pattern_class(subject, *cls),
408+
PatternPredicateKind::Class(cls, kind) => {
409+
self.evaluate_match_pattern_class(subject, *cls, *kind, is_positive)
410+
}
407411
PatternPredicateKind::Value(expr) => self.evaluate_match_pattern_value(subject, *expr),
408412
PatternPredicateKind::Or(predicates) => {
409-
self.evaluate_match_pattern_or(subject, predicates)
413+
self.evaluate_match_pattern_or(subject, predicates, is_positive)
410414
}
411415
PatternPredicateKind::Unsupported => None,
412416
}
@@ -418,7 +422,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
418422
is_positive: bool,
419423
) -> Option<NarrowingConstraints<'db>> {
420424
let subject = pattern.subject(self.db);
421-
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject)
425+
self.evaluate_pattern_predicate_kind(pattern.kind(self.db), subject, is_positive)
422426
.map(|mut constraints| {
423427
negate_if(&mut constraints, self.db, !is_positive);
424428
constraints
@@ -905,7 +909,16 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
905909
&mut self,
906910
subject: Expression<'db>,
907911
cls: Expression<'db>,
912+
kind: ClassPatternKind,
913+
is_positive: bool,
908914
) -> Option<NarrowingConstraints<'db>> {
915+
if !kind.is_irrefutable() && !is_positive {
916+
// A class pattern like `case Point(x=0, y=0)` is not irrefutable. In the positive case,
917+
// we can still narrow the type of the match subject to `Point`. But in the negative case,
918+
// we cannot exclude `Point` as a possibility.
919+
return None;
920+
}
921+
909922
let subject = place_expr(subject.node_ref(self.db, self.module))?;
910923
let place = self.expect_place(&subject);
911924

@@ -930,12 +943,15 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
930943
&mut self,
931944
subject: Expression<'db>,
932945
predicates: &Vec<PatternPredicateKind<'db>>,
946+
is_positive: bool,
933947
) -> Option<NarrowingConstraints<'db>> {
934948
let db = self.db;
935949

936950
predicates
937951
.iter()
938-
.filter_map(|predicate| self.evaluate_pattern_predicate_kind(predicate, subject))
952+
.filter_map(|predicate| {
953+
self.evaluate_pattern_predicate_kind(predicate, subject, is_positive)
954+
})
939955
.reduce(|mut constraints, constraints_| {
940956
merge_constraints_or(&mut constraints, &constraints_, db);
941957
constraints

0 commit comments

Comments
 (0)