[red-knot] Support for not-equal narrowing (#13749)

Add type narrowing for `!=` expression as stated in
#13694.

###  Test Plan

Add tests in new md format.

---------

Co-authored-by: David Peter <mail@david-peter.de>
This commit is contained in:
Alex 2024-10-22 00:08:33 +03:00 committed by GitHub
parent e39110e18b
commit 9d102799f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 153 additions and 0 deletions

View file

@ -0,0 +1,13 @@
# Narrowing for nested conditionals
```py
def int_instance() -> int: ...
x = int_instance()
if x != 1:
if x != 2:
if x != 3:
reveal_type(x) # revealed: int & ~Literal[1] & ~Literal[2] & ~Literal[3]
```

View file

@ -0,0 +1,58 @@
# Narrowing for `!=` conditionals
## `x != None`
```py
x = None if flag else 1
if x != None:
reveal_type(x) # revealed: Literal[1]
```
## `!=` for other singleton types
```py
x = True if flag else False
if x != False:
reveal_type(x) # revealed: Literal[True]
```
## `x != y` where `y` is of literal type
```py
x = 1 if flag else 2
if x != 1:
reveal_type(x) # revealed: Literal[2]
```
## `x != y` where `y` is a single-valued type
```py
class A: ...
class B: ...
C = A if flag else B
if C != A:
reveal_type(C) # revealed: Literal[B]
```
## `!=` for non-single-valued types
Only single-valued types should narrow the type:
```py
def int_instance() -> int: ...
x = int_instance() if flag else None
y = int_instance()
if x != y:
reveal_type(x) # revealed: int | None
```

View file

@ -671,6 +671,55 @@ impl<'db> Type<'db> {
}
}
/// Return true if this type is non-empty and all inhabitants of this type compare equal.
pub(crate) fn is_single_valued(self, db: &'db dyn Db) -> bool {
match self {
Type::None
| Type::Function(..)
| Type::Module(..)
| Type::Class(..)
| Type::IntLiteral(..)
| Type::BooleanLiteral(..)
| Type::StringLiteral(..)
| Type::BytesLiteral(..) => true,
Type::Tuple(tuple) => tuple
.elements(db)
.iter()
.all(|elem| elem.is_single_valued(db)),
Type::Instance(class_type) => match class_type.known(db) {
Some(KnownClass::NoneType) => true,
Some(
KnownClass::Bool
| KnownClass::Object
| KnownClass::Bytes
| KnownClass::Type
| KnownClass::Int
| KnownClass::Float
| KnownClass::Str
| KnownClass::List
| KnownClass::Tuple
| KnownClass::Set
| KnownClass::Dict
| KnownClass::GenericAlias
| KnownClass::ModuleType
| KnownClass::FunctionType,
) => false,
None => false,
},
Type::Any
| Type::Never
| Type::Unknown
| Type::Unbound
| Type::Todo
| Type::Union(..)
| Type::Intersection(..)
| Type::LiteralString => false,
}
}
/// Resolve a member access of a type.
///
/// For example, if `foo` is `Type::Instance(<Bar>)`,
@ -1973,6 +2022,31 @@ mod tests {
assert!(from.into_type(&db).is_singleton());
}
#[test_case(Ty::None)]
#[test_case(Ty::BooleanLiteral(true))]
#[test_case(Ty::IntLiteral(1))]
#[test_case(Ty::StringLiteral("abc"))]
#[test_case(Ty::BytesLiteral("abc"))]
#[test_case(Ty::Tuple(vec![]))]
#[test_case(Ty::Tuple(vec![Ty::BooleanLiteral(true), Ty::IntLiteral(1)]))]
fn is_single_valued(from: Ty) {
let db = setup_db();
assert!(from.into_type(&db).is_single_valued(&db));
}
#[test_case(Ty::Never)]
#[test_case(Ty::Any)]
#[test_case(Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]))]
#[test_case(Ty::Tuple(vec![Ty::None, Ty::BuiltinInstance("int")]))]
#[test_case(Ty::BuiltinInstance("str"))]
#[test_case(Ty::LiteralString)]
fn is_not_single_valued(from: Ty) {
let db = setup_db();
assert!(!from.into_type(&db).is_single_valued(&db));
}
#[test_case(Ty::Never)]
#[test_case(Ty::IntLiteral(345))]
#[test_case(Ty::BuiltinInstance("str"))]

View file

@ -178,6 +178,14 @@ impl<'db> NarrowingConstraintsBuilder<'db> {
ast::CmpOp::Is => {
self.constraints.insert(symbol, comp_ty);
}
ast::CmpOp::NotEq => {
if comp_ty.is_single_valued(self.db) {
let ty = IntersectionBuilder::new(self.db)
.add_negative(comp_ty)
.build();
self.constraints.insert(symbol, ty);
}
}
_ => {
// TODO other comparison types
}