diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index b9cd306a35..f8246c883b 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -122,9 +122,6 @@ class CustomerModel(ModelBase): id: int name: str -# TODO: this is not supported yet -# error: [unknown-argument] -# error: [unknown-argument] CustomerModel(id=1, name="Test") ``` @@ -216,11 +213,7 @@ class OrderedModelBase: ... class TestWithBase(OrderedModelBase): inner: int -# TODO: No errors here, should reveal `bool` -# error: [too-many-positional-arguments] -# error: [too-many-positional-arguments] -# error: [unsupported-operator] -reveal_type(TestWithBase(1) < TestWithBase(2)) # revealed: Unknown +reveal_type(TestWithBase(1) < TestWithBase(2)) # revealed: bool ``` ### `kw_only_default` @@ -277,8 +270,7 @@ class ModelBase: ... class TestBase(ModelBase): name: str -# TODO: This should be `(self: TestBase, *, name: str) -> None` -reveal_type(TestBase.__init__) # revealed: def __init__(self) -> None +reveal_type(TestBase.__init__) # revealed: (self: TestBase, *, name: str) -> None ``` ### `frozen_default` @@ -333,12 +325,9 @@ class ModelBase: ... class TestMeta(ModelBase): name: str -# TODO: no error here -# error: [unknown-argument] t = TestMeta(name="test") -# TODO: this should be an `invalid-assignment` error -t.name = "new" +t.name = "new" # error: [invalid-assignment] ``` ### Combining parameters @@ -437,19 +426,15 @@ class DefaultFrozenModel: class Frozen(DefaultFrozenModel): name: str -# TODO: no error here -# error: [unknown-argument] f = Frozen(name="test") -# TODO: this should be an `invalid-assignment` error -f.name = "new" +f.name = "new" # error: [invalid-assignment] class Mutable(DefaultFrozenModel, frozen=False): name: str -# TODO: no error here -# error: [unknown-argument] m = Mutable(name="test") -m.name = "new" # No error +# TODO: This should not be an error +m.name = "new" # error: [invalid-assignment] ``` ## `field_specifiers` @@ -532,12 +517,8 @@ class Person(FancyBase): name: str = fancy_field() age: int | None = fancy_field(kw_only=True) -# TODO: should be (self: Person, name: str = Unknown, *, age: int | None = Unknown) -> None -reveal_type(Person.__init__) # revealed: def __init__(self) -> None +reveal_type(Person.__init__) # revealed: (self: Person, name: str, *, age: int | None) -> None -# TODO: shouldn't be an error -# error: [too-many-positional-arguments] -# error: [unknown-argument] alice = Person("Alice", age=30) reveal_type(alice.id) # revealed: int diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index d62ee839ce..c2a2ccd823 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -190,7 +190,11 @@ pub(crate) enum CodeGeneratorKind<'db> { } impl<'db> CodeGeneratorKind<'db> { - pub(crate) fn from_class(db: &'db dyn Db, class: ClassLiteral<'db>) -> Option { + pub(crate) fn from_class( + db: &'db dyn Db, + class: ClassLiteral<'db>, + specialization: Option>, + ) -> Option { #[salsa::tracked( cycle_fn=code_generator_of_class_recover, cycle_initial=code_generator_of_class_initial, @@ -199,11 +203,20 @@ impl<'db> CodeGeneratorKind<'db> { fn code_generator_of_class<'db>( db: &'db dyn Db, class: ClassLiteral<'db>, + specialization: Option>, ) -> Option> { if class.dataclass_params(db).is_some() { Some(CodeGeneratorKind::DataclassLike(None)) } else if let Ok((_, Some(transformer_params))) = class.try_metaclass(db) { Some(CodeGeneratorKind::DataclassLike(Some(transformer_params))) + } else if let Some(transformer_params) = + class.iter_mro(db, specialization).skip(1).find_map(|base| { + base.into_class().and_then(|class| { + class.class_literal(db).0.dataclass_transformer_params(db) + }) + }) + { + Some(CodeGeneratorKind::DataclassLike(Some(transformer_params))) } else if class .explicit_bases(db) .contains(&Type::SpecialForm(SpecialFormType::NamedTuple)) @@ -219,6 +232,7 @@ impl<'db> CodeGeneratorKind<'db> { fn code_generator_of_class_initial<'db>( _db: &'db dyn Db, _class: ClassLiteral<'db>, + _specialization: Option>, ) -> Option> { None } @@ -229,21 +243,37 @@ impl<'db> CodeGeneratorKind<'db> { _value: &Option>, _count: u32, _class: ClassLiteral<'db>, + _specialization: Option>, ) -> salsa::CycleRecoveryAction>> { salsa::CycleRecoveryAction::Iterate } - code_generator_of_class(db, class) + code_generator_of_class(db, class, specialization) } - pub(super) fn matches(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> bool { + pub(super) fn matches( + self, + db: &'db dyn Db, + class: ClassLiteral<'db>, + specialization: Option>, + ) -> bool { matches!( - (CodeGeneratorKind::from_class(db, class), self), + ( + CodeGeneratorKind::from_class(db, class, specialization), + self + ), (Some(Self::DataclassLike(_)), Self::DataclassLike(_)) | (Some(Self::NamedTuple), Self::NamedTuple) | (Some(Self::TypedDict), Self::TypedDict) ) } + + pub(super) fn dataclass_transformer_params(self) -> Option> { + match self { + Self::DataclassLike(params) => params, + Self::NamedTuple | Self::TypedDict => None, + } + } } /// A specialization of a generic class with a particular assignment of types to typevars. @@ -2200,7 +2230,7 @@ impl<'db> ClassLiteral<'db> { }; } - if CodeGeneratorKind::NamedTuple.matches(db, self) { + if CodeGeneratorKind::NamedTuple.matches(db, self, specialization) { if let Some(field) = self .own_fields(db, specialization, CodeGeneratorKind::NamedTuple) .get(name) @@ -2262,7 +2292,7 @@ impl<'db> ClassLiteral<'db> { ) -> Option> { let dataclass_params = self.dataclass_params(db); - let field_policy = CodeGeneratorKind::from_class(db, self)?; + let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?; let transformer_params = if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy { @@ -2808,7 +2838,7 @@ impl<'db> ClassLiteral<'db> { .filter_map(|superclass| { if let Some(class) = superclass.into_class() { let (class_literal, specialization) = class.class_literal(db); - if field_policy.matches(db, class_literal) { + if field_policy.matches(db, class_literal, specialization) { Some((class_literal, specialization)) } else { None @@ -3623,7 +3653,7 @@ impl<'db> VarianceInferable<'db> for ClassLiteral<'db> { .map(|class| class.variance_of(db, typevar)); let default_attribute_variance = { - let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self); + let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self, None); // Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses // their field types in contravariant position, thus meaning a frozen dataclass must // still be invariant in its field types. Other synthesized methods on dataclasses are diff --git a/crates/ty_python_semantic/src/types/ide_support.rs b/crates/ty_python_semantic/src/types/ide_support.rs index c47a508d61..ea436b4163 100644 --- a/crates/ty_python_semantic/src/types/ide_support.rs +++ b/crates/ty_python_semantic/src/types/ide_support.rs @@ -122,7 +122,7 @@ impl<'db> AllMembers<'db> { self.extend_with_instance_members(db, ty, class_literal); // If this is a NamedTuple instance, include members from NamedTupleFallback - if CodeGeneratorKind::NamedTuple.matches(db, class_literal) { + if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) { self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db)); } } @@ -142,7 +142,7 @@ impl<'db> AllMembers<'db> { Type::ClassLiteral(class_literal) => { self.extend_with_class_members(db, ty, class_literal); - if CodeGeneratorKind::NamedTuple.matches(db, class_literal) { + if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) { self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db)); } @@ -153,7 +153,7 @@ impl<'db> AllMembers<'db> { Type::GenericAlias(generic_alias) => { let class_literal = generic_alias.origin(db); - if CodeGeneratorKind::NamedTuple.matches(db, class_literal) { + if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) { self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db)); } self.extend_with_class_members(db, ty, class_literal); @@ -164,7 +164,7 @@ impl<'db> AllMembers<'db> { let class_literal = class_type.class_literal(db).0; self.extend_with_class_members(db, ty, class_literal); - if CodeGeneratorKind::NamedTuple.matches(db, class_literal) { + if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) { self.extend_with_type( db, KnownClass::NamedTupleFallback.to_class_literal(db), diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 5ebcfcda15..2dee39100e 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -577,7 +577,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { continue; } - let is_named_tuple = CodeGeneratorKind::NamedTuple.matches(self.db(), class); + let is_named_tuple = CodeGeneratorKind::NamedTuple.matches(self.db(), class, None); // (2) If it's a `NamedTuple` class, check that no field without a default value // appears after a field with a default value. @@ -898,7 +898,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // (7) Check that a dataclass does not have more than one `KW_ONLY`. if let Some(field_policy @ CodeGeneratorKind::DataclassLike(_)) = - CodeGeneratorKind::from_class(self.db(), class) + CodeGeneratorKind::from_class(self.db(), class, None) { let specialization = None; @@ -4569,11 +4569,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { .dataclass_params(db) .map(|params| SmallVec::from(params.field_specifiers(db))) .or_else(|| { - class_literal - .try_metaclass(db) - .ok() - .and_then(|(_, params)| params) - .map(|params| SmallVec::from(params.field_specifiers(db))) + Some(SmallVec::from( + CodeGeneratorKind::from_class(db, class_literal, None)? + .dataclass_transformer_params()? + .field_specifiers(db), + )) }) }