From 75f3c0e8e6dcd52e8194f9831aef25fddb646685 Mon Sep 17 00:00:00 2001 From: David Peter Date: Thu, 9 Oct 2025 15:24:20 +0200 Subject: [PATCH] [ty] Respect `dataclass_transform` parameters for metaclass-based models (#20780) ## Summary Respect parameters such as `frozen_default` for metaclass-based `@dataclass_transformer` models. Related to: https://github.com/astral-sh/ty/issues/1260 ## Typing conformance changes Those are all correct (new true positives) ## Test Plan New Markdown tests --- .../mdtest/dataclasses/dataclass_transform.md | 123 +++++++++++++++++- crates/ty_python_semantic/src/types/class.rs | 52 ++++---- .../src/types/infer/builder.rs | 2 +- 3 files changed, 151 insertions(+), 26 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index f210e97b0d..240f7b0b17 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -165,7 +165,7 @@ Normal(1) < Normal(2) # error: [unsupported-operator] class NormalOverwritten: inner: int -NormalOverwritten(1) < NormalOverwritten(2) +reveal_type(NormalOverwritten(1) < NormalOverwritten(2)) # revealed: bool @order_default_false class OrderFalse: @@ -177,13 +177,13 @@ OrderFalse(1) < OrderFalse(2) # error: [unsupported-operator] class OrderFalseOverwritten: inner: int -OrderFalseOverwritten(1) < OrderFalseOverwritten(2) +reveal_type(OrderFalseOverwritten(1) < OrderFalseOverwritten(2)) # revealed: bool @order_default_true class OrderTrue: inner: int -OrderTrue(1) < OrderTrue(2) +reveal_type(OrderTrue(1) < OrderTrue(2)) # revealed: bool @order_default_true(order=False) class OrderTrueOverwritten: @@ -193,6 +193,36 @@ class OrderTrueOverwritten: OrderTrueOverwritten(1) < OrderTrueOverwritten(2) ``` +This also works for metaclass-based transformers: + +```py +@dataclass_transform(order_default=True) +class OrderedModelMeta(type): ... + +class OrderedModel(metaclass=OrderedModelMeta): ... + +class TestWithMeta(OrderedModel): + inner: int + +reveal_type(TestWithMeta(1) < TestWithMeta(2)) # revealed: bool +``` + +And for base-class-based transformers: + +```py +@dataclass_transform(order_default=True) +class OrderedModelBase: ... + +class TestWithBase(OrderedModelBase): + inner: int + +# TODO: No errors here, should reveal `bool` +# error: [too-many-positional-arguments] +# error: [too-many-positional-arguments] +# error: [unsupported-operator] +reveal_type(TestWithBase(1) < TestWithBase(2)) # revealed: Unknown +``` + ### `kw_only_default` When provided, sets the default value for the `kw_only` parameter of `field()`. @@ -224,6 +254,33 @@ class CustomerModel: c = CustomerModel(1, "Harry") ``` +This also works for metaclass-based transformers: + +```py +@dataclass_transform(kw_only_default=True) +class ModelMeta(type): ... + +class ModelBase(metaclass=ModelMeta): ... + +class TestMeta(ModelBase): + name: str + +reveal_type(TestMeta.__init__) # revealed: (self: TestMeta, *, name: str) -> None +``` + +And for base-class-based transformers: + +```py +@dataclass_transform(kw_only_default=True) +class ModelBase: ... + +class TestBase(ModelBase): + name: str + +# TODO: This should be `(self: TestBase, *, name: str) -> None` +reveal_type(TestBase.__init__) # revealed: def __init__(self) -> None +``` + ### `frozen_default` When provided, sets the default value for the `frozen` parameter of `field()`. @@ -252,6 +309,38 @@ m = MutableModel(name="test") m.name = "new" # No error ``` +This also works for metaclass-based transformers: + +```py +@dataclass_transform(frozen_default=True) +class ModelMeta(type): ... + +class ModelBase(metaclass=ModelMeta): ... + +class TestMeta(ModelBase): + name: str + +t = TestMeta(name="test") +t.name = "new" # error: [invalid-assignment] +``` + +And for base-class-based transformers: + +```py +@dataclass_transform(frozen_default=True) +class ModelBase: ... + +class TestMeta(ModelBase): + name: str + +# TODO: no error here +# error: [unknown-argument] +t = TestMeta(name="test") + +# TODO: this should be an `invalid-assignment` error +t.name = "new" +``` + ### Combining parameters Combining several of these parameters also works as expected: @@ -367,4 +456,32 @@ D1(1.2) # error: [invalid-argument-type] D2(1.2) # error: [invalid-argument-type] ``` +### Use cases + +#### Home Assistant + +Home Assistant uses a pattern like this, where a `@dataclass`-decorated class inherits from a base +class that is itself a `dataclass`-like construct via a metaclass-based dataclass transformer. Make +sure that we recognize all fields in a hierarchy like this: + +```py +from dataclasses import dataclass +from typing import dataclass_transform + +@dataclass_transform() +class ModelMeta(type): + pass + +class Sensor(metaclass=ModelMeta): + key: int + +@dataclass(frozen=True, kw_only=True) +class TemperatureSensor(Sensor): + name: str + +t = TemperatureSensor(key=1, name="Temperature Sensor") +reveal_type(t.key) # revealed: int +reveal_type(t.name) # revealed: str +``` + [`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index 40575f1097..93a8d8fcf6 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -176,7 +176,7 @@ fn try_metaclass_cycle_initial<'db>( #[derive(Clone, Copy, Debug, PartialEq, salsa::Update, get_size2::GetSize)] pub(crate) enum CodeGeneratorKind { /// Classes decorated with `@dataclass` or similar dataclass-like decorators - DataclassLike, + DataclassLike(Option), /// Classes inheriting from `typing.NamedTuple` NamedTuple, /// Classes inheriting from `typing.TypedDict` @@ -194,12 +194,10 @@ impl CodeGeneratorKind { db: &'db dyn Db, class: ClassLiteral<'db>, ) -> Option { - if class.dataclass_params(db).is_some() - || class - .try_metaclass(db) - .is_ok_and(|(_, transformer_params)| transformer_params.is_some()) - { - Some(CodeGeneratorKind::DataclassLike) + if class.dataclass_params(db).is_some() { + Some(CodeGeneratorKind::DataclassLike(None)) + } else if let Ok((_, Some(transformer_params))) = class.try_metaclass(db) { + Some(CodeGeneratorKind::DataclassLike(Some(transformer_params))) } else if class .explicit_bases(db) .contains(&Type::SpecialForm(SpecialFormType::NamedTuple)) @@ -233,7 +231,12 @@ impl CodeGeneratorKind { } pub(super) fn matches(self, db: &dyn Db, class: ClassLiteral<'_>) -> bool { - CodeGeneratorKind::from_class(db, class) == Some(self) + matches!( + (CodeGeneratorKind::from_class(db, class), self), + (Some(Self::DataclassLike(_)), Self::DataclassLike(_)) + | (Some(Self::NamedTuple), Self::NamedTuple) + | (Some(Self::TypedDict), Self::TypedDict) + ) } } @@ -2152,11 +2155,21 @@ impl<'db> ClassLiteral<'db> { name: &str, ) -> Option> { let dataclass_params = self.dataclass_params(db); - let has_dataclass_param = - |param| dataclass_params.is_some_and(|params| params.contains(param)); let field_policy = CodeGeneratorKind::from_class(db, self)?; + let transformer_params = + if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy { + Some(DataclassParams::from(transformer_params)) + } else { + None + }; + + let has_dataclass_param = |param| { + dataclass_params.is_some_and(|params| params.contains(param)) + || transformer_params.is_some_and(|params| params.contains(param)) + }; + let instance_ty = Type::instance(db, self.apply_optional_specialization(db, specialization)); @@ -2269,13 +2282,8 @@ impl<'db> ClassLiteral<'db> { }; match (field_policy, name) { - (CodeGeneratorKind::DataclassLike, "__init__") => { - let has_synthesized_dunder_init = has_dataclass_param(DataclassParams::INIT) - || self - .try_metaclass(db) - .is_ok_and(|(_, transformer_params)| transformer_params.is_some()); - - if !has_synthesized_dunder_init { + (CodeGeneratorKind::DataclassLike(_), "__init__") => { + if !has_dataclass_param(DataclassParams::INIT) { return None; } @@ -2289,7 +2297,7 @@ impl<'db> ClassLiteral<'db> { .with_annotated_type(KnownClass::Type.to_instance(db)); signature_from_fields(vec![cls_parameter], Some(Type::none(db))) } - (CodeGeneratorKind::DataclassLike, "__lt__" | "__le__" | "__gt__" | "__ge__") => { + (CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => { if !has_dataclass_param(DataclassParams::ORDER) { return None; } @@ -2332,7 +2340,7 @@ impl<'db> ClassLiteral<'db> { ) }) } - (CodeGeneratorKind::DataclassLike, "__replace__") + (CodeGeneratorKind::DataclassLike(_), "__replace__") if Program::get(db).python_version(db) >= PythonVersion::PY313 => { let self_parameter = Parameter::positional_or_keyword(Name::new_static("self")) @@ -2340,7 +2348,7 @@ impl<'db> ClassLiteral<'db> { signature_from_fields(vec![self_parameter], Some(instance_ty)) } - (CodeGeneratorKind::DataclassLike, "__setattr__") => { + (CodeGeneratorKind::DataclassLike(_), "__setattr__") => { if has_dataclass_param(DataclassParams::FROZEN) { let signature = Signature::new( Parameters::new([ @@ -2356,7 +2364,7 @@ impl<'db> ClassLiteral<'db> { } None } - (CodeGeneratorKind::DataclassLike, "__slots__") => { + (CodeGeneratorKind::DataclassLike(_), "__slots__") => { has_dataclass_param(DataclassParams::SLOTS).then(|| { let fields = self.fields(db, specialization, field_policy); let slots = fields.keys().map(|name| Type::string_literal(db, name)); @@ -2794,7 +2802,7 @@ impl<'db> ClassLiteral<'db> { let kind = match field_policy { CodeGeneratorKind::NamedTuple => FieldKind::NamedTuple { default_ty }, - CodeGeneratorKind::DataclassLike => FieldKind::Dataclass { + CodeGeneratorKind::DataclassLike(_) => FieldKind::Dataclass { default_ty, init_only: attr.is_init_var(), init, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 7d438c85f5..9239718093 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -851,7 +851,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } // (6) Check that a dataclass does not have more than one `KW_ONLY`. - if let Some(field_policy @ CodeGeneratorKind::DataclassLike) = + if let Some(field_policy @ CodeGeneratorKind::DataclassLike(_)) = CodeGeneratorKind::from_class(self.db(), class) { let specialization = None;