mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-29 13:25:17 +00:00
[red-knot] Implement type narrowing for boolean conditionals (#14037)
## Summary This PR enables red-knot to support type narrowing based on `and` and `or` conditionals, including nested combinations and their negation (for `elif` / `else` blocks and for `not` operator). Part of #13694. In order to address this properly (hopefully 😅), I had to run `NarrowingConstraintsBuilder` functions recursively. In the first commit I introduced a minor refactor - instead of mutating `self.constraints`, the new constraints are now returned as function return values. I also modified the constraints map to be optional, preventing unnecessary hashmap allocations. Thanks @carljm for your support on this :) The second commit contains the logic and tests for handling boolean ops, with improvements to intersections handling in `is_subtype_of` . As I'm still new to Rust and the internals of type checkers, I’d be more than happy to hear any insights or suggestions. Thank you! --------- Co-authored-by: Carl Meyer <carl@astral.sh>
This commit is contained in:
parent
bb25bd9c6c
commit
6c56a7a868
4 changed files with 591 additions and 59 deletions
|
@ -528,6 +528,46 @@ impl<'db> Type<'db> {
|
|||
.elements(db)
|
||||
.iter()
|
||||
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
|
||||
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
|
||||
// Check that all target positive values are covered in self positive values
|
||||
target_intersection
|
||||
.positive(db)
|
||||
.iter()
|
||||
.all(|&target_pos_elem| {
|
||||
self_intersection
|
||||
.positive(db)
|
||||
.iter()
|
||||
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
|
||||
})
|
||||
// Check that all target negative values are excluded in self, either by being
|
||||
// subtypes of a self negative value or being disjoint from a self positive value.
|
||||
&& target_intersection
|
||||
.negative(db)
|
||||
.iter()
|
||||
.all(|&target_neg_elem| {
|
||||
// Is target negative value is subtype of a self negative value
|
||||
self_intersection.negative(db).iter().any(|&self_neg_elem| {
|
||||
target_neg_elem.is_subtype_of(db, self_neg_elem)
|
||||
// Is target negative value is disjoint from a self positive value?
|
||||
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
|
||||
target_neg_elem.is_disjoint_from(db, self_pos_elem)
|
||||
})
|
||||
})
|
||||
}
|
||||
(Type::Intersection(intersection), ty) => intersection
|
||||
.positive(db)
|
||||
.iter()
|
||||
.any(|&elem_ty| elem_ty.is_subtype_of(db, ty)),
|
||||
(ty, Type::Intersection(intersection)) => {
|
||||
intersection
|
||||
.positive(db)
|
||||
.iter()
|
||||
.all(|&pos_ty| ty.is_subtype_of(db, pos_ty))
|
||||
&& intersection
|
||||
.negative(db)
|
||||
.iter()
|
||||
.all(|&neg_ty| neg_ty.is_disjoint_from(db, ty))
|
||||
}
|
||||
(Type::Instance(self_class), Type::Instance(target_class)) => {
|
||||
self_class.is_subclass_of(db, target_class)
|
||||
}
|
||||
|
@ -2190,6 +2230,11 @@ mod tests {
|
|||
Ty::BuiltinInstance("FloatingPointError"),
|
||||
Ty::BuiltinInstance("Exception")
|
||||
)]
|
||||
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::BuiltinInstance("int"))]
|
||||
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
|
||||
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
|
||||
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]})]
|
||||
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("str")], neg: vec![Ty::StringLiteral("foo")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
|
||||
fn is_subtype_of(from: Ty, to: Ty) {
|
||||
let db = setup_db();
|
||||
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
|
||||
|
@ -2210,6 +2255,11 @@ mod tests {
|
|||
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(42)]), Ty::Tuple(vec![Ty::BuiltinInstance("str")]))]
|
||||
#[test_case(Ty::Tuple(vec![Ty::Todo]), Ty::Tuple(vec![Ty::IntLiteral(2)]))]
|
||||
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::Todo]))]
|
||||
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(3)]})]
|
||||
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
|
||||
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]})]
|
||||
#[test_case(Ty::BuiltinInstance("int"), Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
|
||||
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(1)]})]
|
||||
fn is_not_subtype_of(from: Ty, to: Ty) {
|
||||
let db = setup_db();
|
||||
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
|
||||
|
@ -2241,6 +2291,34 @@ mod tests {
|
|||
assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_subtype_of_intersection_of_class_instances() {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented(
|
||||
"/src/module.py",
|
||||
"
|
||||
class A: ...
|
||||
a = A()
|
||||
class B: ...
|
||||
b = B()
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap();
|
||||
|
||||
let a_ty = super::global_symbol(&db, module, "a").expect_type();
|
||||
let b_ty = super::global_symbol(&db, module, "b").expect_type();
|
||||
let intersection = IntersectionBuilder::new(&db)
|
||||
.add_positive(a_ty)
|
||||
.add_positive(b_ty)
|
||||
.build();
|
||||
|
||||
assert_eq!(intersection.display(&db).to_string(), "A & B");
|
||||
assert!(!a_ty.is_subtype_of(&db, b_ty));
|
||||
assert!(intersection.is_subtype_of(&db, b_ty));
|
||||
assert!(intersection.is_subtype_of(&db, a_ty));
|
||||
}
|
||||
|
||||
#[test_case(
|
||||
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]),
|
||||
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)])
|
||||
|
|
|
@ -25,12 +25,12 @@
|
|||
//! * No type in an intersection can be a supertype of any other type in the intersection (just
|
||||
//! eliminate the supertype from the intersection).
|
||||
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
|
||||
|
||||
use super::KnownClass;
|
||||
use crate::types::{IntersectionType, Type, UnionType};
|
||||
use crate::{Db, FxOrderSet};
|
||||
use smallvec::SmallVec;
|
||||
|
||||
use super::KnownClass;
|
||||
|
||||
pub(crate) struct UnionBuilder<'db> {
|
||||
elements: Vec<Type<'db>>,
|
||||
db: &'db dyn Db,
|
||||
|
@ -80,7 +80,6 @@ impl<'db> UnionBuilder<'db> {
|
|||
to_remove.push(index);
|
||||
}
|
||||
}
|
||||
|
||||
match to_remove[..] {
|
||||
[] => self.elements.push(to_add),
|
||||
[index] => self.elements[index] = to_add,
|
||||
|
@ -103,7 +102,6 @@ impl<'db> UnionBuilder<'db> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
|
@ -386,8 +384,9 @@ mod tests {
|
|||
use crate::program::{Program, SearchPathSettings};
|
||||
use crate::python_version::PythonVersion;
|
||||
use crate::stdlib::typing_symbol;
|
||||
use crate::types::{KnownClass, StringLiteralType, UnionBuilder};
|
||||
use crate::types::{global_symbol, KnownClass, StringLiteralType, UnionBuilder};
|
||||
use crate::ProgramSettings;
|
||||
use ruff_db::files::system_path_to_file;
|
||||
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
|
||||
use test_case::test_case;
|
||||
|
||||
|
@ -993,4 +992,66 @@ mod tests {
|
|||
.build();
|
||||
assert_eq!(result, ty);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_intersection_of_two_unions_simplify() {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented(
|
||||
"/src/module.py",
|
||||
"
|
||||
class A: ...
|
||||
class B: ...
|
||||
a = A()
|
||||
b = B()
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let file = system_path_to_file(&db, "src/module.py").expect("file to exist");
|
||||
|
||||
let a = global_symbol(&db, file, "a").expect_type();
|
||||
let b = global_symbol(&db, file, "b").expect_type();
|
||||
let union = UnionBuilder::new(&db).add(a).add(b).build();
|
||||
assert_eq!(union.display(&db).to_string(), "A | B");
|
||||
let reversed_union = UnionBuilder::new(&db).add(b).add(a).build();
|
||||
assert_eq!(reversed_union.display(&db).to_string(), "B | A");
|
||||
let intersection = IntersectionBuilder::new(&db)
|
||||
.add_positive(union)
|
||||
.add_positive(reversed_union)
|
||||
.build();
|
||||
assert_eq!(intersection.display(&db).to_string(), "B | A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_union_of_two_intersections_simplify() {
|
||||
let mut db = setup_db();
|
||||
db.write_dedented(
|
||||
"/src/module.py",
|
||||
"
|
||||
class A: ...
|
||||
class B: ...
|
||||
a = A()
|
||||
b = B()
|
||||
",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let file = system_path_to_file(&db, "src/module.py").expect("file to exist");
|
||||
|
||||
let a = global_symbol(&db, file, "a").expect_type();
|
||||
let b = global_symbol(&db, file, "b").expect_type();
|
||||
let intersection = IntersectionBuilder::new(&db)
|
||||
.add_positive(a)
|
||||
.add_positive(b)
|
||||
.build();
|
||||
let reversed_intersection = IntersectionBuilder::new(&db)
|
||||
.add_positive(b)
|
||||
.add_positive(a)
|
||||
.build();
|
||||
let union = UnionBuilder::new(&db)
|
||||
.add(intersection)
|
||||
.add(reversed_intersection)
|
||||
.build();
|
||||
assert_eq!(union.display(&db).to_string(), "A & B");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,12 +5,15 @@ use crate::semantic_index::expression::Expression;
|
|||
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable};
|
||||
use crate::semantic_index::symbol_table;
|
||||
use crate::types::{
|
||||
infer_expression_types, IntersectionBuilder, KnownFunction, Type, UnionBuilder,
|
||||
infer_expression_types, IntersectionBuilder, KnownClass, KnownFunction, Truthiness, Type,
|
||||
UnionBuilder,
|
||||
};
|
||||
use crate::Db;
|
||||
use itertools::Itertools;
|
||||
use ruff_python_ast as ast;
|
||||
use ruff_python_ast::{BoolOp, ExprBoolOp};
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::collections::hash_map::Entry;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Return the type constraint that `test` (if true) would place on `definition`, if any.
|
||||
|
@ -34,21 +37,20 @@ pub(crate) fn narrowing_constraint<'db>(
|
|||
constraint: Constraint<'db>,
|
||||
definition: Definition<'db>,
|
||||
) -> Option<Type<'db>> {
|
||||
match constraint.node {
|
||||
let constraints = match constraint.node {
|
||||
ConstraintNode::Expression(expression) => {
|
||||
if constraint.is_positive {
|
||||
all_narrowing_constraints_for_expression(db, expression)
|
||||
.get(&definition.symbol(db))
|
||||
.copied()
|
||||
} else {
|
||||
all_negative_narrowing_constraints_for_expression(db, expression)
|
||||
.get(&definition.symbol(db))
|
||||
.copied()
|
||||
}
|
||||
}
|
||||
ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern)
|
||||
.get(&definition.symbol(db))
|
||||
.copied(),
|
||||
ConstraintNode::Pattern(pattern) => all_narrowing_constraints_for_pattern(db, pattern),
|
||||
};
|
||||
if let Some(constraints) = constraints {
|
||||
constraints.get(&definition.symbol(db)).copied()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,7 +58,7 @@ pub(crate) fn narrowing_constraint<'db>(
|
|||
fn all_narrowing_constraints_for_pattern<'db>(
|
||||
db: &'db dyn Db,
|
||||
pattern: PatternConstraint<'db>,
|
||||
) -> NarrowingConstraints<'db> {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
NarrowingConstraintsBuilder::new(db, ConstraintNode::Pattern(pattern), true).finish()
|
||||
}
|
||||
|
||||
|
@ -64,7 +66,7 @@ fn all_narrowing_constraints_for_pattern<'db>(
|
|||
fn all_narrowing_constraints_for_expression<'db>(
|
||||
db: &'db dyn Db,
|
||||
expression: Expression<'db>,
|
||||
) -> NarrowingConstraints<'db> {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), true).finish()
|
||||
}
|
||||
|
||||
|
@ -72,7 +74,7 @@ fn all_narrowing_constraints_for_expression<'db>(
|
|||
fn all_negative_narrowing_constraints_for_expression<'db>(
|
||||
db: &'db dyn Db,
|
||||
expression: Expression<'db>,
|
||||
) -> NarrowingConstraints<'db> {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish()
|
||||
}
|
||||
|
||||
|
@ -100,11 +102,52 @@ fn generate_isinstance_constraint<'db>(
|
|||
|
||||
type NarrowingConstraints<'db> = FxHashMap<ScopedSymbolId, Type<'db>>;
|
||||
|
||||
fn merge_constraints_and<'db>(
|
||||
into: &mut NarrowingConstraints<'db>,
|
||||
from: NarrowingConstraints<'db>,
|
||||
db: &'db dyn Db,
|
||||
) {
|
||||
for (key, value) in from {
|
||||
match into.entry(key) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
*entry.get_mut() = IntersectionBuilder::new(db)
|
||||
.add_positive(*entry.get())
|
||||
.add_positive(value)
|
||||
.build();
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn merge_constraints_or<'db>(
|
||||
into: &mut NarrowingConstraints<'db>,
|
||||
from: &NarrowingConstraints<'db>,
|
||||
db: &'db dyn Db,
|
||||
) {
|
||||
for (key, value) in from {
|
||||
match into.entry(*key) {
|
||||
Entry::Occupied(mut entry) => {
|
||||
*entry.get_mut() = UnionBuilder::new(db).add(*entry.get()).add(*value).build();
|
||||
}
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(KnownClass::Object.to_instance(db));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (key, value) in into.iter_mut() {
|
||||
if !from.contains_key(key) {
|
||||
*value = KnownClass::Object.to_instance(db);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct NarrowingConstraintsBuilder<'db> {
|
||||
db: &'db dyn Db,
|
||||
constraint: ConstraintNode<'db>,
|
||||
is_positive: bool,
|
||||
constraints: NarrowingConstraints<'db>,
|
||||
}
|
||||
|
||||
impl<'db> NarrowingConstraintsBuilder<'db> {
|
||||
|
@ -113,24 +156,31 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
db,
|
||||
constraint,
|
||||
is_positive,
|
||||
constraints: NarrowingConstraints::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(mut self) -> NarrowingConstraints<'db> {
|
||||
match self.constraint {
|
||||
fn finish(mut self) -> Option<NarrowingConstraints<'db>> {
|
||||
let constraints: Option<NarrowingConstraints<'db>> = match self.constraint {
|
||||
ConstraintNode::Expression(expression) => {
|
||||
self.evaluate_expression_constraint(expression, self.is_positive);
|
||||
self.evaluate_expression_constraint(expression, self.is_positive)
|
||||
}
|
||||
ConstraintNode::Pattern(pattern) => self.evaluate_pattern_constraint(pattern),
|
||||
};
|
||||
if let Some(mut constraints) = constraints {
|
||||
constraints.shrink_to_fit();
|
||||
Some(constraints)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
self.constraints.shrink_to_fit();
|
||||
self.constraints
|
||||
}
|
||||
|
||||
fn evaluate_expression_constraint(&mut self, expression: Expression<'db>, is_positive: bool) {
|
||||
fn evaluate_expression_constraint(
|
||||
&mut self,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let expression_node = expression.node_ref(self.db).node();
|
||||
self.evaluate_expression_node_constraint(expression_node, expression, is_positive);
|
||||
self.evaluate_expression_node_constraint(expression_node, expression, is_positive)
|
||||
}
|
||||
|
||||
fn evaluate_expression_node_constraint(
|
||||
|
@ -138,52 +188,51 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
expression_node: &ruff_python_ast::Expr,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
match expression_node {
|
||||
ast::Expr::Compare(expr_compare) => {
|
||||
self.add_expr_compare(expr_compare, expression, is_positive);
|
||||
self.evaluate_expr_compare(expr_compare, expression, is_positive)
|
||||
}
|
||||
ast::Expr::Call(expr_call) => {
|
||||
self.add_expr_call(expr_call, expression, is_positive);
|
||||
self.evaluate_expr_call(expr_call, expression, is_positive)
|
||||
}
|
||||
ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => {
|
||||
self.evaluate_expression_node_constraint(
|
||||
&unary_op.operand,
|
||||
expression,
|
||||
!is_positive,
|
||||
);
|
||||
}
|
||||
_ => {} // TODO other test expression kinds
|
||||
ast::Expr::UnaryOp(unary_op) if unary_op.op == ast::UnaryOp::Not => self
|
||||
.evaluate_expression_node_constraint(&unary_op.operand, expression, !is_positive),
|
||||
ast::Expr::BoolOp(bool_op) => self.evaluate_bool_op(bool_op, expression, is_positive),
|
||||
_ => None, // TODO other test expression kinds
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) {
|
||||
fn evaluate_pattern_constraint(
|
||||
&mut self,
|
||||
pattern: PatternConstraint<'db>,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let subject = pattern.subject(self.db);
|
||||
|
||||
match pattern.pattern(self.db).node() {
|
||||
ast::Pattern::MatchValue(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
ast::Pattern::MatchSingleton(singleton_pattern) => {
|
||||
self.add_match_pattern_singleton(subject, singleton_pattern);
|
||||
self.evaluate_match_pattern_singleton(subject, singleton_pattern)
|
||||
}
|
||||
ast::Pattern::MatchSequence(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
ast::Pattern::MatchMapping(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
ast::Pattern::MatchClass(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
ast::Pattern::MatchStar(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
ast::Pattern::MatchAs(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
ast::Pattern::MatchOr(_) => {
|
||||
// TODO
|
||||
None // TODO
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -199,12 +248,12 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn add_expr_compare(
|
||||
fn evaluate_expr_compare(
|
||||
&mut self,
|
||||
expr_compare: &ast::ExprCompare,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let ast::ExprCompare {
|
||||
range: _,
|
||||
left,
|
||||
|
@ -214,14 +263,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
if !left.is_name_expr() && comparators.iter().all(|c| !c.is_name_expr()) {
|
||||
// If none of the comparators are name expressions,
|
||||
// we have no symbol to narrow down the type of.
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
if !is_positive && comparators.len() > 1 {
|
||||
// We can't negate a constraint made by a multi-comparator expression, since we can't
|
||||
// know which comparison part is the one being negated.
|
||||
// For example, the negation of `x is 1 is y is 2`, would be `(x is not 1) or (y is not 1) or (y is not 2)`
|
||||
// and that requires cross-symbol constraints, which we don't support yet.
|
||||
return;
|
||||
return None;
|
||||
}
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
|
@ -229,6 +278,7 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
let comparator_tuples = std::iter::once(&**left)
|
||||
.chain(comparators)
|
||||
.tuple_windows::<(&ruff_python_ast::Expr, &ruff_python_ast::Expr)>();
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
for (op, (left, right)) in std::iter::zip(&**ops, comparator_tuples) {
|
||||
if let ast::Expr::Name(ast::ExprName {
|
||||
range: _,
|
||||
|
@ -246,20 +296,20 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
self.constraints.insert(symbol, ty);
|
||||
constraints.insert(symbol, ty);
|
||||
} else {
|
||||
// Non-singletons cannot be safely narrowed using `is not`
|
||||
}
|
||||
}
|
||||
ast::CmpOp::Is => {
|
||||
self.constraints.insert(symbol, rhs_ty);
|
||||
constraints.insert(symbol, rhs_ty);
|
||||
}
|
||||
ast::CmpOp::NotEq => {
|
||||
if rhs_ty.is_single_valued(self.db) {
|
||||
let ty = IntersectionBuilder::new(self.db)
|
||||
.add_negative(rhs_ty)
|
||||
.build();
|
||||
self.constraints.insert(symbol, ty);
|
||||
constraints.insert(symbol, ty);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
|
@ -268,14 +318,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
}
|
||||
}
|
||||
}
|
||||
Some(constraints)
|
||||
}
|
||||
|
||||
fn add_expr_call(
|
||||
fn evaluate_expr_call(
|
||||
&mut self,
|
||||
expr_call: &ast::ExprCall,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let scope = self.scope();
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
|
||||
|
@ -299,18 +350,21 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
if !is_positive {
|
||||
constraint = constraint.negate(self.db);
|
||||
}
|
||||
self.constraints.insert(symbol, constraint);
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
constraints.insert(symbol, constraint);
|
||||
return Some(constraints);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn add_match_pattern_singleton(
|
||||
fn evaluate_match_pattern_singleton(
|
||||
&mut self,
|
||||
subject: &ast::Expr,
|
||||
pattern: &ast::PatternMatchSingleton,
|
||||
) {
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
if let Some(ast::ExprName { id, .. }) = subject.as_name_expr() {
|
||||
// SAFETY: we should always have a symbol for every Name node.
|
||||
let symbol = self.symbols().symbol_id_by_name(id).unwrap();
|
||||
|
@ -320,7 +374,64 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
|
|||
ast::Singleton::True => Type::BooleanLiteral(true),
|
||||
ast::Singleton::False => Type::BooleanLiteral(false),
|
||||
};
|
||||
self.constraints.insert(symbol, ty);
|
||||
let mut constraints = NarrowingConstraints::default();
|
||||
constraints.insert(symbol, ty);
|
||||
Some(constraints)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_bool_op(
|
||||
&mut self,
|
||||
expr_bool_op: &ExprBoolOp,
|
||||
expression: Expression<'db>,
|
||||
is_positive: bool,
|
||||
) -> Option<NarrowingConstraints<'db>> {
|
||||
let inference = infer_expression_types(self.db, expression);
|
||||
let scope = self.scope();
|
||||
let mut sub_constraints = expr_bool_op
|
||||
.values
|
||||
.iter()
|
||||
// filter our arms with statically known truthiness
|
||||
.filter(|expr| {
|
||||
inference
|
||||
.expression_ty(expr.scoped_ast_id(self.db, scope))
|
||||
.bool(self.db)
|
||||
!= match expr_bool_op.op {
|
||||
BoolOp::And => Truthiness::AlwaysTrue,
|
||||
BoolOp::Or => Truthiness::AlwaysFalse,
|
||||
}
|
||||
})
|
||||
.map(|sub_expr| {
|
||||
self.evaluate_expression_node_constraint(sub_expr, expression, is_positive)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
match (expr_bool_op.op, is_positive) {
|
||||
(BoolOp::And, true) | (BoolOp::Or, false) => {
|
||||
let mut aggregation: Option<NarrowingConstraints> = None;
|
||||
for sub_constraint in sub_constraints.into_iter().flatten() {
|
||||
if let Some(ref mut some_aggregation) = aggregation {
|
||||
merge_constraints_and(some_aggregation, sub_constraint, self.db);
|
||||
} else {
|
||||
aggregation = Some(sub_constraint);
|
||||
}
|
||||
}
|
||||
aggregation
|
||||
}
|
||||
(BoolOp::Or, true) | (BoolOp::And, false) => {
|
||||
let (first, rest) = sub_constraints.split_first_mut()?;
|
||||
if let Some(ref mut first) = first {
|
||||
for rest_constraint in rest {
|
||||
if let Some(rest_constraint) = rest_constraint {
|
||||
merge_constraints_or(first, rest_constraint, self.db);
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
first.clone()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue