mirror of
https://github.com/astral-sh/ruff.git
synced 2025-08-03 02:12:22 +00:00
[red-knot] Type narrow in else clause (#13918)
## Summary Add support for type narrowing in elif and else scopes as part of #13694. ## Test Plan - mdtest - builder unit test for union negation. --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
3006d6da23
commit
35f007f17f
17 changed files with 363 additions and 53 deletions
|
@ -19,7 +19,8 @@ reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True]
|
|||
|
||||
```py
|
||||
# TODO: implement lookup of `__eq__` on typeshed `int` stub.
|
||||
def int_instance() -> int: ...
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
reveal_type(1 == int_instance()) # revealed: @Todo
|
||||
reveal_type(9 < int_instance()) # revealed: bool
|
||||
|
|
|
@ -59,7 +59,8 @@ reveal_type(c >= d) # revealed: Literal[True]
|
|||
|
||||
```py
|
||||
def bool_instance() -> bool: ...
|
||||
def int_instance() -> int: ...
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
a = (bool_instance(),)
|
||||
b = (int_instance(),)
|
||||
|
@ -159,7 +160,8 @@ reveal_type(a >= a) # revealed: @Todo
|
|||
"Membership Test Comparisons" refers to the operators `in` and `not in`.
|
||||
|
||||
```py
|
||||
def int_instance() -> int: ...
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
a = (1, 2)
|
||||
b = ((3, 4), (1, 2))
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
# Narrowing for conditionals with elif and else
|
||||
|
||||
## Positive contributions become negative in elif-else blocks
|
||||
|
||||
```py
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
x = int_instance()
|
||||
|
||||
if x == 1:
|
||||
# cannot narrow; could be a subclass of `int`
|
||||
reveal_type(x) # revealed: int
|
||||
elif x == 2:
|
||||
reveal_type(x) # revealed: int & ~Literal[1]
|
||||
elif x != 3:
|
||||
reveal_type(x) # revealed: int & ~Literal[1] & ~Literal[2] & ~Literal[3]
|
||||
```
|
||||
|
||||
## Positive contributions become negative in elif-else blocks, with simplification
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
x = 1 if bool_instance() else 2 if bool_instance() else 3
|
||||
|
||||
if x == 1:
|
||||
# TODO should be Literal[1]
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
elif x == 2:
|
||||
# TODO should be Literal[2]
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[3]
|
||||
```
|
||||
|
||||
## Multiple negative contributions using elif, with simplification
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
x = 1 if bool_instance() else 2 if bool_instance() else 3
|
||||
|
||||
if x != 1:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
elif x != 2:
|
||||
# TODO should be `Literal[1]`
|
||||
reveal_type(x) # revealed: Literal[1, 3]
|
||||
elif x == 3:
|
||||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
else:
|
||||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
|
@ -11,6 +11,8 @@ x = None if flag else 1
|
|||
|
||||
if x is None:
|
||||
reveal_type(x) # revealed: None
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
|
||||
reveal_type(x) # revealed: None | Literal[1]
|
||||
```
|
||||
|
@ -30,6 +32,8 @@ y = x if flag else None
|
|||
|
||||
if y is x:
|
||||
reveal_type(y) # revealed: A
|
||||
else:
|
||||
reveal_type(y) # revealed: A | None
|
||||
|
||||
reveal_type(y) # revealed: A | None
|
||||
```
|
||||
|
@ -50,4 +54,26 @@ reveal_type(y) # revealed: bool
|
|||
if y is x is False: # Interpreted as `(y is x) and (x is False)`
|
||||
reveal_type(x) # revealed: Literal[False]
|
||||
reveal_type(y) # revealed: bool
|
||||
else:
|
||||
# The negation of the clause above is (y is not x) or (x is not False)
|
||||
# So we can't narrow the type of x or y here, because each arm of the `or` could be true
|
||||
reveal_type(x) # revealed: bool
|
||||
reveal_type(y) # revealed: bool
|
||||
```
|
||||
|
||||
## `is` in elif clause
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
x = None if bool_instance() else (1 if bool_instance() else True)
|
||||
|
||||
reveal_type(x) # revealed: None | Literal[1] | Literal[True]
|
||||
if x is None:
|
||||
reveal_type(x) # revealed: None
|
||||
elif x is True:
|
||||
reveal_type(x) # revealed: Literal[True]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
```
|
||||
|
|
|
@ -13,6 +13,8 @@ x = None if flag else 1
|
|||
|
||||
if x is not None:
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
else:
|
||||
reveal_type(x) # revealed: None
|
||||
|
||||
reveal_type(x) # revealed: None | Literal[1]
|
||||
```
|
||||
|
@ -29,6 +31,8 @@ reveal_type(x) # revealed: bool
|
|||
|
||||
if x is not False:
|
||||
reveal_type(x) # revealed: Literal[True]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[False]
|
||||
```
|
||||
|
||||
## `is not` for non-singleton types
|
||||
|
@ -43,6 +47,27 @@ y = 345
|
|||
|
||||
if x is not y:
|
||||
reveal_type(x) # revealed: Literal[345]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[345]
|
||||
```
|
||||
|
||||
## `is not` for other types
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
class A: ...
|
||||
|
||||
x = A()
|
||||
y = x if bool_instance() else None
|
||||
|
||||
if y is not x:
|
||||
reveal_type(y) # revealed: A | None
|
||||
else:
|
||||
reveal_type(y) # revealed: A
|
||||
|
||||
reveal_type(y) # revealed: A | None
|
||||
```
|
||||
|
||||
## `is not` in chained comparisons
|
||||
|
@ -63,4 +88,10 @@ reveal_type(y) # revealed: bool
|
|||
if y is not x is not False: # Interpreted as `(y is not x) and (x is not False)`
|
||||
reveal_type(x) # revealed: Literal[True]
|
||||
reveal_type(y) # revealed: bool
|
||||
else:
|
||||
# The negation of the clause above is (y is x) or (x is False)
|
||||
# So we can't narrow the type of x or y here, because each arm of the `or` could be true
|
||||
|
||||
reveal_type(x) # revealed: bool
|
||||
reveal_type(y) # revealed: bool
|
||||
```
|
||||
|
|
|
@ -3,7 +3,8 @@
|
|||
## Multiple negative contributions
|
||||
|
||||
```py
|
||||
def int_instance() -> int: ...
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
x = int_instance()
|
||||
|
||||
|
@ -27,3 +28,29 @@ if x != 1:
|
|||
if x != 2:
|
||||
reveal_type(x) # revealed: Literal[3]
|
||||
```
|
||||
|
||||
## elif-else blocks
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
x = 1 if bool_instance() else 2 if bool_instance() else 3
|
||||
|
||||
if x != 1:
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
if x == 2:
|
||||
# TODO should be `Literal[2]`
|
||||
reveal_type(x) # revealed: Literal[2, 3]
|
||||
elif x == 3:
|
||||
reveal_type(x) # revealed: Literal[3]
|
||||
else:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
elif x != 2:
|
||||
# TODO should be Literal[1]
|
||||
reveal_type(x) # revealed: Literal[1, 3]
|
||||
else:
|
||||
# TODO should be Never
|
||||
reveal_type(x) # revealed: Literal[1, 2, 3]
|
||||
```
|
||||
|
|
|
@ -11,6 +11,9 @@ x = None if flag else 1
|
|||
|
||||
if x != None:
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
else:
|
||||
# TODO should be None
|
||||
reveal_type(x) # revealed: None | Literal[1]
|
||||
```
|
||||
|
||||
## `!=` for other singleton types
|
||||
|
@ -24,6 +27,9 @@ x = True if flag else False
|
|||
|
||||
if x != False:
|
||||
reveal_type(x) # revealed: Literal[True]
|
||||
else:
|
||||
# TODO should be Literal[False]
|
||||
reveal_type(x) # revealed: bool
|
||||
```
|
||||
|
||||
## `x != y` where `y` is of literal type
|
||||
|
@ -54,6 +60,25 @@ C = A if flag else B
|
|||
|
||||
if C != A:
|
||||
reveal_type(C) # revealed: Literal[B]
|
||||
else:
|
||||
# TODO should be Literal[A]
|
||||
reveal_type(C) # revealed: Literal[A, B]
|
||||
```
|
||||
|
||||
## `x != y` where `y` has multiple single-valued options
|
||||
|
||||
```py
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
x = 1 if bool_instance() else 2
|
||||
y = 2 if bool_instance() else 3
|
||||
|
||||
if x != y:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
else:
|
||||
# TODO should be Literal[2]
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
||||
|
||||
## `!=` for non-single-valued types
|
||||
|
@ -74,3 +99,21 @@ y = int_instance()
|
|||
if x != y:
|
||||
reveal_type(x) # revealed: int | None
|
||||
```
|
||||
|
||||
## Mix of single-valued and non-single-valued types
|
||||
|
||||
```py
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
def bool_instance() -> bool:
|
||||
return True
|
||||
|
||||
x = 1 if bool_instance() else 2
|
||||
y = 2 if bool_instance() else int_instance()
|
||||
|
||||
if x != y:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal[1, 2]
|
||||
```
|
||||
|
|
|
@ -40,6 +40,8 @@ x = 1 if flag else "a"
|
|||
|
||||
if isinstance(x, (int, str)):
|
||||
reveal_type(x) # revealed: Literal[1] | Literal["a"]
|
||||
else:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
if isinstance(x, (int, bytes)):
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
|
@ -51,6 +53,8 @@ if isinstance(x, (bytes, str)):
|
|||
# one of the possibilities:
|
||||
if isinstance(x, (int, object)):
|
||||
reveal_type(x) # revealed: Literal[1] | Literal["a"]
|
||||
else:
|
||||
reveal_type(x) # revealed: Never
|
||||
|
||||
y = 1 if flag1 else "a" if flag2 else b"b"
|
||||
if isinstance(y, (int, str)):
|
||||
|
@ -75,6 +79,8 @@ x = 1 if flag else "a"
|
|||
|
||||
if isinstance(x, (bool, (bytes, int))):
|
||||
reveal_type(x) # revealed: Literal[1]
|
||||
else:
|
||||
reveal_type(x) # revealed: Literal["a"]
|
||||
```
|
||||
|
||||
## Class types
|
||||
|
@ -82,6 +88,7 @@ if isinstance(x, (bool, (bytes, int))):
|
|||
```py
|
||||
class A: ...
|
||||
class B: ...
|
||||
class C: ...
|
||||
|
||||
def get_object() -> object: ...
|
||||
|
||||
|
@ -91,6 +98,16 @@ if isinstance(x, A):
|
|||
reveal_type(x) # revealed: A
|
||||
if isinstance(x, B):
|
||||
reveal_type(x) # revealed: A & B
|
||||
else:
|
||||
reveal_type(x) # revealed: A & ~B
|
||||
|
||||
if isinstance(x, (A, B)):
|
||||
reveal_type(x) # revealed: A | B
|
||||
elif isinstance(x, (A, C)):
|
||||
reveal_type(x) # revealed: C & ~A & ~B
|
||||
else:
|
||||
# TODO: Should be simplified to ~A & ~B & ~C
|
||||
reveal_type(x) # revealed: object & ~A & ~B & ~C
|
||||
```
|
||||
|
||||
## No narrowing for instances of `builtins.type`
|
||||
|
|
|
@ -26,7 +26,8 @@ reveal_type(y) # revealed: Unknown
|
|||
## Function return
|
||||
|
||||
```py
|
||||
def int_instance() -> int: ...
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
a = b"abcde"[int_instance()]
|
||||
# TODO: Support overloads... Should be `bytes`
|
||||
|
|
|
@ -23,7 +23,8 @@ reveal_type(b) # revealed: Unknown
|
|||
## Function return
|
||||
|
||||
```py
|
||||
def int_instance() -> int: ...
|
||||
def int_instance() -> int:
|
||||
return 42
|
||||
|
||||
a = "abcde"[int_instance()]
|
||||
# TODO: Support overloads... Should be `str`
|
||||
|
|
|
@ -27,7 +27,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
|
|||
use crate::semantic_index::SemanticIndex;
|
||||
use crate::Db;
|
||||
|
||||
use super::constraint::{Constraint, PatternConstraint};
|
||||
use super::constraint::{Constraint, ConstraintNode, PatternConstraint};
|
||||
use super::definition::{
|
||||
AssignmentKind, DefinitionCategory, ExceptHandlerDefinitionNodeRef,
|
||||
MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
|
||||
|
@ -243,12 +243,23 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
definition
|
||||
}
|
||||
|
||||
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
|
||||
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> {
|
||||
let expression = self.add_standalone_expression(constraint_node);
|
||||
self.current_use_def_map_mut()
|
||||
.record_constraint(Constraint::Expression(expression));
|
||||
let constraint = Constraint {
|
||||
node: ConstraintNode::Expression(expression),
|
||||
is_positive: true,
|
||||
};
|
||||
self.current_use_def_map_mut().record_constraint(constraint);
|
||||
|
||||
expression
|
||||
constraint
|
||||
}
|
||||
|
||||
fn add_negated_constraint(&mut self, constraint: Constraint<'db>) {
|
||||
self.current_use_def_map_mut()
|
||||
.record_constraint(Constraint {
|
||||
node: constraint.node,
|
||||
is_positive: false,
|
||||
});
|
||||
}
|
||||
|
||||
fn push_assignment(&mut self, assignment: CurrentAssignment<'db>) {
|
||||
|
@ -285,7 +296,10 @@ impl<'db> SemanticIndexBuilder<'db> {
|
|||
countme::Count::default(),
|
||||
);
|
||||
self.current_use_def_map_mut()
|
||||
.record_constraint(Constraint::Pattern(pattern_constraint));
|
||||
.record_constraint(Constraint {
|
||||
node: ConstraintNode::Pattern(pattern_constraint),
|
||||
is_positive: true,
|
||||
});
|
||||
pattern_constraint
|
||||
}
|
||||
|
||||
|
@ -639,7 +653,8 @@ where
|
|||
ast::Stmt::If(node) => {
|
||||
self.visit_expr(&node.test);
|
||||
let pre_if = self.flow_snapshot();
|
||||
self.add_expression_constraint(&node.test);
|
||||
let constraint = self.add_expression_constraint(&node.test);
|
||||
let mut constraints = vec![constraint];
|
||||
self.visit_body(&node.body);
|
||||
let mut post_clauses: Vec<FlowSnapshot> = vec![];
|
||||
for clause in &node.elif_else_clauses {
|
||||
|
@ -649,7 +664,14 @@ where
|
|||
// we can only take an elif/else branch if none of the previous ones were
|
||||
// taken, so the block entry state is always `pre_if`
|
||||
self.flow_restore(pre_if.clone());
|
||||
self.visit_elif_else_clause(clause);
|
||||
for constraint in &constraints {
|
||||
self.add_negated_constraint(*constraint);
|
||||
}
|
||||
if let Some(elif_test) = &clause.test {
|
||||
self.visit_expr(elif_test);
|
||||
constraints.push(self.add_expression_constraint(elif_test));
|
||||
}
|
||||
self.visit_body(&clause.body);
|
||||
}
|
||||
for post_clause_state in post_clauses {
|
||||
self.flow_merge(post_clause_state);
|
||||
|
|
|
@ -7,7 +7,13 @@ use crate::semantic_index::expression::Expression;
|
|||
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum Constraint<'db> {
|
||||
pub(crate) struct Constraint<'db> {
|
||||
pub(crate) node: ConstraintNode<'db>,
|
||||
pub(crate) is_positive: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub(crate) enum ConstraintNode<'db> {
|
||||
Expression(Expression<'db>),
|
||||
Pattern(PatternConstraint<'db>),
|
||||
}
|
||||
|
|
|
@ -332,6 +332,11 @@ impl<'db> Type<'db> {
|
|||
.expect("Expected a Type::ModuleLiteral variant")
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn negate(&self, db: &'db dyn Db) -> Type<'db> {
|
||||
IntersectionBuilder::new(db).add_negative(*self).build()
|
||||
}
|
||||
|
||||
pub const fn into_union_type(self) -> Option<UnionType<'db>> {
|
||||
match self {
|
||||
Type::Union(union_type) => Some(union_type),
|
||||
|
|
|
@ -173,14 +173,10 @@ impl<'db> IntersectionBuilder<'db> {
|
|||
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
|
||||
// See comments above in `add_positive`; this is just the negated version.
|
||||
if let Type::Union(union) = ty {
|
||||
union
|
||||
.elements(self.db)
|
||||
.iter()
|
||||
.map(|elem| self.clone().add_negative(*elem))
|
||||
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
|
||||
builder.intersections.extend(sub.intersections);
|
||||
builder
|
||||
})
|
||||
for elem in union.elements(self.db) {
|
||||
self = self.add_negative(*elem);
|
||||
}
|
||||
self
|
||||
} else {
|
||||
for inner in &mut self.intersections {
|
||||
inner.add_negative(self.db, ty);
|
||||
|
@ -667,6 +663,27 @@ mod tests {
|
|||
assert_eq!(ty, Type::IntLiteral(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_negative_union_de_morgan() {
|
||||
let db = setup_db();
|
||||
|
||||
let union = UnionBuilder::new(&db)
|
||||
.add(Type::IntLiteral(1))
|
||||
.add(Type::IntLiteral(2))
|
||||
.build();
|
||||
assert_eq!(union.display(&db).to_string(), "Literal[1, 2]");
|
||||
|
||||
let ty = IntersectionBuilder::new(&db).add_negative(union).build();
|
||||
|
||||
let expected = IntersectionBuilder::new(&db)
|
||||
.add_negative(Type::IntLiteral(1))
|
||||
.add_negative(Type::IntLiteral(2))
|
||||
.build();
|
||||
|
||||
assert_eq!(ty.display(&db).to_string(), "~Literal[1] & ~Literal[2]");
|
||||
assert_eq!(ty, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_intersection_simplify_positive_type_and_positive_subtype() {
|
||||
let db = setup_db();
|
||||
|
|
|
@ -2415,7 +2415,6 @@ impl<'db> TypeInferenceBuilder<'db> {
|
|||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let ty = bindings_ty(self.db, definitions, unbound_ty);
|
||||
|
||||
if ty.is_unbound() {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::semantic_index::ast_ids::HasScopedAstId;
|
||||
use crate::semantic_index::constraint::{Constraint, PatternConstraint};
|
||||
use crate::semantic_index::constraint::{Constraint, ConstraintNode, PatternConstraint};
|
||||
use crate::semantic_index::definition::Definition;
|
||||
use crate::semantic_index::expression::Expression;
|
||||
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
|
||||
|
@ -34,13 +34,19 @@ pub(crate) fn narrowing_constraint<'db>(
|
|||
constraint: Constraint<'db>,
|
||||
definition: Definition<'db>,
|
||||
) -> Option<Type<'db>> {
|
||||
match constraint {
|
||||
Constraint::Expression(expression) => {
|
||||
all_narrowing_constraints_for_expression(db, expression)
|
||||
.get(&definition.symbol(db))
|
||||
.copied()
|
||||
match constraint.node {
|
||||
ConstraintNode::Expression(expression) => {
|
||||
if constraint.is_positive {
|
||||
all_narrowing_constraints_for_expression(db, expression)
|
||||
.get(&definition.symbol(db))
|
||||
.copied()
|
||||
} else {
|
||||
all_negative_narrowing_constraints_for_expression(db, expression)
|
||||
.get(&definition.symbol(db))
|
||||
.copied()
|
||||
}
|
||||
}
|
||||
Constraint::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern)
|
||||
ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern)
|
||||
.get(&definition.symbol(db))
|
||||
.copied(),
|
||||
}
|
||||
|
@ -51,7 +57,7 @@ fn all_narrowing_constraints_for_pattern<'db>(
|
|||
db: &'db dyn Db,
|
||||
pattern: PatternConstraint<'db>,
|
||||
) -> NarrowingConstraints<'db> {
|
||||
NarrowingConstraintsBuilder::new(db, Constraint::Pattern(pattern)).finish()
|
||||
NarrowingConstraintsBuilder::new(db, ConstraintNode::Pattern(pattern), true).finish()
|
||||
}
|
||||
|
||||
#[salsa::tracked(return_ref)]
|
||||
|
@ -59,7 +65,15 @@ fn all_narrowing_constraints_for_expression<'db>(
|
|||
db: &'db dyn Db,
|
||||
expression: Expression<'db>,
|
||||
) -> NarrowingConstraints<'db> {
|
||||
NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish()
|
||||
NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), true).finish()
|
||||
}
|
||||
|
||||
#[salsa::tracked(return_ref)]
|
||||
fn all_negative_narrowing_constraints_for_expression<'db>(
|
||||
db: &'db dyn Db,
|
||||
expression: Expression<'db>,
|
||||
) -> NarrowingConstraints<'db> {
|
||||
NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish()
|
||||
}
|
||||
|
||||
/// Generate a constraint from the *type* of the second argument of an `isinstance` call.
|
||||
|
@ -88,36 +102,39 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
|
|||
|
||||
struct NarrowingConstraintsBuilder<'db> {
|
||||
db: &'db dyn Db,
|
||||
constraint: Constraint<'db>,
|
||||
constraint: ConstraintNode<'db>,
|
||||
is_positive: bool,
|
||||
constraints: NarrowingConstraints<'db>,
|
||||
}
|
||||
|
||||
impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||
fn new(db: &'db dyn Db, constraint: Constraint<'db>) -> Self {
|
||||
fn new(db: &'db dyn Db, constraint: ConstraintNode<'db>, is_positive: bool) -> Self {
|
||||
Self {
|
||||
db,
|
||||
constraint,
|
||||
is_positive,
|
||||
constraints: NarrowingConstraints::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(mut self) -> NarrowingConstraints<'db> {
|
||||
match self.constraint {
|
||||
Constraint::Expression(expression) => self.evaluate_expression_constraint(expression),
|
||||
Constraint::Pattern(pattern) => self.evaluate_pattern_constraint(pattern),
|
||||
ConstraintNode::Expression(expression) => {
|
||||
self.evaluate_expression_constraint(expression, self.is_positive);
|
||||
}
|
||||
ConstraintNode::Pattern(pattern) => self.evaluate_pattern_constraint(pattern),
|
||||
}
|
||||
|
||||
self.constraints.shrink_to_fit();
|
||||
self.constraints
|
||||
}
|
||||
|
||||
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) {
|
||||
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>, is_positive: bool) {
|
||||
match expression.node_ref(self.db).node() {
|
||||
ast::Expr::Compare(expr_compare) => {
|
||||
self.add_expr_compare(expr_compare, expression);
|
||||
self.add_expr_compare(expr_compare, expression, is_positive);
|
||||
}
|
||||
ast::Expr::Call(expr_call) => {
|
||||
self.add_expr_call(expr_call, expression);
|
||||
self.add_expr_call(expr_call, expression, is_positive);
|
||||
}
|
||||
_ => {} // TODO other test expression kinds
|
||||
}
|
||||
|
@ -160,12 +177,17 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
|
||||
fn scope(&self) -> ScopeId<'db> {
|
||||
match self.constraint {
|
||||
Constraint::Expression(expression) => expression.scope(self.db),
|
||||
Constraint::Pattern(pattern) => pattern.scope(self.db),
|
||||
ConstraintNode::Expression(expression) => expression.scope(self.db),
|
||||
ConstraintNode::Pattern(pattern) => pattern.scope(self.db),
|
||||
}
|
||||
}
|
||||
|
||||
fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>) {
|
||||
fn add_expr_compare(
|
||||
&mut self,
|
||||
expr_compare: &ast::ExprCompare,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) {
|
||||
let ast::ExprCompare {
|
||||
range: _,
|
||||
left,
|
||||
|
@ -177,6 +199,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
// we have no symbol to narrow down the type of.
|
||||
return;
|
||||
}
|
||||
if !is_positive && comparators.len() > 1 {
|
||||
// We can't negate a constraint made by a multi-comparator expression, since we can't
|
||||
// know which comparison part is the one being negated.
|
||||
// For example, the negation of `x is 1 is y is 2`, would be `(x is not 1) or (y is not 1) or (y is not 2)`
|
||||
// and that requires cross-symbol constraints, which we don't support yet.
|
||||
return;
|
||||
}
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
|
||||
|
@ -192,12 +221,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
{
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
|
||||
match op {
|
||||
let rhs_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
|
||||
|
||||
match if is_positive { *op } else { op.negate() } {
|
||||
ast::CmpOp::IsNot => {
|
||||
if comp_ty.is_singleton() {
|
||||
if rhs_ty.is_singleton() {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(comp_ty)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
self.constraints.insert(symbol, ty);
|
||||
} else {
|
||||
|
@ -205,12 +235,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
ast::CmpOp::Is => {
|
||||
self.constraints.insert(symbol, comp_ty);
|
||||
self.constraints.insert(symbol, rhs_ty);
|
||||
}
|
||||
ast::CmpOp::NotEq => {
|
||||
if comp_ty.is_single_valued(self.db) {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(comp_ty)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
self.constraints.insert(symbol, ty);
|
||||
}
|
||||
|
@ -223,7 +253,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_expr_call(&mut self, expr_call: &ast::ExprCall, expression: Expression<'db>) {
|
||||
fn add_expr_call(
|
||||
&mut self,
|
||||
expr_call: &ast::ExprCall,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) {
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
|
||||
|
@ -242,7 +277,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
|
||||
// TODO: add support for PEP 604 union types on the right hand side:
|
||||
// isinstance(x, str | (int | float))
|
||||
if let Some(constraint) = generate_isinstance_constraint(self.db, &rhs_type) {
|
||||
if let Some(mut constraint) = generate_isinstance_constraint(self.db, &rhs_type)
|
||||
{
|
||||
if !is_positive {
|
||||
constraint = constraint.negate(self.db);
|
||||
}
|
||||
self.constraints.insert(symbol, constraint);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue