[red-knot] Type narrow in else clause (#13918)

## Summary

Add support for type narrowing in elif and else scopes as part of
#13694.

## Test Plan

- mdtest
- builder unit test for union negation.

---------

Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
TomerBin 2024-10-26 19:22:57 +03:00 committed by GitHub
parent 3006d6da23
commit 35f007f17f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 363 additions and 53 deletions

View file

@ -19,7 +19,8 @@ reveal_type(1 <= "" and 0 < 1) # revealed: @Todo | Literal[True]
```py
# TODO: implement lookup of `__eq__` on typeshed `int` stub.
def int_instance() -> int: ...
def int_instance() -> int:
return 42
reveal_type(1 == int_instance()) # revealed: @Todo
reveal_type(9 < int_instance()) # revealed: bool

View file

@ -59,7 +59,8 @@ reveal_type(c >= d) # revealed: Literal[True]
```py
def bool_instance() -> bool: ...
def int_instance() -> int: ...
def int_instance() -> int:
return 42
a = (bool_instance(),)
b = (int_instance(),)
@ -159,7 +160,8 @@ reveal_type(a >= a) # revealed: @Todo
"Membership Test Comparisons" refers to the operators `in` and `not in`.
```py
def int_instance() -> int: ...
def int_instance() -> int:
return 42
a = (1, 2)
b = ((3, 4), (1, 2))

View file

@ -0,0 +1,57 @@
# Narrowing for conditionals with elif and else
## Positive contributions become negative in elif-else blocks
```py
def int_instance() -> int:
return 42
x = int_instance()
if x == 1:
# cannot narrow; could be a subclass of `int`
reveal_type(x) # revealed: int
elif x == 2:
reveal_type(x) # revealed: int & ~Literal[1]
elif x != 3:
reveal_type(x) # revealed: int & ~Literal[1] & ~Literal[2] & ~Literal[3]
```
## Positive contributions become negative in elif-else blocks, with simplification
```py
def bool_instance() -> bool:
return True
x = 1 if bool_instance() else 2 if bool_instance() else 3
if x == 1:
# TODO should be Literal[1]
reveal_type(x) # revealed: Literal[1, 2, 3]
elif x == 2:
# TODO should be Literal[2]
reveal_type(x) # revealed: Literal[2, 3]
else:
reveal_type(x) # revealed: Literal[3]
```
## Multiple negative contributions using elif, with simplification
```py
def bool_instance() -> bool:
return True
x = 1 if bool_instance() else 2 if bool_instance() else 3
if x != 1:
reveal_type(x) # revealed: Literal[2, 3]
elif x != 2:
# TODO should be `Literal[1]`
reveal_type(x) # revealed: Literal[1, 3]
elif x == 3:
# TODO should be Never
reveal_type(x) # revealed: Literal[1, 2, 3]
else:
# TODO should be Never
reveal_type(x) # revealed: Literal[1, 2]
```

View file

@ -11,6 +11,8 @@ x = None if flag else 1
if x is None:
reveal_type(x) # revealed: None
else:
reveal_type(x) # revealed: Literal[1]
reveal_type(x) # revealed: None | Literal[1]
```
@ -30,6 +32,8 @@ y = x if flag else None
if y is x:
reveal_type(y) # revealed: A
else:
reveal_type(y) # revealed: A | None
reveal_type(y) # revealed: A | None
```
@ -50,4 +54,26 @@ reveal_type(y) # revealed: bool
if y is x is False: # Interpreted as `(y is x) and (x is False)`
reveal_type(x) # revealed: Literal[False]
reveal_type(y) # revealed: bool
else:
# The negation of the clause above is (y is not x) or (x is not False)
# So we can't narrow the type of x or y here, because each arm of the `or` could be true
reveal_type(x) # revealed: bool
reveal_type(y) # revealed: bool
```
## `is` in elif clause
```py
def bool_instance() -> bool:
return True
x = None if bool_instance() else (1 if bool_instance() else True)
reveal_type(x) # revealed: None | Literal[1] | Literal[True]
if x is None:
reveal_type(x) # revealed: None
elif x is True:
reveal_type(x) # revealed: Literal[True]
else:
reveal_type(x) # revealed: Literal[1]
```

View file

@ -13,6 +13,8 @@ x = None if flag else 1
if x is not None:
reveal_type(x) # revealed: Literal[1]
else:
reveal_type(x) # revealed: None
reveal_type(x) # revealed: None | Literal[1]
```
@ -29,6 +31,8 @@ reveal_type(x) # revealed: bool
if x is not False:
reveal_type(x) # revealed: Literal[True]
else:
reveal_type(x) # revealed: Literal[False]
```
## `is not` for non-singleton types
@ -43,6 +47,27 @@ y = 345
if x is not y:
reveal_type(x) # revealed: Literal[345]
else:
reveal_type(x) # revealed: Literal[345]
```
## `is not` for other types
```py
def bool_instance() -> bool:
return True
class A: ...
x = A()
y = x if bool_instance() else None
if y is not x:
reveal_type(y) # revealed: A | None
else:
reveal_type(y) # revealed: A
reveal_type(y) # revealed: A | None
```
## `is not` in chained comparisons
@ -63,4 +88,10 @@ reveal_type(y) # revealed: bool
if y is not x is not False: # Interpreted as `(y is not x) and (x is not False)`
reveal_type(x) # revealed: Literal[True]
reveal_type(y) # revealed: bool
else:
# The negation of the clause above is (y is x) or (x is False)
# So we can't narrow the type of x or y here, because each arm of the `or` could be true
reveal_type(x) # revealed: bool
reveal_type(y) # revealed: bool
```

View file

@ -3,7 +3,8 @@
## Multiple negative contributions
```py
def int_instance() -> int: ...
def int_instance() -> int:
return 42
x = int_instance()
@ -27,3 +28,29 @@ if x != 1:
if x != 2:
reveal_type(x) # revealed: Literal[3]
```
## elif-else blocks
```py
def bool_instance() -> bool:
return True
x = 1 if bool_instance() else 2 if bool_instance() else 3
if x != 1:
reveal_type(x) # revealed: Literal[2, 3]
if x == 2:
# TODO should be `Literal[2]`
reveal_type(x) # revealed: Literal[2, 3]
elif x == 3:
reveal_type(x) # revealed: Literal[3]
else:
reveal_type(x) # revealed: Never
elif x != 2:
# TODO should be Literal[1]
reveal_type(x) # revealed: Literal[1, 3]
else:
# TODO should be Never
reveal_type(x) # revealed: Literal[1, 2, 3]
```

View file

@ -11,6 +11,9 @@ x = None if flag else 1
if x != None:
reveal_type(x) # revealed: Literal[1]
else:
# TODO should be None
reveal_type(x) # revealed: None | Literal[1]
```
## `!=` for other singleton types
@ -24,6 +27,9 @@ x = True if flag else False
if x != False:
reveal_type(x) # revealed: Literal[True]
else:
# TODO should be Literal[False]
reveal_type(x) # revealed: bool
```
## `x != y` where `y` is of literal type
@ -54,6 +60,25 @@ C = A if flag else B
if C != A:
reveal_type(C) # revealed: Literal[B]
else:
# TODO should be Literal[A]
reveal_type(C) # revealed: Literal[A, B]
```
## `x != y` where `y` has multiple single-valued options
```py
def bool_instance() -> bool:
return True
x = 1 if bool_instance() else 2
y = 2 if bool_instance() else 3
if x != y:
reveal_type(x) # revealed: Literal[1, 2]
else:
# TODO should be Literal[2]
reveal_type(x) # revealed: Literal[1, 2]
```
## `!=` for non-single-valued types
@ -74,3 +99,21 @@ y = int_instance()
if x != y:
reveal_type(x) # revealed: int | None
```
## Mix of single-valued and non-single-valued types
```py
def int_instance() -> int:
return 42
def bool_instance() -> bool:
return True
x = 1 if bool_instance() else 2
y = 2 if bool_instance() else int_instance()
if x != y:
reveal_type(x) # revealed: Literal[1, 2]
else:
reveal_type(x) # revealed: Literal[1, 2]
```

View file

@ -40,6 +40,8 @@ x = 1 if flag else "a"
if isinstance(x, (int, str)):
reveal_type(x) # revealed: Literal[1] | Literal["a"]
else:
reveal_type(x) # revealed: Never
if isinstance(x, (int, bytes)):
reveal_type(x) # revealed: Literal[1]
@ -51,6 +53,8 @@ if isinstance(x, (bytes, str)):
# one of the possibilities:
if isinstance(x, (int, object)):
reveal_type(x) # revealed: Literal[1] | Literal["a"]
else:
reveal_type(x) # revealed: Never
y = 1 if flag1 else "a" if flag2 else b"b"
if isinstance(y, (int, str)):
@ -75,6 +79,8 @@ x = 1 if flag else "a"
if isinstance(x, (bool, (bytes, int))):
reveal_type(x) # revealed: Literal[1]
else:
reveal_type(x) # revealed: Literal["a"]
```
## Class types
@ -82,6 +88,7 @@ if isinstance(x, (bool, (bytes, int))):
```py
class A: ...
class B: ...
class C: ...
def get_object() -> object: ...
@ -91,6 +98,16 @@ if isinstance(x, A):
reveal_type(x) # revealed: A
if isinstance(x, B):
reveal_type(x) # revealed: A & B
else:
reveal_type(x) # revealed: A & ~B
if isinstance(x, (A, B)):
reveal_type(x) # revealed: A | B
elif isinstance(x, (A, C)):
reveal_type(x) # revealed: C & ~A & ~B
else:
# TODO: Should be simplified to ~A & ~B & ~C
reveal_type(x) # revealed: object & ~A & ~B & ~C
```
## No narrowing for instances of `builtins.type`

View file

@ -26,7 +26,8 @@ reveal_type(y) # revealed: Unknown
## Function return
```py
def int_instance() -> int: ...
def int_instance() -> int:
return 42
a = b"abcde"[int_instance()]
# TODO: Support overloads... Should be `bytes`

View file

@ -23,7 +23,8 @@ reveal_type(b) # revealed: Unknown
## Function return
```py
def int_instance() -> int: ...
def int_instance() -> int:
return 42
a = "abcde"[int_instance()]
# TODO: Support overloads... Should be `str`

View file

@ -27,7 +27,7 @@ use crate::semantic_index::use_def::{FlowSnapshot, UseDefMapBuilder};
use crate::semantic_index::SemanticIndex;
use crate::Db;
use super::constraint::{Constraint, PatternConstraint};
use super::constraint::{Constraint, ConstraintNode, PatternConstraint};
use super::definition::{
AssignmentKind, DefinitionCategory, ExceptHandlerDefinitionNodeRef,
MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef,
@ -243,12 +243,23 @@ impl<'db> SemanticIndexBuilder<'db> {
definition
}
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Expression<'db> {
fn add_expression_constraint(&mut self, constraint_node: &ast::Expr) -> Constraint<'db> {
let expression = self.add_standalone_expression(constraint_node);
self.current_use_def_map_mut()
.record_constraint(Constraint::Expression(expression));
let constraint = Constraint {
node: ConstraintNode::Expression(expression),
is_positive: true,
};
self.current_use_def_map_mut().record_constraint(constraint);
expression
constraint
}
fn add_negated_constraint(&mut self, constraint: Constraint<'db>) {
self.current_use_def_map_mut()
.record_constraint(Constraint {
node: constraint.node,
is_positive: false,
});
}
fn push_assignment(&mut self, assignment: CurrentAssignment<'db>) {
@ -285,7 +296,10 @@ impl<'db> SemanticIndexBuilder<'db> {
countme::Count::default(),
);
self.current_use_def_map_mut()
.record_constraint(Constraint::Pattern(pattern_constraint));
.record_constraint(Constraint {
node: ConstraintNode::Pattern(pattern_constraint),
is_positive: true,
});
pattern_constraint
}
@ -639,7 +653,8 @@ where
ast::Stmt::If(node) => {
self.visit_expr(&node.test);
let pre_if = self.flow_snapshot();
self.add_expression_constraint(&node.test);
let constraint = self.add_expression_constraint(&node.test);
let mut constraints = vec![constraint];
self.visit_body(&node.body);
let mut post_clauses: Vec<FlowSnapshot> = vec![];
for clause in &node.elif_else_clauses {
@ -649,7 +664,14 @@ where
// we can only take an elif/else branch if none of the previous ones were
// taken, so the block entry state is always `pre_if`
self.flow_restore(pre_if.clone());
self.visit_elif_else_clause(clause);
for constraint in &constraints {
self.add_negated_constraint(*constraint);
}
if let Some(elif_test) = &clause.test {
self.visit_expr(elif_test);
constraints.push(self.add_expression_constraint(elif_test));
}
self.visit_body(&clause.body);
}
for post_clause_state in post_clauses {
self.flow_merge(post_clause_state);

View file

@ -7,7 +7,13 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{FileScopeId, ScopeId};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum Constraint<'db> {
pub(crate) struct Constraint<'db> {
pub(crate) node: ConstraintNode<'db>,
pub(crate) is_positive: bool,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum ConstraintNode<'db> {
Expression(Expression<'db>),
Pattern(PatternConstraint<'db>),
}

View file

@ -332,6 +332,11 @@ impl<'db> Type<'db> {
.expect("Expected a Type::ModuleLiteral variant")
}
#[must_use]
pub fn negate(&self, db: &'db dyn Db) -> Type<'db> {
IntersectionBuilder::new(db).add_negative(*self).build()
}
pub const fn into_union_type(self) -> Option<UnionType<'db>> {
match self {
Type::Union(union_type) => Some(union_type),

View file

@ -173,14 +173,10 @@ impl<'db> IntersectionBuilder<'db> {
pub(crate) fn add_negative(mut self, ty: Type<'db>) -> Self {
// See comments above in `add_positive`; this is just the negated version.
if let Type::Union(union) = ty {
union
.elements(self.db)
.iter()
.map(|elem| self.clone().add_negative(*elem))
.fold(IntersectionBuilder::empty(self.db), |mut builder, sub| {
builder.intersections.extend(sub.intersections);
builder
})
for elem in union.elements(self.db) {
self = self.add_negative(*elem);
}
self
} else {
for inner in &mut self.intersections {
inner.add_negative(self.db, ty);
@ -667,6 +663,27 @@ mod tests {
assert_eq!(ty, Type::IntLiteral(1));
}
#[test]
fn build_negative_union_de_morgan() {
let db = setup_db();
let union = UnionBuilder::new(&db)
.add(Type::IntLiteral(1))
.add(Type::IntLiteral(2))
.build();
assert_eq!(union.display(&db).to_string(), "Literal[1, 2]");
let ty = IntersectionBuilder::new(&db).add_negative(union).build();
let expected = IntersectionBuilder::new(&db)
.add_negative(Type::IntLiteral(1))
.add_negative(Type::IntLiteral(2))
.build();
assert_eq!(ty.display(&db).to_string(), "~Literal[1] & ~Literal[2]");
assert_eq!(ty, expected);
}
#[test]
fn build_intersection_simplify_positive_type_and_positive_subtype() {
let db = setup_db();

View file

@ -2415,7 +2415,6 @@ impl<'db> TypeInferenceBuilder<'db> {
} else {
None
};
let ty = bindings_ty(self.db, definitions, unbound_ty);
if ty.is_unbound() {

View file

@ -1,5 +1,5 @@
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::constraint::{Constraint, PatternConstraint};
use crate::semantic_index::constraint::{Constraint, ConstraintNode, PatternConstraint};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::expression::Expression;
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
@ -34,13 +34,19 @@ pub(crate) fn narrowing_constraint<'db>(
constraint: Constraint<'db>,
definition: Definition<'db>,
) -> Option<Type<'db>> {
match constraint {
Constraint::Expression(expression) => {
all_narrowing_constraints_for_expression(db, expression)
.get(&definition.symbol(db))
.copied()
match constraint.node {
ConstraintNode::Expression(expression) => {
if constraint.is_positive {
all_narrowing_constraints_for_expression(db, expression)
.get(&definition.symbol(db))
.copied()
} else {
all_negative_narrowing_constraints_for_expression(db, expression)
.get(&definition.symbol(db))
.copied()
}
}
Constraint::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern)
ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern)
.get(&definition.symbol(db))
.copied(),
}
@ -51,7 +57,7 @@ fn all_narrowing_constraints_for_pattern<'db>(
db: &'db dyn Db,
pattern: PatternConstraint<'db>,
) -> NarrowingConstraints<'db> {
NarrowingConstraintsBuilder::new(db, Constraint::Pattern(pattern)).finish()
NarrowingConstraintsBuilder::new(db, ConstraintNode::Pattern(pattern), true).finish()
}
#[salsa::tracked(return_ref)]
@ -59,7 +65,15 @@ fn all_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> NarrowingConstraints<'db> {
NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish()
NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), true).finish()
}
#[salsa::tracked(return_ref)]
fn all_negative_narrowing_constraints_for_expression<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> NarrowingConstraints<'db> {
NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish()
}
/// Generate a constraint from the *type* of the second argument of an `isinstance` call.
@ -88,36 +102,39 @@ type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
struct NarrowingConstraintsBuilder<'db> {
db: &'db dyn Db,
constraint: Constraint<'db>,
constraint: ConstraintNode<'db>,
is_positive: bool,
constraints: NarrowingConstraints<'db>,
}
impl<'db> NarrowingConstraintsBuilder<'db> {
fn new(db: &'db dyn Db, constraint: Constraint<'db>) -> Self {
fn new(db: &'db dyn Db, constraint: ConstraintNode<'db>, is_positive: bool) -> Self {
Self {
db,
constraint,
is_positive,
constraints: NarrowingConstraints::default(),
}
}
fn finish(mut self) -> NarrowingConstraints<'db> {
match self.constraint {
Constraint::Expression(expression) => self.evaluate_expression_constraint(expression),
Constraint::Pattern(pattern) => self.evaluate_pattern_constraint(pattern),
ConstraintNode::Expression(expression) => {
self.evaluate_expression_constraint(expression, self.is_positive);
}
ConstraintNode::Pattern(pattern) => self.evaluate_pattern_constraint(pattern),
}
self.constraints.shrink_to_fit();
self.constraints
}
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) {
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>, is_positive: bool) {
match expression.node_ref(self.db).node() {
ast::Expr::Compare(expr_compare) => {
self.add_expr_compare(expr_compare, expression);
self.add_expr_compare(expr_compare, expression, is_positive);
}
ast::Expr::Call(expr_call) => {
self.add_expr_call(expr_call, expression);
self.add_expr_call(expr_call, expression, is_positive);
}
_ => {} // TODO other test expression kinds
}
@ -160,12 +177,17 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
fn scope(&self) -> ScopeId<'db> {
match self.constraint {
Constraint::Expression(expression) => expression.scope(self.db),
Constraint::Pattern(pattern) => pattern.scope(self.db),
ConstraintNode::Expression(expression) => expression.scope(self.db),
ConstraintNode::Pattern(pattern) => pattern.scope(self.db),
}
}
fn add_expr_compare(&mut self, expr_compare: &ast::ExprCompare, expression: Expression<'db>) {
fn add_expr_compare(
&mut self,
expr_compare: &ast::ExprCompare,
expression: Expression<'db>,
is_positive: bool,
) {
let ast::ExprCompare {
range: _,
left,
@ -177,6 +199,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
// we have no symbol to narrow down the type of.
return;
}
if !is_positive && comparators.len() > 1 {
// We can't negate a constraint made by a multi-comparator expression, since we can't
// know which comparison part is the one being negated.
// For example, the negation of `x is 1 is y is 2`, would be `(x is not 1) or (y is not 1) or (y is not 2)`
// and that requires cross-symbol constraints, which we don't support yet.
return;
}
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
@ -192,12 +221,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
{
// SAFETY: we should always have a symbol for every Name node.
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
let comp_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
match op {
let rhs_ty = inference.expression_ty(right.scoped_ast_id(self.db, scope));
match if is_positive { *op } else { op.negate() } {
ast::CmpOp::IsNot => {
if comp_ty.is_singleton() {
if rhs_ty.is_singleton() {
let ty = IntersectionBuilder::new(self.db)
.add_negative(comp_ty)
.add_negative(rhs_ty)
.build();
self.constraints.insert(symbol, ty);
} else {
@ -205,12 +235,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
ast::CmpOp::Is => {
self.constraints.insert(symbol, comp_ty);
self.constraints.insert(symbol, rhs_ty);
}
ast::CmpOp::NotEq => {
if comp_ty.is_single_valued(self.db) {
if rhs_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(comp_ty)
.add_negative(rhs_ty)
.build();
self.constraints.insert(symbol, ty);
}
@ -223,7 +253,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
}
}
fn add_expr_call(&mut self, expr_call: &ast::ExprCall, expression: Expression<'db>) {
fn add_expr_call(
&mut self,
expr_call: &ast::ExprCall,
expression: Expression<'db>,
is_positive: bool,
) {
let scope = self.scope();
let inference = infer_expression_types(self.db, expression);
@ -242,7 +277,11 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
// TODO: add support for PEP 604 union types on the right hand side:
// isinstance(x, str | (int | float))
if let Some(constraint) = generate_isinstance_constraint(self.db, &rhs_type) {
if let Some(mut constraint) = generate_isinstance_constraint(self.db, &rhs_type)
{
if !is_positive {
constraint = constraint.negate(self.db);
}
self.constraints.insert(symbol, constraint);
}
}