[ty] Dataclass transform: complete set of parameters

This commit is contained in:
David Peter 2025-11-15 18:11:54 +01:00
parent 29acc1e860
commit 528afe2b05
4 changed files with 111 additions and 40 deletions

View file

@ -455,6 +455,82 @@ m.name = "new" # No error
reveal_type(Mutable(name="A") < Mutable(name="B")) # revealed: bool
```
## Other `dataclass` parameters
Other parameters from normal dataclasses can also be set on models created using
`dataclass_transform`.
### Using function-based transformers
```py
from typing_extensions import dataclass_transform, TypeVar, Callable
T = TypeVar("T", bound=type)
@dataclass_transform()
def fancy_model(*, slots: bool = False) -> Callable[[T], T]:
raise NotImplementedError
@fancy_model()
class NoSlots:
name: str
NoSlots.__slots__ # error: [unresolved-attribute]
@fancy_model(slots=True)
class WithSlots:
name: str
reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]]
```
### Using metaclass-based transformers
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
class FancyMeta(type):
def __new__(cls, name, bases, namespace, *, slots: bool = False):
...
return super().__new__(cls, name, bases, namespace)
class FancyBase(metaclass=FancyMeta): ...
class NoSlots(FancyBase):
name: str
# error: [unresolved-attribute]
NoSlots.__slots__
class WithSlots(FancyBase, slots=True):
name: str
reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]]
```
### Using base-class-based transformers
```py
from typing_extensions import dataclass_transform
@dataclass_transform()
class FancyBase:
def __init_subclass__(cls, *, slots: bool = False):
...
super().__init_subclass__()
class NoSlots(FancyBase):
name: str
NoSlots.__slots__ # error: [unresolved-attribute]
class WithSlots(FancyBase, slots=True):
name: str
reveal_type(WithSlots.__slots__) # revealed: tuple[Literal["name"]]
```
## `field_specifiers`
The `field_specifiers` argument can be used to specify a list of functions that should be treated

View file

@ -622,6 +622,19 @@ bitflags! {
}
}
pub(crate) const DATACLASS_FLAGS: &[(&str, DataclassFlags)] = &[
("init", DataclassFlags::INIT),
("repr", DataclassFlags::REPR),
("eq", DataclassFlags::EQ),
("order", DataclassFlags::ORDER),
("unsafe_hash", DataclassFlags::UNSAFE_HASH),
("frozen", DataclassFlags::FROZEN),
("match_args", DataclassFlags::MATCH_ARGS),
("kw_only", DataclassFlags::KW_ONLY),
("slots", DataclassFlags::SLOTS),
("weakref_slot", DataclassFlags::WEAKREF_SLOT),
];
impl get_size2::GetSize for DataclassFlags {}
impl Default for DataclassFlags {

View file

@ -35,11 +35,11 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DataclassFlags, DataclassParams,
FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy,
NominalInstanceType, PropertyInstanceType, SpecialFormType, TrackedConstraintSet,
TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType, WrapperDescriptorKind,
enums, ide_support, infer_isolated_expression, todo_type,
BoundMethodType, BoundTypeVarIdentity, ClassLiteral, DATACLASS_FLAGS, DataclassFlags,
DataclassParams, FieldInstance, KnownBoundMethodType, KnownClass, KnownInstanceType,
MemberLookupPolicy, NominalInstanceType, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeContext, TypeVarVariance, UnionBuilder, UnionType,
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
@ -1134,28 +1134,12 @@ impl<'db> Bindings<'db> {
);
let mut flags = dataclass_params.flags(db);
if let Ok(Some(Type::BooleanLiteral(order))) =
overload.parameter_type_by_name("order", false)
{
flags.set(DataclassFlags::ORDER, order);
}
if let Ok(Some(Type::BooleanLiteral(eq))) =
overload.parameter_type_by_name("eq", false)
{
flags.set(DataclassFlags::EQ, eq);
}
if let Ok(Some(Type::BooleanLiteral(kw_only))) =
overload.parameter_type_by_name("kw_only", false)
{
flags.set(DataclassFlags::KW_ONLY, kw_only);
}
if let Ok(Some(Type::BooleanLiteral(frozen))) =
overload.parameter_type_by_name("frozen", false)
{
flags.set(DataclassFlags::FROZEN, frozen);
for (param, flag) in DATACLASS_FLAGS {
if let Ok(Some(Type::BooleanLiteral(value))) =
overload.parameter_type_by_name(param, false)
{
flags.set(*flag, value);
}
}
Type::DataclassDecorator(DataclassParams::new(

View file

@ -32,13 +32,13 @@ use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::typed_dict::typed_dict_params_from_class_def;
use crate::types::visitor::{TypeCollector, TypeVisitor, walk_type_with_recursion_guard};
use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassFlags,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType,
MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType,
TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable,
declaration_type, determine_upper_bound, exceeds_max_specialization_depth,
infer_definition_types,
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DATACLASS_FLAGS,
DataclassFlags, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor,
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType,
ManualPEP695TypeAliasType, MaterializationKind, NormalizedVisitor, PropertyInstanceType,
StringLiteralType, TypeAliasType, TypeContext, TypeMapping, TypeRelation, TypedDictParams,
UnionBuilder, VarianceInferable, declaration_type, determine_upper_bound,
exceeds_max_specialization_depth, infer_definition_types,
};
use crate::{
Db, FxIndexMap, FxIndexSet, FxOrderSet, Program,
@ -2229,12 +2229,10 @@ impl<'db> ClassLiteral<'db> {
if let Some(is_set) =
keyword.value.as_boolean_literal_expr().map(|b| b.value)
{
match arg_name.as_str() {
"eq" => flags.set(DataclassFlags::EQ, is_set),
"order" => flags.set(DataclassFlags::ORDER, is_set),
"kw_only" => flags.set(DataclassFlags::KW_ONLY, is_set),
"frozen" => flags.set(DataclassFlags::FROZEN, is_set),
_ => {}
for (flag_name, flag) in DATACLASS_FLAGS {
if arg_name.as_str() == *flag_name {
flags.set(*flag, is_set);
}
}
}
}