diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md b/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md index f8eb97bec5..06b81555b1 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/assignment.md @@ -213,23 +213,20 @@ reveal_type(l[0]) # revealed: Literal[0] reveal_type(d[0]) # revealed: Literal[0] reveal_type(b[0]) # revealed: Literal[0] reveal_type(dd[0]) # revealed: Literal[0] -# TODO: should be Literal[0] -reveal_type(cm[0]) # revealed: Unknown +reveal_type(cm[0]) # revealed: Literal[0] class C: reveal_type(l[0]) # revealed: Literal[0] reveal_type(d[0]) # revealed: Literal[0] reveal_type(b[0]) # revealed: Literal[0] reveal_type(dd[0]) # revealed: Literal[0] - # TODO: should be Literal[0] - reveal_type(cm[0]) # revealed: Unknown + reveal_type(cm[0]) # revealed: Literal[0] [reveal_type(l[0]) for _ in range(1)] # revealed: Literal[0] [reveal_type(d[0]) for _ in range(1)] # revealed: Literal[0] [reveal_type(b[0]) for _ in range(1)] # revealed: Literal[0] [reveal_type(dd[0]) for _ in range(1)] # revealed: Literal[0] -# TODO: should be Literal[0] -[reveal_type(cm[0]) for _ in range(1)] # revealed: Unknown +[reveal_type(cm[0]) for _ in range(1)] # revealed: Literal[0] def _(): reveal_type(l[0]) # revealed: int | None diff --git a/crates/ty_python_semantic/src/types/infer.rs b/crates/ty_python_semantic/src/types/infer.rs index 0f76eb2ddb..d1dec5e148 100644 --- a/crates/ty_python_semantic/src/types/infer.rs +++ b/crates/ty_python_semantic/src/types/infer.rs @@ -1785,6 +1785,39 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { + /// Arbitrary `__getitem__`/`__setitem__` methods on a class do not + /// necessarily guarantee that the passed-in value for `__setitem__` is stored and + /// can be retrieved unmodified via `__getitem__`. Therefore, we currently only + /// perform assignment-based narrowing on a few built-in classes (`list`, `dict`, + /// `bytesarray`, `TypedDict` and `collections` types) where we are confident that + /// this kind of narrowing can be performed soundly. This is the same approach as + /// pyright. TODO: Other standard library classes may also be considered safe. Also, + /// subclasses of these safe classes that do not override `__getitem__/__setitem__` + /// may be considered safe. + fn is_safe_mutable_class<'db>(db: &'db dyn Db, ty: Type<'db>) -> bool { + const SAFE_MUTABLE_CLASSES: &[KnownClass] = &[ + KnownClass::List, + KnownClass::Dict, + KnownClass::Bytearray, + KnownClass::DefaultDict, + KnownClass::ChainMap, + KnownClass::Counter, + KnownClass::Deque, + KnownClass::OrderedDict, + ]; + + SAFE_MUTABLE_CLASSES + .iter() + .map(|class| class.to_instance(db)) + .any(|safe_mutable_class| { + ty.is_equivalent_to(db, safe_mutable_class) + || ty + .generic_origin(db) + .zip(safe_mutable_class.generic_origin(db)) + .is_some_and(|(l, r)| l == r) + }) + } + debug_assert!( binding .kind(self.db()) @@ -2026,38 +2059,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let value_ty = self .try_expression_type(value) .unwrap_or_else(|| self.infer_expression(value)); - // Arbitrary `__getitem__`/`__setitem__` methods on a class do not - // necessarily guarantee that the passed-in value for `__setitem__` is stored and - // can be retrieved unmodified via `__getitem__`. Therefore, we currently only - // perform assignment-based narrowing on a few built-in classes (`list`, `dict`, - // `bytesarray`, `TypedDict` and `collections` types) where we are confident that - // this kind of narrowing can be performed soundly. This is the same approach as - // pyright. TODO: Other standard library classes may also be considered safe. Also, - // subclasses of these safe classes that do not override `__getitem__/__setitem__` - // may be considered safe. - let is_safe_mutable_class = || { - let safe_mutable_classes = [ - KnownClass::List.to_instance(db), - KnownClass::Dict.to_instance(db), - KnownClass::Bytearray.to_instance(db), - KnownClass::DefaultDict.to_instance(db), - SpecialFormType::ChainMap.instance_fallback(db), - SpecialFormType::Counter.instance_fallback(db), - SpecialFormType::Deque.instance_fallback(db), - SpecialFormType::OrderedDict.instance_fallback(db), - SpecialFormType::TypedDict.instance_fallback(db), - ]; - safe_mutable_classes.iter().any(|safe_mutable_class| { - value_ty.is_equivalent_to(db, *safe_mutable_class) - || value_ty - .generic_origin(db) - .zip(safe_mutable_class.generic_origin(db)) - .is_some_and(|(l, r)| l == r) - }) - }; - - if !value_ty.is_typed_dict() && !is_safe_mutable_class() { + if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) { bound_ty = declared_ty; } }