diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index 8b7749359b..4805db9d0d 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -455,6 +455,82 @@ m.name = "new" # No error reveal_type(Mutable(name="A") < Mutable(name="B")) # revealed: bool ``` +## Other `dataclass` parameters + +Other parameters from normal dataclasses can also be set on models created using +`dataclass_transform`. + +### Using function-based transformers + +```py +from typing_extensions import dataclass_transform, TypeVar, Callable + +T = TypeVar("T", bound=type) + +@dataclass_transform() +def fancy_model(*, slots: bool = False) -> Callable[[T], T]: + raise NotImplementedError + +@fancy_model() +class NoSlots: + name: str + +NoSlots.__slots__ # error: [unresolved-attribute] + +@fancy_model(slots=True) +class WithSlots: + name: str + +reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]] +``` + +### Using metaclass-based transformers + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +class FancyMeta(type): + def __new__(cls, name, bases, namespace, *, slots: bool = False): + ... + return super().__new__(cls, name, bases, namespace) + +class FancyBase(metaclass=FancyMeta): ... + +class NoSlots(FancyBase): + name: str + +# error: [unresolved-attribute] +NoSlots.__slots__ + +class WithSlots(FancyBase, slots=True): + name: str + +reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]] +``` + +### Using base-class-based transformers + +```py +from typing_extensions import dataclass_transform + +@dataclass_transform() +class FancyBase: + def __init_subclass__(cls, *, slots: bool = False): + ... + super().__init_subclass__() + +class NoSlots(FancyBase): + name: str + +NoSlots.__slots__ # error: [unresolved-attribute] + +class WithSlots(FancyBase, slots=True): + name: str + +reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]] +``` + ## `field_specifiers` The `field_specifiers` argument can be used to specify a list of functions that should be treated diff --git a/crates/ty_python_semantic/src/types.rs b/crates/ty_python_semantic/src/types.rs index 518c0f1c59..cb63e03ef6 100644 --- a/crates/ty_python_semantic/src/types.rs +++ b/crates/ty_python_semantic/src/types.rs @@ -622,6 +622,19 @@ bitflags! { } } +pub(crate) const DATACLASS_FLAGS: &[(&str, DataclassFlags)] = &[ + ("init", DataclassFlags::INIT), + ("repr", DataclassFlags::REPR), + ("eq", DataclassFlags::EQ), + ("order", DataclassFlags::ORDER), + ("unsafe_hash", DataclassFlags::UNSAFE_HASH), + ("frozen", DataclassFlags::FROZEN), + ("match_args", DataclassFlags::MATCH_ARGS), + ("kw_only", DataclassFlags::KW_ONLY), + ("slots", DataclassFlags::SLOTS), + ("weakref_slot", DataclassFlags::WEAKREF_SLOT), +]; + impl get_size2::GetSize for DataclassFlags {} impl Default for DataclassFlags { diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index 54a17e4b2b..a99ac8b1ef 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -35,11 +35,11 @@ use crate::types::generics::{ use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters}; use crate::types::tuple::{TupleLength, TupleType}; use crate::types::{ - BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams, - FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, - NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet, - TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind, - enums, ide_support, infer_isolated_expression, todo_type, + BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DATACLASS_FLAGS, DataclassFlags, + DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, + MemberLookupPolicy, NominalInstanceType, PropertyInstanceType, SpecialFormType, + TrackedConstraintSet, TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, + WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type, }; use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity}; use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion}; @@ -1134,28 +1134,12 @@ impl<'db> Bindings<'db> { ); let mut flags = dataclass_params.flags(db); - if let Ok(Some(Type::BooleanLiteral(order))) = - overload.parameter_type_by_name("order", false) - { - flags.set(DataclassFlags::ORDER, order); - } - - if let Ok(Some(Type::BooleanLiteral(eq))) = - overload.parameter_type_by_name("eq", false) - { - flags.set(DataclassFlags::EQ, eq); - } - - if let Ok(Some(Type::BooleanLiteral(kw_only))) = - overload.parameter_type_by_name("kw_only", false) - { - flags.set(DataclassFlags::KW_ONLY, kw_only); - } - - if let Ok(Some(Type::BooleanLiteral(frozen))) = - overload.parameter_type_by_name("frozen", false) - { - flags.set(DataclassFlags::FROZEN, frozen); + for (param, flag) in DATACLASS_FLAGS { + if let Ok(Some(Type::BooleanLiteral(value))) = + overload.parameter_type_by_name(param, false) + { + flags.set(*flag, value); + } } Type::DataclassDecorator(DataclassParams::new( diff --git a/crates/ty_python_semantic/src/types/class.rs b/crates/ty_python_semantic/src/types/class.rs index ca83e6d981..184a2d8e15 100644 --- a/crates/ty_python_semantic/src/types/class.rs +++ b/crates/ty_python_semantic/src/types/class.rs @@ -32,13 +32,13 @@ use crate::types::tuple::{TupleSpec, TupleType}; use crate::types::typed_dict::typed_dict_params_from_class_def; use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard}; use crate::types::{ - ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassFlags, - DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, - IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, - MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, - TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable, - declaration_type, determine_upper_bound, exceeds_max_specialization_depth, - infer_definition_types, + ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DATACLASS_FLAGS, + DataclassFlags, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, + HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, + ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType, + StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypedDictParams, + UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound, + exceeds_max_specialization_depth, infer_definition_types, }; use crate::{ Db, FxIndexMap, FxIndexSet, FxOrderSet, Program, @@ -2229,12 +2229,10 @@ impl<'db> ClassLiteral<'db> { if let Some(is_set) = keyword.value.as_boolean_literal_expr().map(|b| b.value) { - match arg_name.as_str() { - "eq" => flags.set(DataclassFlags::EQ, is_set), - "order" => flags.set(DataclassFlags::ORDER, is_set), - "kw_only" => flags.set(DataclassFlags::KW_ONLY, is_set), - "frozen" => flags.set(DataclassFlags::FROZEN, is_set), - _ => {} + for (flag_name, flag) in DATACLASS_FLAGS { + if arg_name.as_str() == *flag_name { + flags.set(*flag, is_set); + } } } }