diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md index f2a24ae965..58850844ff 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md @@ -174,3 +174,11 @@ def _(flag: bool): if isinstance(x, int, foo="bar"): reveal_type(x) # revealed: Literal[1] | Literal["a"] ``` + +## `type[]` types are narrowed as well as class-literal types + +```py +def _(x: object, y: type[int]): + if isinstance(x, y): + reveal_type(x) # revealed: int +``` diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md index b1099a1f7a..d539a28a02 100644 --- a/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/issubclass.md @@ -239,3 +239,11 @@ t = int if flag() else str if issubclass(t, int, foo="bar"): reveal_type(t) # revealed: Literal[int, str] ``` + +### `type[]` types are narrowed as well as class-literal types + +```py +def _(x: type, y: type[int]): + if issubclass(x, y): + reveal_type(x) # revealed: type[int] +``` diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 1da7b4b54b..011484f26b 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -15,6 +15,7 @@ pub(crate) use self::display::TypeArrayDisplay; pub(crate) use self::infer::{ infer_deferred_types, infer_definition_types, infer_expression_types, infer_scope_types, }; +pub use self::narrow::KnownConstraintFunction; pub(crate) use self::signatures::Signature; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module, KnownModule}; @@ -3043,14 +3044,6 @@ impl<'db> FunctionType<'db> { } } -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum KnownConstraintFunction { - /// `builtins.isinstance` - IsInstance, - /// `builtins.issubclass` - IsSubclass, -} - /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might /// have special behavior. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index ef1d629303..cc572eeb7c 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -7,8 +7,8 @@ use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; use crate::types::{ - infer_expression_types, ClassLiteralType, IntersectionBuilder, KnownClass, - KnownConstraintFunction, KnownFunction, Truthiness, Type, UnionBuilder, + infer_expression_types, ClassBase, ClassLiteralType, IntersectionBuilder, KnownClass, + KnownFunction, SubclassOfType, Truthiness, Type, UnionBuilder, }; use crate::Db; use itertools::Itertools; @@ -83,28 +83,37 @@ fn all_negative_narrowing_constraints_for_expression<'db>( NarrowingConstraintsBuilder::new(db, ConstraintNode::Expression(expression), false).finish() } -/// Generate a constraint from the type of a `classinfo` argument to `isinstance` or `issubclass`. -/// -/// The `classinfo` argument can be a class literal, a tuple of (tuples of) class literals. PEP 604 -/// union types are not yet supported. Returns `None` if the `classinfo` argument has a wrong type. -fn generate_classinfo_constraint<'db, F>( - db: &'db dyn Db, - classinfo: &Type<'db>, - to_constraint: F, -) -> Option> -where - F: Fn(ClassLiteralType<'db>) -> Type<'db> + Copy, -{ - match classinfo { - Type::Tuple(tuple) => { - let mut builder = UnionBuilder::new(db); - for element in tuple.elements(db) { - builder = builder.add(generate_classinfo_constraint(db, element, to_constraint)?); +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum KnownConstraintFunction { + /// `builtins.isinstance` + IsInstance, + /// `builtins.issubclass` + IsSubclass, +} + +impl KnownConstraintFunction { + /// Generate a constraint from the type of a `classinfo` argument to `isinstance` or `issubclass`. + /// + /// The `classinfo` argument can be a class literal, a tuple of (tuples of) class literals. PEP 604 + /// union types are not yet supported. Returns `None` if the `classinfo` argument has a wrong type. + fn generate_constraint<'db>(self, db: &'db dyn Db, classinfo: Type<'db>) -> Option> { + match classinfo { + Type::Tuple(tuple) => { + let mut builder = UnionBuilder::new(db); + for element in tuple.elements(db) { + builder = builder.add(self.generate_constraint(db, *element)?); + } + Some(builder.build()) } - Some(builder.build()) + Type::ClassLiteral(ClassLiteralType { class }) + | Type::SubclassOf(SubclassOfType { + base: ClassBase::Class(class), + }) => Some(match self { + KnownConstraintFunction::IsInstance => Type::instance(class), + KnownConstraintFunction::IsSubclass => Type::subclass_of(class), + }), + _ => None, } - Type::ClassLiteral(class_literal_type) => Some(to_constraint(*class_literal_type)), - _ => None, } } @@ -429,24 +438,13 @@ impl<'db> NarrowingConstraintsBuilder<'db> { let class_info_ty = inference.expression_ty(class_info.scoped_expression_id(self.db, scope)); - let to_constraint = match function { - KnownConstraintFunction::IsInstance => { - |class_literal: ClassLiteralType<'db>| Type::instance(class_literal.class) - } - KnownConstraintFunction::IsSubclass => { - |class_literal: ClassLiteralType<'db>| { - Type::subclass_of(class_literal.class) - } - } - }; - - generate_classinfo_constraint(self.db, &class_info_ty, to_constraint).map( - |constraint| { + function + .generate_constraint(self.db, class_info_ty) + .map(|constraint| { let mut constraints = NarrowingConstraints::default(); constraints.insert(symbol, constraint.negate_if(self.db, !is_positive)); constraints - }, - ) + }) } // for the expression `bool(E)`, we further narrow the type based on `E` Type::ClassLiteral(class_type)