[ty] Support dataclass-transform field_specifiers (#20888)

## Summary

Add support for the `field_specifiers` parameter on
`dataclass_transform` decorator calls.

closes https://github.com/astral-sh/ty/issues/1068

## Conformance test results

All true positives ✔️ 

## Ecosystem analysis

* `trio`: this is the kind of change that I would expect from this PR.
The code makes use of a dataclass `Outcome` with a `_unwrapped: bool =
attr.ib(default=False, eq=False, init=False)` field that is excluded
from the `__init__` signature, so we now see a bunch of
constructor-call-related errors going away.
* `home-assistant/core`: They have a `domain: str = attr.ib(init=False,
repr=False)` field and then use
  ```py
    @domain.default
    def _domain_default(self) -> str:
        # …
  ```
This accesses the `default` attribute on `dataclasses.Field[…]` with a
type of `default: _T | Literal[_MISSING_TYPE.MISSING]`, so we get those
"Object of type `_MISSING_TYPE` is not callable" errors. I don't really
understand how that is supposed to work. Even if `_MISSING_TYPE` would
be absent from that union, what does this try to call? pyright also
issues an error and it doesn't seem to work at runtime? So this looks
like a true positive?
* `attrs`: Similar here. There are some new diagnostics on code that
tries to access `.validator` on a field. This *does* work at runtime,
but I'm not sure how that is supposed to type-check (without a [custom
plugin](2c6c395935/mypy/plugins/attrs.py (L575-L602))).
pyright errors on this as well.
* A handful of new false positives because we don't support `alias` yet

## Test Plan

Updated tests.
This commit is contained in:
David Peter 2025-10-16 20:49:11 +02:00 committed by GitHub
parent 2bffef5966
commit 8dad58de37
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 475 additions and 198 deletions

View file

