diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 1e2d2500ca..f360184189 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -401,6 +401,16 @@ impl<'db> Type<'db> { } } + /// Return true if the type is a class or a union of classes. + pub fn is_class(&self, db: &'db dyn Db) -> bool { + match self { + Type::Union(union) => union.elements(db).iter().all(|ty| ty.is_class(db)), + Type::Class(_) => true, + // / TODO include type[X], once we add that type + _ => false, + } + } + /// Return true if this type is a [subtype of] type `target`. /// /// [subtype of]: https://typing.readthedocs.io/en/latest/spec/concepts.html#subtype-supertype-and-type-equivalence diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 33155ad6cb..b8b529865f 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -1322,6 +1322,22 @@ impl<'db> TypeInferenceBuilder<'db> { ); } + /// Emit a diagnostic declaring that a type does not support subscripting. + pub(super) fn non_subscriptable_diagnostic( + &mut self, + node: AnyNodeRef, + non_subscriptable_ty: Type<'db>, + ) { + self.add_diagnostic( + node, + "non-subscriptable", + format_args!( + "Cannot subscript object of type '{}' with no `__getitem__` method.", + non_subscriptable_ty.display(self.db) + ), + ); + } + fn infer_for_statement_definition( &mut self, target: &ast::ExprName, @@ -2588,7 +2604,35 @@ impl<'db> TypeInferenceBuilder<'db> { Type::Unknown }) } - _ => Type::Todo, + (value_ty, slice_ty) => { + // Resolve the value to its class. + let value_meta_ty = value_ty.to_meta_type(self.db); + + // If the class defines `__getitem__`, return its return type. + // + // See: https://docs.python.org/3/reference/datamodel.html#class-getitem-versus-getitem + let dunder_getitem_method = value_meta_ty.member(self.db, "__getitem__"); + if !dunder_getitem_method.is_unbound() { + return dunder_getitem_method + .call(self.db, &[slice_ty]) + .unwrap_with_diagnostic(self.db, value.as_ref().into(), self); + } + + // Otherwise, if the value is itself a class and defines `__class_getitem__`, + // return its return type. + if value_ty.is_class(self.db) { + let dunder_class_getitem_method = value_ty.member(self.db, "__class_getitem__"); + if !dunder_class_getitem_method.is_unbound() { + return dunder_class_getitem_method + .call(self.db, &[slice_ty]) + .unwrap_with_diagnostic(self.db, value.as_ref().into(), self); + } + } + + // Otherwise, emit a diagnostic. + self.non_subscriptable_diagnostic((&**value).into(), value_ty); + Type::Unknown + } } } @@ -6723,6 +6767,261 @@ mod tests { Ok(()) } + #[test] + fn subscript_getitem_unbound() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + class NotSubscriptable: + pass + + a = NotSubscriptable()[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Unknown"); + assert_file_diagnostics( + &db, + "/src/a.py", + &["Cannot subscript object of type 'NotSubscriptable' with no `__getitem__` method."], + ); + + Ok(()) + } + + #[test] + fn subscript_not_callable_getitem() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + class NotSubscriptable: + __getitem__ = None + + a = NotSubscriptable()[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Unknown"); + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'None' is not callable."], + ); + + Ok(()) + } + + #[test] + fn subscript_str_literal() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + def add(x: int, y: int) -> int: + return x + y + + a = 'abcde'[add(0, 1)] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "str"); + + Ok(()) + } + + #[test] + fn subscript_getitem() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + class Identity: + def __getitem__(self, index: int) -> int: + return index + + a = Identity()[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "int"); + + Ok(()) + } + + #[test] + fn subscript_class_getitem() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + class Identity: + def __class_getitem__(cls, item: int) -> str: + return item + + a = Identity[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "str"); + + Ok(()) + } + + #[test] + fn subscript_getitem_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + class Identity: + if flag: + def __getitem__(self, index: int) -> int: + return index + else: + def __getitem__(self, index: int) -> str: + return str(index) + + a = Identity()[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "int | str"); + + Ok(()) + } + + #[test] + fn subscript_class_getitem_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + class Identity: + if flag: + def __class_getitem__(cls, item: int) -> str: + return item + else: + def __class_getitem__(cls, item: int) -> int: + return item + + a = Identity[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "str | int"); + + Ok(()) + } + + #[test] + fn subscript_class_getitem_class_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + class Identity1: + def __class_getitem__(cls, item: int) -> str: + return item + + class Identity2: + def __class_getitem__(cls, item: int) -> int: + return item + + if flag: + a = Identity1 + else: + a = Identity2 + + b = a[0] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "Literal[Identity1, Identity2]"); + assert_public_ty(&db, "/src/a.py", "b", "str | int"); + + Ok(()) + } + + #[test] + fn subscript_class_getitem_unbound_method_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + if flag: + class Identity: + def __class_getitem__(self, x: int) -> str: + pass + else: + class Identity: + pass + + a = Identity[42] + ", + )?; + + assert_public_ty(&db, "/src/a.py", "a", "str | Unknown"); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Object of type 'Literal[__class_getitem__] | Unbound' is not callable (due to union element 'Unbound')."], + ); + + Ok(()) + } + + #[test] + fn subscript_class_getitem_non_class_union() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + flag = True + + if flag: + class Identity: + def __class_getitem__(self, x: int) -> str: + pass + else: + Identity = 1 + + a = Identity[42] + ", + )?; + + // TODO this should _probably_ emit `str | Unknown` instead of `Unknown`. + assert_public_ty(&db, "/src/a.py", "a", "Unknown"); + + assert_file_diagnostics( + &db, + "/src/a.py", + &["Cannot subscript object of type 'Literal[Identity] | Literal[1]' with no `__getitem__` method."], + ); + + Ok(()) + } + #[test] fn dunder_call() -> anyhow::Result<()> { let mut db = setup_db();