diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md index e3fd368b1e..113c63f168 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclasses.md @@ -640,6 +640,8 @@ reveal_type(C.__init__) # revealed: (self: C, normal: int, conditionally_presen python-version = "3.12" ``` +### Basic + ```py from dataclasses import dataclass @@ -658,6 +660,34 @@ reveal_type(d_int.description) # revealed: str DataWithDescription[int](None, "description") ``` +### Deriving from generic dataclasses + +This is a regression test for . + +```py +from dataclasses import dataclass + +@dataclass +class Wrap[T]: + data: T + +reveal_type(Wrap[int].__init__) # revealed: (self: Wrap[int], data: int) -> None + +@dataclass +class WrappedInt(Wrap[int]): + other_field: str + +reveal_type(WrappedInt.__init__) # revealed: (self: WrappedInt, data: int, other_field: str) -> None + +# Make sure that another generic type parameter does not affect the `data` field +@dataclass +class WrappedIntAndExtraData[T](Wrap[int]): + extra_data: T + +# revealed: (self: WrappedIntAndExtraData[bytes], data: int, extra_data: bytes) -> None +reveal_type(WrappedIntAndExtraData[bytes].__init__) +``` + ## Descriptor-typed fields ### Same type in `__get__` and `__set__` diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 7f4d5a5ff5..1b391029f0 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -1668,16 +1668,16 @@ impl<'db> ClassLiteral<'db> { if field_policy == CodeGeneratorKind::NamedTuple { // NamedTuples do not allow multiple inheritance, so it is sufficient to enumerate the // fields of this class only. - return self.own_fields(db); + return self.own_fields(db, specialization); } let matching_classes_in_mro: Vec<_> = self .iter_mro(db, specialization) .filter_map(|superclass| { if let Some(class) = superclass.into_class() { - let class_literal = class.class_literal(db).0; + let (class_literal, specialization) = class.class_literal(db); if field_policy.matches(db, class_literal) { - Some(class_literal) + Some((class_literal, specialization)) } else { None } @@ -1691,7 +1691,7 @@ impl<'db> ClassLiteral<'db> { matching_classes_in_mro .into_iter() .rev() - .flat_map(|class| class.own_fields(db)) + .flat_map(|(class, specialization)| class.own_fields(db, specialization)) // We collect into a FxOrderMap here to deduplicate attributes .collect() } @@ -1707,7 +1707,11 @@ impl<'db> ClassLiteral<'db> { /// y: str = "a" /// ``` /// we return a map `{"x": (int, None), "y": (str, Some(Literal["a"]))}`. - fn own_fields(self, db: &'db dyn Db) -> FxOrderMap, Option>)> { + fn own_fields( + self, + db: &'db dyn Db, + specialization: Option>, + ) -> FxOrderMap, Option>)> { let mut attributes = FxOrderMap::default(); let class_body_scope = self.body_scope(db); @@ -1747,7 +1751,14 @@ impl<'db> ClassLiteral<'db> { let bindings = use_def.end_of_scope_bindings(place_id); let default_ty = place_from_bindings(db, bindings).ignore_possibly_unbound(); - attributes.insert(place_expr.expect_name().clone(), (attr_ty, default_ty)); + attributes.insert( + place_expr.expect_name().clone(), + ( + attr_ty.apply_optional_specialization(db, specialization), + default_ty + .map(|ty| ty.apply_optional_specialization(db, specialization)), + ), + ); } } }