@ -461,7 +461,7 @@ The [`typing.dataclass_transform`] specification also allows classes (such as `d
to be listed in `field_specifiers`, but it is currently unclear how this should work, and other type
checkers do not seem to support this either.
### Basic example
### For function-based transformers
```py
from typing_extensions import dataclass_transform, Any
@ -478,11 +478,8 @@ class Person:
name: str = fancy_field()
age: int | None = fancy_field(kw_only=True)
# TODO: Should be `(self: Person, name: str, *, age: int | None) -> None`
reveal_type(Person.__init__) # revealed: (self: Person, id: int = Any, name: str = Any, age: int | None = Any) -> None
reveal_type(Person.__init__) # revealed: (self: Person, name: str, *, age: int | None) -> None
# TODO: No error here
# error: [invalid-argument-type]
alice = Person("Alice", age=30)
reveal_type(alice.id) # revealed: int
@ -490,6 +487,145 @@ reveal_type(alice.name) # revealed: str
reveal_type(alice.age) # revealed: int | None
```
### For metaclass-based transformers
```py
from typing_extensions import dataclass_transform, Any
def fancy_field(*, init: bool = True, kw_only: bool = False) -> Any: ...
@dataclass_transform(field_specifiers=(fancy_field,))
class FancyMeta(type):
def __new__(cls, name, bases, namespace):
...
return super().__new__(cls, name, bases, namespace)
class FancyBase(metaclass=FancyMeta): ...
class Person(FancyBase):
id: int = fancy_field(init=False)
name: str = fancy_field()
age: int | None = fancy_field(kw_only=True)
reveal_type(Person.__init__) # revealed: (self: Person, name: str, *, age: int | None) -> None
alice = Person("Alice", age=30)
reveal_type(alice.id) # revealed: int
reveal_type(alice.name) # revealed: str
reveal_type(alice.age) # revealed: int | None
```
### For base-class-based transformers
```py
from typing_extensions import dataclass_transform, Any
def fancy_field(*, init: bool = True, kw_only: bool = False) -> Any: ...
@dataclass_transform(field_specifiers=(fancy_field,))
class FancyBase:
def __init_subclass__(cls):
...
super().__init_subclass__()
class Person(FancyBase):
id: int = fancy_field(init=False)
name: str = fancy_field()
age: int | None = fancy_field(kw_only=True)
# TODO: should be (self: Person, name: str = Unknown, *, age: int | None = Unknown) -> None
reveal_type(Person.__init__) # revealed: def __init__(self) -> None
# TODO: shouldn't be an error
# error: [too-many-positional-arguments]
# error: [unknown-argument]
alice = Person("Alice", age=30)
reveal_type(alice.id) # revealed: int
reveal_type(alice.name) # revealed: str
reveal_type(alice.age) # revealed: int | None
```
### With default arguments
Field specifiers can have default arguments that should be respected:
```py
from typing_extensions import dataclass_transform, Any
def fancy_field(*, init: bool = False) -> Any: ...
@dataclass_transform(field_specifiers=(fancy_field,))
def fancy_model[T](cls: type[T]) -> type[T]:
...
return cls
@fancy_model
class Person:
id: int = fancy_field()
name: str = fancy_field(init=True)
reveal_type(Person.__init__) # revealed: (self: Person, name: str) -> None
Person(name="Alice")
```
### With overloaded field specifiers
```py
from typing_extensions import dataclass_transform, overload, Any
@overload
def fancy_field(*, init: bool = True) -> Any: ...
@overload
def fancy_field(*, kw_only: bool = False) -> Any: ...
def fancy_field(*, init: bool = True, kw_only: bool = False) -> Any: ...
@dataclass_transform(field_specifiers=(fancy_field,))
def fancy_model[T](cls: type[T]) -> type[T]:
...
return cls
@fancy_model
class Person:
id: int = fancy_field(init=False)
name: str = fancy_field()
age: int | None = fancy_field(kw_only=True)
reveal_type(Person.__init__) # revealed: (self: Person, name: str, *, age: int | None) -> None
```
### Nested dataclass-transformers
Make sure that models are only affected by the field specifiers of their own transformer:
```py
from typing_extensions import dataclass_transform, Any
from dataclasses import field
def outer_field(*, init: bool = True, kw_only: bool = False) -> Any: ...
@dataclass_transform(field_specifiers=(outer_field,))
def outer_model[T](cls: type[T]) -> type[T]:
# ...
return cls
def inner_field(*, init: bool = True, kw_only: bool = False) -> Any: ...
@dataclass_transform(field_specifiers=(inner_field,))
def inner_model[T](cls: type[T]) -> type[T]:
# ...
return cls
@outer_model
class Outer:
@inner_model
class Inner:
inner_a: int = inner_field(init=False)
inner_b: str = outer_field(init=False)
outer_a: int = outer_field(init=False)
outer_b: str = inner_field(init=False)
reveal_type(Outer.__init__) # revealed: (self: Outer, outer_b: str = Any) -> None
reveal_type(Outer.Inner.__init__) # revealed: (self: Inner, inner_b: str = Any) -> None
```
## Overloaded dataclass-like decorators
In the case of an overloaded decorator, the `dataclass_transform` decorator can be applied to the

View file

@ -32,7 +32,9 @@ pub(crate) use self::signatures::{CallableSignature, Parameter, Parameters, Sign
pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType};
use crate::module_name::ModuleName;
use crate::module_resolver::{KnownModule, resolve_module};
use crate::place::{Definedness, Place, PlaceAndQualifiers, TypeOrigin, imported_symbol};
use crate::place::{
Definedness, Place, PlaceAndQualifiers, TypeOrigin, imported_symbol, known_module_symbol,
};
use crate::semantic_index::definition::{Definition, DefinitionKind};
use crate::semantic_index::place::ScopedPlaceId;
use crate::semantic_index::scope::ScopeId;
@ -50,7 +52,8 @@ pub use crate::types::display::DisplaySettings;
use crate::types::display::TupleSpecialization;
use crate::types::enums::{enum_metadata, is_single_member_enum};
use crate::types::function::{
DataclassTransformerParams, FunctionSpans, FunctionType, KnownFunction,
DataclassTransformerFlags, DataclassTransformerParams, FunctionSpans, FunctionType,
KnownFunction,
};
use crate::types::generics::{
GenericContext, InferableTypeVars, PartialSpecialization, Specialization, bind_typevar,
@ -618,67 +621,95 @@ impl<'db> PropertyInstanceType<'db> {
}
bitflags! {
/// Used for the return type of `dataclass(…)` calls. Keeps track of the arguments
/// that were passed in. For the precise meaning of the fields, see [1].
/// Used to store metadata about a dataclass or dataclass-like class.
/// For the precise meaning of the fields, see [1].
///
/// [1]: https://docs.python.org/3/library/dataclasses.html
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DataclassParams: u16 {
const INIT = 0b0000_0000_0001;
const REPR = 0b0000_0000_0010;
const EQ = 0b0000_0000_0100;
const ORDER = 0b0000_0000_1000;
const UNSAFE_HASH = 0b0000_0001_0000;
const FROZEN = 0b0000_0010_0000;
const MATCH_ARGS = 0b0000_0100_0000;
const KW_ONLY = 0b0000_1000_0000;
const SLOTS = 0b0001_0000_0000;
const WEAKREF_SLOT = 0b0010_0000_0000;
// This is not an actual argument from `dataclass(...)` but a flag signaling that no
// `field_specifiers` was specified for the `dataclass_transform`, see [1].
// [1]: https://typing.python.org/en/latest/spec/dataclasses.html#dataclass-transform-parameters
const NO_FIELD_SPECIFIERS = 0b0100_0000_0000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct DataclassFlags: u16 {
const INIT = 1 << 0;
const REPR = 1 << 1;
const EQ = 1 << 2;
const ORDER = 1 << 3;
const UNSAFE_HASH = 1 << 4;
const FROZEN = 1 << 5;
const MATCH_ARGS = 1 << 6;
const KW_ONLY = 1 << 7;
const SLOTS = 1 << 8 ;
const WEAKREF_SLOT = 1 << 9;
}
}
impl get_size2::GetSize for DataclassParams {}
impl get_size2::GetSize for DataclassFlags {}
impl Default for DataclassParams {
impl Default for DataclassFlags {
fn default() -> Self {
Self::INIT | Self::REPR | Self::EQ | Self::MATCH_ARGS
}
}
impl From<DataclassTransformerParams> for DataclassParams {
fn from(params: DataclassTransformerParams) -> Self {
impl From<DataclassTransformerFlags> for DataclassFlags {
fn from(params: DataclassTransformerFlags) -> Self {
let mut result = Self::default();
result.set(
Self::EQ,
params.contains(DataclassTransformerParams::EQ_DEFAULT),
params.contains(DataclassTransformerFlags::EQ_DEFAULT),
);
result.set(
Self::ORDER,
params.contains(DataclassTransformerParams::ORDER_DEFAULT),
params.contains(DataclassTransformerFlags::ORDER_DEFAULT),
);
result.set(
Self::KW_ONLY,
params.contains(DataclassTransformerParams::KW_ONLY_DEFAULT),
params.contains(DataclassTransformerFlags::KW_ONLY_DEFAULT),
);
result.set(
Self::FROZEN,
params.contains(DataclassTransformerParams::FROZEN_DEFAULT),
);
result.set(
Self::NO_FIELD_SPECIFIERS,
!params.contains(DataclassTransformerParams::FIELD_SPECIFIERS),
params.contains(DataclassTransformerFlags::FROZEN_DEFAULT),
);
result
}
}
/// Metadata for a dataclass. Stored inside a `Type::DataclassDecorator(…)`
/// instance that we use as the return type of a `dataclasses.dataclass` and
/// dataclass-transformer decorator calls.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub struct DataclassParams<'db> {
flags: DataclassFlags,
#[returns(deref)]
field_specifiers: Box<[Type<'db>]>,
}
impl get_size2::GetSize for DataclassParams<'_> {}
impl<'db> DataclassParams<'db> {
fn default_params(db: &'db dyn Db) -> Self {
Self::from_flags(db, DataclassFlags::default())
}
fn from_flags(db: &'db dyn Db, flags: DataclassFlags) -> Self {
let dataclasses_field = known_module_symbol(db, KnownModule::Dataclasses, "field")
.place
.ignore_possibly_undefined()
.unwrap_or_else(Type::unknown);
Self::new(db, flags, vec![dataclasses_field].into_boxed_slice())
}
fn from_transformer_params(db: &'db dyn Db, params: DataclassTransformerParams<'db>) -> Self {
Self::new(
db,
DataclassFlags::from(params.flags(db)),
params.field_specifiers(db),
)
}
}
/// Representation of a type: a set of possible values at runtime.
///
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update, get_size2::GetSize)]
@ -719,9 +750,9 @@ pub enum Type<'db> {
/// A special callable that is returned by a `dataclass(…)` call. It is usually
/// used as a decorator. Note that this is only used as a return type for actual
/// `dataclass` calls, not for the argumentless `@dataclass` decorator.
DataclassDecorator(DataclassParams),
DataclassDecorator(DataclassParams<'db>),
/// A special callable that is returned by a `dataclass_transform(…)` call.
DataclassTransformer(DataclassTransformerParams),
DataclassTransformer(DataclassTransformerParams<'db>),
/// The type of an arbitrary callable object with a certain specified signature.
Callable(CallableType<'db>),
/// A specific module object
@ -5449,7 +5480,7 @@ impl<'db> Type<'db> {
) -> Result<Bindings<'db>, CallError<'db>> {
self.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types, &TypeContext::default())
.check_types(db, argument_types, &TypeContext::default(), &[])
}
/// Look up a dunder method on the meta-type of `self` and call it.
@ -5501,7 +5532,7 @@ impl<'db> Type<'db> {
let bindings = dunder_callable
.bindings(db)
.match_parameters(db, argument_types)
.check_types(db, argument_types, &tcx)?;
.check_types(db, argument_types, &tcx, &[])?;
if boundness == Definedness::PossiblyUndefined {
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
}

View file

@ -24,7 +24,8 @@ use crate::types::diagnostic::{
};
use crate::types::enums::is_enum_class;
use crate::types::function::{
DataclassTransformerParams, FunctionDecorators, FunctionType, KnownFunction, OverloadLiteral,
DataclassTransformerFlags, DataclassTransformerParams, FunctionDecorators, FunctionType,
KnownFunction, OverloadLiteral,
};
use crate::types::generics::{
InferableTypeVars, Specialization, SpecializationBuilder, SpecializationError,
@ -32,9 +33,9 @@ use crate::types::generics::{
use crate::types::signatures::{Parameter, ParameterForm, ParameterKind, Parameters};
use crate::types::tuple::{TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
BoundMethodType, ClassLiteral, DataclassFlags, DataclassParams, FieldInstance,
KnownBoundMethodType, KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType,
SpecialFormType, TrackedConstraintSet, TypeAliasType, TypeContext, UnionBuilder, UnionType,
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
};
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
@ -135,6 +136,7 @@ impl<'db> Bindings<'db> {
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
call_expression_tcx: &TypeContext<'db>,
dataclass_field_specifiers: &[Type<'db>],
) -> Result<Self, CallError<'db>> {
for element in &mut self.elements {
if let Some(mut updated_argument_forms) =
@ -147,7 +149,7 @@ impl<'db> Bindings<'db> {
}
}
self.evaluate_known_cases(db);
self.evaluate_known_cases(db, dataclass_field_specifiers);
// In order of precedence:
//
@ -269,7 +271,7 @@ impl<'db> Bindings<'db> {
/// Evaluates the return type of certain known callables, where we have special-case logic to
/// determine the return type in a way that isn't directly expressible in the type system.
fn evaluate_known_cases(&mut self, db: &'db dyn Db) {
fn evaluate_known_cases(&mut self, db: &'db dyn Db, dataclass_field_specifiers: &[Type<'db>]) {
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
if let Some(Type::BooleanLiteral(value)) = ty {
*value
@ -596,6 +598,70 @@ impl<'db> Bindings<'db> {
}
}
function @ Type::FunctionLiteral(function_type)
if dataclass_field_specifiers.contains(&function)
|| function_type.is_known(db, KnownFunction::Field) =>
{
let has_default_value = overload
.parameter_type_by_name("default", false)
.is_ok_and(|ty| ty.is_some())
|| overload
.parameter_type_by_name("default_factory", false)
.is_ok_and(|ty| ty.is_some())
|| overload
.parameter_type_by_name("factory", false)
.is_ok_and(|ty| ty.is_some());
let init = overload
.parameter_type_by_name("init", true)
.unwrap_or(None);
let kw_only = overload
.parameter_type_by_name("kw_only", true)
.unwrap_or(None);
// `dataclasses.field` and field-specifier functions of commonly used
// libraries like `pydantic`, `attrs`, and `SQLAlchemy` all return
// the default type for the field (or `Any`) instead of an actual `Field`
// instance, even if this is not what happens at runtime (see also below).
// We still make use of this fact and pretend that all field specifiers
// return the type of the default value:
let default_ty = if has_default_value {
Some(overload.return_ty)
} else {
None
};
let init = init
.map(|init| !init.bool(db).is_always_false())
.unwrap_or(true);
let kw_only = if Program::get(db).python_version(db) >= PythonVersion::PY310
{
match kw_only {
// We are more conservative here when turning the type for `kw_only`
// into a bool, because a field specifier in a stub might use
// `kw_only: bool = ...` and the truthiness of `...` is always true.
// This is different from `init` above because may need to fall back
// to `kw_only_default`, whereas `init_default` does not exist.
Some(Type::BooleanLiteral(yes)) => Some(yes),
_ => None,
}
} else {
None
};
// `typeshed` pretends that `dataclasses.field()` returns the type of the
// default value directly. At runtime, however, this function returns an
// instance of `dataclasses.Field`. We also model it this way and return
// a known-instance type with information about the field. The drawback
// of this approach is that we need to pretend that instances of `Field`
// are assignable to `T` if the default type of the field is assignable
// to `T`. Otherwise, we would error on `name: str = field(default="")`.
overload.set_return_type(Type::KnownInstance(KnownInstanceType::Field(
FieldInstance::new(db, default_ty, init, kw_only),
)));
}
Type::FunctionLiteral(function_type) => match function_type.known(db) {
Some(KnownFunction::IsEquivalentTo) => {
if let [Some(ty_a), Some(ty_b)] = overload.parameter_types() {
@ -871,43 +937,45 @@ impl<'db> Bindings<'db> {
weakref_slot,
] = overload.parameter_types()
{
let mut params = DataclassParams::empty();
let mut flags = DataclassFlags::empty();
if to_bool(init, true) {
params |= DataclassParams::INIT;
flags |= DataclassFlags::INIT;
}
if to_bool(repr, true) {
params |= DataclassParams::REPR;
flags |= DataclassFlags::REPR;
}
if to_bool(eq, true) {
params |= DataclassParams::EQ;
flags |= DataclassFlags::EQ;
}
if to_bool(order, false) {
params |= DataclassParams::ORDER;
flags |= DataclassFlags::ORDER;
}
if to_bool(unsafe_hash, false) {
params |= DataclassParams::UNSAFE_HASH;
flags |= DataclassFlags::UNSAFE_HASH;
}
if to_bool(frozen, false) {
params |= DataclassParams::FROZEN;
flags |= DataclassFlags::FROZEN;
}
if to_bool(match_args, true) {
params |= DataclassParams::MATCH_ARGS;
flags |= DataclassFlags::MATCH_ARGS;
}
if to_bool(kw_only, false) {
if Program::get(db).python_version(db) >= PythonVersion::PY310 {
params |= DataclassParams::KW_ONLY;
flags |= DataclassFlags::KW_ONLY;
} else {
// TODO: emit diagnostic
}
}
if to_bool(slots, false) {
params |= DataclassParams::SLOTS;
flags |= DataclassFlags::SLOTS;
}
if to_bool(weakref_slot, false) {
params |= DataclassParams::WEAKREF_SLOT;
flags |= DataclassFlags::WEAKREF_SLOT;
}
let params = DataclassParams::from_flags(db, flags);
overload.set_return_type(Type::DataclassDecorator(params));
}
@ -915,7 +983,7 @@ impl<'db> Bindings<'db> {
if let [Some(Type::ClassLiteral(class_literal))] =
overload.parameter_types()
{
let params = DataclassParams::default();
let params = DataclassParams::default_params(db);
overload.set_return_type(Type::from(ClassLiteral::new(
db,
class_literal.name(db),
@ -938,82 +1006,39 @@ impl<'db> Bindings<'db> {
_kwargs,
] = overload.parameter_types()
{
let mut params = DataclassTransformerParams::empty();
let mut flags = DataclassTransformerFlags::empty();
if to_bool(eq_default, true) {
params |= DataclassTransformerParams::EQ_DEFAULT;
flags |= DataclassTransformerFlags::EQ_DEFAULT;
}
if to_bool(order_default, false) {
params |= DataclassTransformerParams::ORDER_DEFAULT;
flags |= DataclassTransformerFlags::ORDER_DEFAULT;
}
if to_bool(kw_only_default, false) {
params |= DataclassTransformerParams::KW_ONLY_DEFAULT;
flags |= DataclassTransformerFlags::KW_ONLY_DEFAULT;
}
if to_bool(frozen_default, false) {
params |= DataclassTransformerParams::FROZEN_DEFAULT;
flags |= DataclassTransformerFlags::FROZEN_DEFAULT;
}
if let Some(field_specifiers_type) = field_specifiers {
// For now, we'll do a simple check: if field_specifiers is not
// None/empty, we assume it might contain dataclasses.field
// TODO: Implement proper parsing to check for
// dataclasses.field/Field specifically
if !field_specifiers_type.is_none(db) {
params |= DataclassTransformerParams::FIELD_SPECIFIERS;
}
}
let field_specifiers: Box<[Type<'db>]> = field_specifiers
.map(|tuple_type| {
tuple_type
.exact_tuple_instance_spec(db)
.iter()
.flat_map(|tuple_spec| tuple_spec.fixed_elements())
.copied()
.collect()
})
.unwrap_or_default();
let params =
DataclassTransformerParams::new(db, flags, field_specifiers);
overload.set_return_type(Type::DataclassTransformer(params));
}
}
Some(KnownFunction::Field) => {
let default =
overload.parameter_type_by_name("default").unwrap_or(None);
let default_factory = overload
.parameter_type_by_name("default_factory")
.unwrap_or(None);
let init = overload.parameter_type_by_name("init").unwrap_or(None);
let kw_only =
overload.parameter_type_by_name("kw_only").unwrap_or(None);
// `dataclasses.field` and field-specifier functions of commonly used
// libraries like `pydantic`, `attrs`, and `SQLAlchemy` all return
// the default type for the field (or `Any`) instead of an actual `Field`
// instance, even if this is not what happens at runtime (see also below).
// We still make use of this fact and pretend that all field specifiers
// return the type of the default value:
let default_ty = if default.is_some() || default_factory.is_some() {
Some(overload.return_ty)
} else {
None
};
let init = init
.map(|init| !init.bool(db).is_always_false())
.unwrap_or(true);
let kw_only =
if Program::get(db).python_version(db) >= PythonVersion::PY310 {
kw_only.map(|kw_only| !kw_only.bool(db).is_always_false())
} else {
None
};
// `typeshed` pretends that `dataclasses.field()` returns the type of the
// default value directly. At runtime, however, this function returns an
// instance of `dataclasses.Field`. We also model it this way and return
// a known-instance type with information about the field. The drawback
// of this approach is that we need to pretend that instances of `Field`
// are assignable to `T` if the default type of the field is assignable
// to `T`. Otherwise, we would error on `name: str = field(default="")`.
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::Field(FieldInstance::new(
db, default_ty, init, kw_only,
)),
));
}
_ => {
// Ideally, either the implementation, or exactly one of the overloads
// of the function can have the dataclass_transform decorator applied.
@ -1030,36 +1055,41 @@ impl<'db> Bindings<'db> {
// the argument type and overwrite the corresponding flag in `dataclass_params` after
// constructing them from the `dataclass_transformer`-parameter defaults.
let mut dataclass_params =
DataclassParams::from(params);
let dataclass_params =
DataclassParams::from_transformer_params(
db, params,
);
let mut flags = dataclass_params.flags(db);
if let Ok(Some(Type::BooleanLiteral(order))) =
overload.parameter_type_by_name("order")
overload.parameter_type_by_name("order", false)
{
dataclass_params.set(DataclassParams::ORDER, order);
flags.set(DataclassFlags::ORDER, order);
}
if let Ok(Some(Type::BooleanLiteral(eq))) =
overload.parameter_type_by_name("eq")
overload.parameter_type_by_name("eq", false)
{
dataclass_params.set(DataclassParams::EQ, eq);
flags.set(DataclassFlags::EQ, eq);
}
if let Ok(Some(Type::BooleanLiteral(kw_only))) =
overload.parameter_type_by_name("kw_only")
overload.parameter_type_by_name("kw_only", false)
{
dataclass_params
.set(DataclassParams::KW_ONLY, kw_only);
flags.set(DataclassFlags::KW_ONLY, kw_only);
}
if let Ok(Some(Type::BooleanLiteral(frozen))) =
overload.parameter_type_by_name("frozen")
overload.parameter_type_by_name("frozen", false)
{
dataclass_params
.set(DataclassParams::FROZEN, frozen);
flags.set(DataclassFlags::FROZEN, frozen);
}
Type::DataclassDecorator(dataclass_params)
Type::DataclassDecorator(DataclassParams::new(
db,
flags,
dataclass_params.field_specifiers(db),
))
},
)
})
@ -2843,6 +2873,7 @@ impl<'db> MatchedArgument<'db> {
}
/// Indicates that a parameter of the given name was not found.
#[derive(Debug, Clone, Copy)]
pub(crate) struct UnknownParameterNameError;
/// Binding information for one of the overloads of a callable.
@ -2993,15 +3024,24 @@ impl<'db> Binding<'db> {
pub(crate) fn parameter_type_by_name(
&self,
parameter_name: &str,
fallback_to_default: bool,
) -> Result<Option<Type<'db>>, UnknownParameterNameError> {
let index = self
.signature
.parameters()
let parameters = self.signature.parameters();
let index = parameters
.keyword_by_name(parameter_name)
.map(|(i, _)| i)
.ok_or(UnknownParameterNameError)?;
Ok(self.parameter_tys[index])
let parameter_ty = self.parameter_tys[index];
if parameter_ty.is_some() {
Ok(parameter_ty)
} else if fallback_to_default {
Ok(parameters[index].default_type())
} else {
Ok(None)
}
}
pub(crate) fn arguments_for_parameter<'a>(

View file

@ -32,12 +32,12 @@ use crate::types::tuple::{TupleSpec, TupleType};
use crate::types::typed_dict::typed_dict_params_from_class_def;
use crate::types::visitor::{NonAtomicType, TypeKind, TypeVisitor, walk_non_atomic_type};
use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassParams,
DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor,
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext,
TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable, declaration_type,
determine_upper_bound, infer_definition_types,
ApplyTypeMappingVisitor, Binding, BoundSuperType, CallableType, DataclassFlags,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsDisjointVisitor, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType,
MaterializationKind, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType,
TypeContext, TypeMapping, TypeRelation, TypedDictParams, UnionBuilder, VarianceInferable,
declaration_type, determine_upper_bound, infer_definition_types,
};
use crate::{
Db, FxIndexMap, FxIndexSet, FxOrderSet, Program,
@ -163,7 +163,7 @@ fn try_metaclass_cycle_recover<'db>(
_count: u32,
_self: ClassLiteral<'db>,
) -> salsa::CycleRecoveryAction<
Result<(Type<'db>, Option<DataclassTransformerParams>), MetaclassError<'db>>,
Result<(Type<'db>, Option<DataclassTransformerParams<'db>>), MetaclassError<'db>>,
> {
salsa::CycleRecoveryAction::Iterate
}
@ -172,7 +172,7 @@ fn try_metaclass_cycle_recover<'db>(
fn try_metaclass_cycle_initial<'db>(
_db: &'db dyn Db,
_self_: ClassLiteral<'db>,
) -> Result<(Type<'db>, Option<DataclassTransformerParams>), MetaclassError<'db>> {
) -> Result<(Type<'db>, Option<DataclassTransformerParams<'db>>), MetaclassError<'db>> {
Err(MetaclassError {
kind: MetaclassErrorKind::Cycle,
})
@ -180,17 +180,17 @@ fn try_metaclass_cycle_initial<'db>(
/// A category of classes with code generation capabilities (with synthesized methods).
#[derive(Clone, Copy, Debug, PartialEq, salsa::Update, get_size2::GetSize)]
pub(crate) enum CodeGeneratorKind {
pub(crate) enum CodeGeneratorKind<'db> {
/// Classes decorated with `@dataclass` or similar dataclass-like decorators
DataclassLike(Option<DataclassTransformerParams>),
DataclassLike(Option<DataclassTransformerParams<'db>>),
/// Classes inheriting from `typing.NamedTuple`
NamedTuple,
/// Classes inheriting from `typing.TypedDict`
TypedDict,
}
impl CodeGeneratorKind {
pub(crate) fn from_class(db: &dyn Db, class: ClassLiteral<'_>) -> Option<Self> {
impl<'db> CodeGeneratorKind<'db> {
pub(crate) fn from_class(db: &'db dyn Db, class: ClassLiteral<'db>) -> Option<Self> {
#[salsa::tracked(
cycle_fn=code_generator_of_class_recover,
cycle_initial=code_generator_of_class_initial,
@ -199,7 +199,7 @@ impl CodeGeneratorKind {
fn code_generator_of_class<'db>(
db: &'db dyn Db,
class: ClassLiteral<'db>,
) -> Option<CodeGeneratorKind> {
) -> Option<CodeGeneratorKind<'db>> {
if class.dataclass_params(db).is_some() {
Some(CodeGeneratorKind::DataclassLike(None))
} else if let Ok((_, Some(transformer_params))) = class.try_metaclass(db) {
@ -216,27 +216,27 @@ impl CodeGeneratorKind {
}
}
fn code_generator_of_class_initial(
_db: &dyn Db,
_class: ClassLiteral<'_>,
) -> Option<CodeGeneratorKind> {
fn code_generator_of_class_initial<'db>(
_db: &'db dyn Db,
_class: ClassLiteral<'db>,
) -> Option<CodeGeneratorKind<'db>> {
None
}
#[expect(clippy::ref_option, clippy::trivially_copy_pass_by_ref)]
fn code_generator_of_class_recover(
_db: &dyn Db,
_value: &Option<CodeGeneratorKind>,
#[expect(clippy::ref_option)]
fn code_generator_of_class_recover<'db>(
_db: &'db dyn Db,
_value: &Option<CodeGeneratorKind<'db>>,
_count: u32,
_class: ClassLiteral<'_>,
) -> salsa::CycleRecoveryAction<Option<CodeGeneratorKind>> {
_class: ClassLiteral<'db>,
) -> salsa::CycleRecoveryAction<Option<CodeGeneratorKind<'db>>> {
salsa::CycleRecoveryAction::Iterate
}
code_generator_of_class(db, class)
}
pub(super) fn matches(self, db: &dyn Db, class: ClassLiteral<'_>) -> bool {
pub(super) fn matches(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
matches!(
(CodeGeneratorKind::from_class(db, class), self),
(Some(Self::DataclassLike(_)), Self::DataclassLike(_))
@ -1387,8 +1387,8 @@ pub struct ClassLiteral<'db> {
/// If this class is deprecated, this holds the deprecation message.
pub(crate) deprecated: Option<DeprecatedInstance<'db>>,
pub(crate) dataclass_params: Option<DataclassParams>,
pub(crate) dataclass_transformer_params: Option<DataclassTransformerParams>,
pub(crate) dataclass_params: Option<DataclassParams<'db>>,
pub(crate) dataclass_transformer_params: Option<DataclassTransformerParams<'db>>,
}
// The Salsa heap is tracked separately.
@ -1909,7 +1909,7 @@ impl<'db> ClassLiteral<'db> {
pub(super) fn try_metaclass(
self,
db: &'db dyn Db,
) -> Result<(Type<'db>, Option<DataclassTransformerParams>), MetaclassError<'db>> {
) -> Result<(Type<'db>, Option<DataclassTransformerParams<'db>>), MetaclassError<'db>> {
tracing::trace!("ClassLiteral::try_metaclass: {}", self.name(db));
// Identify the class's own metaclass (or take the first base class's metaclass).
@ -2271,14 +2271,17 @@ impl<'db> ClassLiteral<'db> {
let transformer_params =
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
Some(DataclassParams::from(transformer_params))
Some(DataclassParams::from_transformer_params(
db,
transformer_params,
))
} else {
None
};
let has_dataclass_param = |param| {
dataclass_params.is_some_and(|params| params.contains(param))
|| transformer_params.is_some_and(|params| params.contains(param))
dataclass_params.is_some_and(|params| params.flags(db).contains(param))
|| transformer_params.is_some_and(|params| params.flags(db).contains(param))
};
let instance_ty =
@ -2357,7 +2360,7 @@ impl<'db> ClassLiteral<'db> {
}
let is_kw_only = name == "__replace__"
|| kw_only.unwrap_or(has_dataclass_param(DataclassParams::KW_ONLY));
|| kw_only.unwrap_or(has_dataclass_param(DataclassFlags::KW_ONLY));
let mut parameter = if is_kw_only {
Parameter::keyword_only(field_name)
@ -2395,7 +2398,7 @@ impl<'db> ClassLiteral<'db> {
match (field_policy, name) {
(CodeGeneratorKind::DataclassLike(_), "__init__") => {
if !has_dataclass_param(DataclassParams::INIT) {
if !has_dataclass_param(DataclassFlags::INIT) {
return None;
}
@ -2410,7 +2413,7 @@ impl<'db> ClassLiteral<'db> {
signature_from_fields(vec![cls_parameter], Some(Type::none(db)))
}
(CodeGeneratorKind::DataclassLike(_), "__lt__" | "__le__" | "__gt__" | "__ge__") => {
if !has_dataclass_param(DataclassParams::ORDER) {
if !has_dataclass_param(DataclassFlags::ORDER) {
return None;
}
@ -2461,7 +2464,7 @@ impl<'db> ClassLiteral<'db> {
signature_from_fields(vec![self_parameter], Some(instance_ty))
}
(CodeGeneratorKind::DataclassLike(_), "__setattr__") => {
if has_dataclass_param(DataclassParams::FROZEN) {
if has_dataclass_param(DataclassFlags::FROZEN) {
let signature = Signature::new(
Parameters::new([
Parameter::positional_or_keyword(Name::new_static("self"))
@ -2477,7 +2480,7 @@ impl<'db> ClassLiteral<'db> {
None
}
(CodeGeneratorKind::DataclassLike(_), "__slots__") => {
has_dataclass_param(DataclassParams::SLOTS).then(|| {
has_dataclass_param(DataclassFlags::SLOTS).then(|| {
let fields = self.fields(db, specialization, field_policy);
let slots = fields.keys().map(|name| Type::string_literal(db, name));
Type::heterogeneous_tuple(db, slots)
@ -2901,7 +2904,7 @@ impl<'db> ClassLiteral<'db> {
default_ty = field.default_type(db);
if self
.dataclass_params(db)
.map(|params| params.contains(DataclassParams::NO_FIELD_SPECIFIERS))
.map(|params| params.field_specifiers(db).is_empty())
.unwrap_or(false)
{
// This happens when constructing a `dataclass` with a `dataclass_transform`
@ -3635,7 +3638,7 @@ impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
let is_frozen_dataclass = Program::get(db).python_version(db) <= PythonVersion::PY312
&& self
.dataclass_params(db)
.is_some_and(|params| params.contains(DataclassParams::FROZEN));
.is_some_and(|params| params.flags(db).contains(DataclassFlags::FROZEN));
if is_namedtuple || is_frozen_dataclass {
TypeVarVariance::Covariant
} else {

View file

@ -152,24 +152,36 @@ bitflags! {
/// arguments that were passed in. For the precise meaning of the fields, see [1].
///
/// [1]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, salsa::Update)]
pub struct DataclassTransformerParams: u8 {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, salsa::Update)]
pub struct DataclassTransformerFlags: u8 {
const EQ_DEFAULT = 1 << 0;
const ORDER_DEFAULT = 1 << 1;
const KW_ONLY_DEFAULT = 1 << 2;
const FROZEN_DEFAULT = 1 << 3;
const FIELD_SPECIFIERS= 1 << 4;
}
}
impl get_size2::GetSize for DataclassTransformerParams {}
impl get_size2::GetSize for DataclassTransformerFlags {}
impl Default for DataclassTransformerParams {
impl Default for DataclassTransformerFlags {
fn default() -> Self {
Self::EQ_DEFAULT
}
}
/// Metadata for a dataclass-transformer. Stored inside a `Type::DataclassTransformer(…)`
/// instance that we use as the return type for `dataclass_transform(…)` calls.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
#[derive(PartialOrd, Ord)]
pub struct DataclassTransformerParams<'db> {
pub flags: DataclassTransformerFlags,
#[returns(deref)]
pub field_specifiers: Box<[Type<'db>]>,
}
impl get_size2::GetSize for DataclassTransformerParams<'_> {}
/// Representation of a function definition in the AST: either a non-generic function, or a generic
/// function that has not been specialized.
///
@ -201,7 +213,7 @@ pub struct OverloadLiteral<'db> {
/// The arguments to `dataclass_transformer`, if this function was annotated
/// with `@dataclass_transformer(...)`.
pub(crate) dataclass_transformer_params: Option<DataclassTransformerParams>,
pub(crate) dataclass_transformer_params: Option<DataclassTransformerParams<'db>>,
}
// The Salsa heap is tracked separately.
@ -212,7 +224,7 @@ impl<'db> OverloadLiteral<'db> {
fn with_dataclass_transformer_params(
self,
db: &'db dyn Db,
params: DataclassTransformerParams,
params: DataclassTransformerParams<'db>,
) -> Self {
Self::new(
db,
@ -740,7 +752,7 @@ impl<'db> FunctionType<'db> {
pub(crate) fn with_dataclass_transformer_params(
self,
db: &'db dyn Db,
params: DataclassTransformerParams,
params: DataclassTransformerParams<'db>,
) -> Self {
// A decorator only applies to the specific overload that it is attached to, not to all
// previous overloads.

View file

@ -9,6 +9,7 @@ use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, PythonVersion};
use ruff_python_stdlib::builtins::version_builtin_was_added;
use ruff_text_size::{Ranged, TextRange};
use rustc_hash::{FxHashMap, FxHashSet};
use smallvec::SmallVec;
use super::{
CycleRecovery, DefinitionInference, DefinitionInferenceExtra, ExpressionInference,
@ -152,6 +153,12 @@ type BinaryComparisonVisitor<'db> = CycleDetector<
Result<Type<'db>, CompareUnsupportedError<'db>>,
>;
/// We currently store one dataclass field-specifiers inline, because that covers standard
/// dataclasses. attrs uses 2 specifiers, pydantic and strawberry use 3 specifiers. SQLAlchemy
/// uses 7 field specifiers. We could probably store more inline if this turns out to be a
/// performance problem. For now, we optimize for memory usage.
const NUM_FIELD_SPECIFIERS_INLINE: usize = 1;
/// Builder to infer all types in a region.
///
/// A builder is used by creating it with [`new()`](TypeInferenceBuilder::new), and then calling
@ -277,6 +284,10 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
/// `true` if all places in this expression are definitely bound
all_definitely_bound: bool,
/// A list of `dataclass_transform` field specifiers that are "active" (when inferring
/// the right hand side of an annotated assignment in a class that is a dataclass).
dataclass_field_specifiers: SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>,
}
impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
@ -312,6 +323,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
undecorated_type: None,
cycle_recovery: None,
all_definitely_bound: true,
dataclass_field_specifiers: SmallVec::new(),
}
}
@ -2574,7 +2586,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.as_function_literal()
.is_some_and(|function| function.is_known(self.db(), KnownFunction::Dataclass))
{
dataclass_params = Some(DataclassParams::default());
dataclass_params = Some(DataclassParams::default_params(self.db()));
continue;
}
@ -2595,11 +2607,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// overload, or an overload and the implementation both. Nevertheless, this is not
// allowed. We do not try to treat the offenders intelligently -- just use the
// params of the last seen usage of `@dataclass_transform`
let params = f
let transformer_params = f
.iter_overloads_and_implementation(self.db())
.find_map(|overload| overload.dataclass_transformer_params(self.db()));
if let Some(params) = params {
dataclass_params = Some(params.into());
if let Some(transformer_params) = transformer_params {
dataclass_params = Some(DataclassParams::from_transformer_params(
self.db(),
transformer_params,
));
continue;
}
}
@ -4518,10 +4533,42 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
debug_assert!(PlaceExpr::try_from_expr(target).is_some());
if let Some(value) = value {
fn field_specifiers<'db>(
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
scope: ScopeId<'db>,
) -> Option<SmallVec<[Type<'db>; NUM_FIELD_SPECIFIERS_INLINE]>> {
let enclosing_scope = index.scope(scope.file_scope_id(db));
let class_node = enclosing_scope.node().as_class()?;
let class_definition = index.expect_single_definition(class_node);
let class_literal = infer_definition_types(db, class_definition)
.declaration_type(class_definition)
.inner_type()
.as_class_literal()?;
class_literal
.dataclass_params(db)
.map(|params| SmallVec::from(params.field_specifiers(db)))
.or_else(|| {
class_literal
.try_metaclass(db)
.ok()
.and_then(|(_, params)| params)
.map(|params| SmallVec::from(params.field_specifiers(db)))
})
}
if let Some(specifiers) = field_specifiers(self.db(), self.index, self.scope()) {
self.dataclass_field_specifiers = specifiers;
}
let inferred_ty = self.infer_maybe_standalone_expression(
value,
TypeContext::new(Some(declared.inner_type())),
);
self.dataclass_field_specifiers.clear();
let inferred_ty = if target
.as_name_expr()
.is_some_and(|name| &name.id == "TYPE_CHECKING")
@ -6650,7 +6697,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let mut bindings = match bindings.check_types(self.db(), &call_arguments, &tcx) {
let mut bindings = match bindings.check_types(
self.db(),
&call_arguments,
&tcx,
&self.dataclass_field_specifiers[..],
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, call_expression.into());
@ -9238,8 +9290,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let binding = Binding::single(value_ty, generic_context.signature(self.db()));
let bindings = match Bindings::from(binding)
.match_parameters(self.db(), &call_argument_types)
.check_types(self.db(), &call_argument_types, &TypeContext::default())
{
.check_types(
self.db(),
&call_argument_types,
&TypeContext::default(),
&self.dataclass_field_specifiers[..],
) {
Ok(bindings) => bindings,
Err(CallError(_, bindings)) => {
bindings.report_diagnostics(&self.context, subscript.into());
@ -9771,6 +9827,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred,
cycle_recovery,
all_definitely_bound,
dataclass_field_specifiers: _,
// Ignored; only relevant to definition regions
undecorated_type: _,
@ -9837,8 +9894,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred,
cycle_recovery,
undecorated_type,
all_definitely_bound: _,
// builder only state
dataclass_field_specifiers: _,
all_definitely_bound: _,
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,
@ -9905,12 +9963,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
deferred: _,
bindings: _,
declarations: _,
all_definitely_bound: _,
// Ignored; only relevant to definition regions
undecorated_type: _,
// Builder only state
dataclass_field_specifiers: _,
all_definitely_bound: _,
typevar_binding_context: _,
deferred_state: _,
multi_inference_state: _,

View file

@ -83,15 +83,11 @@ pub(super) fn union_or_intersection_elements_ordering<'db>(
(Type::WrapperDescriptor(_), _) => Ordering::Less,
(_, Type::WrapperDescriptor(_)) => Ordering::Greater,
(Type::DataclassDecorator(left), Type::DataclassDecorator(right)) => {
left.bits().cmp(&right.bits())
}
(Type::DataclassDecorator(left), Type::DataclassDecorator(right)) => left.cmp(right),
(Type::DataclassDecorator(_), _) => Ordering::Less,
(_, Type::DataclassDecorator(_)) => Ordering::Greater,
(Type::DataclassTransformer(left), Type::DataclassTransformer(right)) => {
left.bits().cmp(&right.bits())
}
(Type::DataclassTransformer(left), Type::DataclassTransformer(right)) => left.cmp(right),
(Type::DataclassTransformer(_), _) => Ordering::Less,
(_, Type::DataclassTransformer(_)) => Ordering::Greater,