[ty] Handle field specifier functions that accept **kwargs and recognize metaclass-based transformers as instances of DataclassInstance (#22018)

## Summary

This contains two bug fixes:

- [Handle field specifier functions that accept
`**kwargs`](ad6918d505)
- [Recognize metaclass-based transformers as instances of
`DataclassInstance`](1a8e29b23c)

closes https://github.com/astral-sh/ty/issues/1987

## Test Plan

* New Markdown tests
* Made sure that the example in 1987 checks without errors
This commit is contained in:
David Peter 2025-12-17 14:22:16 +01:00 committed by GitHub
parent 764ad8b29b
commit 2a61fe2353
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 201 additions and 21 deletions

View file

@ -643,6 +643,91 @@ reveal_type(Person.__init__) # revealed: (self: Person, name: str) -> None
Person(name="Alice")
```
### Field specifiers using `**kwargs`
Some field specifiers may use `**kwargs` to pass through standard parameters like `default`,
`default_factory`, `init`, `kw_only`, and `alias`. This section tests that all these parameters work
correctly when passed via `**kwargs` for all three kinds of transformers.
#### Function-based transformer
```py
from typing import Any
from typing_extensions import dataclass_transform
def field(**kwargs: Any) -> Any: ...
@dataclass_transform(field_specifiers=(field,))
def create_model[T](cls: type[T]) -> type[T]:
return cls
@create_model
class Person:
id: int = field(init=False)
name: str
age: int = field(default=0)
tags: list[str] = field(default_factory=list)
email: str = field(kw_only=True)
internal_notes: str = field(alias="notes")
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
reveal_type(Person.__init__)
Person("Alice", 30, [], "some notes", email="alice@example.com")
Person("Bob", email="bob@example.com", notes="other notes")
```
#### Metaclass-based transformer
```py
from typing import Any
from typing_extensions import dataclass_transform
def field(**kwargs: Any) -> Any: ...
@dataclass_transform(field_specifiers=(field,))
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class Person(ModelBase):
id: int = field(init=False)
name: str
age: int = field(default=0)
tags: list[str] = field(default_factory=list)
email: str = field(kw_only=True)
internal_notes: str = field(alias="notes")
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
reveal_type(Person.__init__)
Person("Alice", 30, [], "some notes", email="alice@example.com")
Person("Bob", email="bob@example.com", notes="other notes")
```
#### Base-class-based transformer
```py
from typing import Any
from typing_extensions import dataclass_transform
def field(**kwargs: Any) -> Any: ...
@dataclass_transform(field_specifiers=(field,))
class ModelBase: ...
class Person(ModelBase):
id: int = field(init=False)
name: str
age: int = field(default=0)
tags: list[str] = field(default_factory=list)
email: str = field(kw_only=True)
internal_notes: str = field(alias="notes")
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
reveal_type(Person.__init__)
Person("Alice", 30, [], "some notes", email="alice@example.com")
Person("Bob", email="bob@example.com", notes="other notes")
```
### Support for `alias`
The `alias` parameter in field specifiers allows providing an alternative name for the parameter in
@ -868,4 +953,83 @@ reveal_type(t.key) # revealed: int
reveal_type(t.name) # revealed: str
```
## `__dataclass_fields__` and `DataclassInstance` protocol
Classes created via `dataclass_transform` should have `__dataclass_fields__` and
`__dataclass_params__` attributes, allowing them to satisfy the `DataclassInstance` protocol. This
enables use of `dataclasses.fields`, `dataclasses.asdict`, `dataclasses.replace`, etc.
### Function-based transformer
```py
from dataclasses import fields, asdict, replace, Field
from typing import dataclass_transform, Any
@dataclass_transform()
def create_model[T](cls: type[T]) -> type[T]:
return cls
@create_model
class Person:
name: str
age: int
p = Person("Alice", 30)
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
### Metaclass-based transformer
```py
from dataclasses import fields, asdict, replace, Field
from typing import dataclass_transform, Any
@dataclass_transform()
class ModelMeta(type): ...
class ModelBase(metaclass=ModelMeta): ...
class Person(ModelBase):
name: str
age: int
p = Person("Alice", 30)
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
### Base-class-based transformer
```py
from dataclasses import fields, asdict, replace, Field
from typing import dataclass_transform, Any
@dataclass_transform()
class ModelBase: ...
class Person(ModelBase):
name: str
age: int
p = Person("Alice", 30)
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
reveal_type(asdict(p)) # revealed: dict[str, Any]
reveal_type(replace(p, name="Bob")) # revealed: Person
```
[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform

View file

@ -214,7 +214,7 @@ impl<'db> Bindings<'db> {
}
}
self.evaluate_known_cases(db, dataclass_field_specifiers);
self.evaluate_known_cases(db, argument_types, dataclass_field_specifiers);
// In order of precedence:
//
@ -337,7 +337,12 @@ impl<'db> Bindings<'db> {
/// Evaluates the return type of certain known callables, where we have special-case logic to
/// determine the return type in a way that isn't directly expressible in the type system.
fn evaluate_known_cases(&mut self, db: &'db dyn Db, dataclass_field_specifiers: &[Type<'db>]) {
fn evaluate_known_cases(
&mut self,
db: &'db dyn Db,
argument_types: &CallArguments<'_, 'db>,
dataclass_field_specifiers: &[Type<'db>],
) {
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
if let Some(Type::BooleanLiteral(value)) = ty {
*value
@ -666,25 +671,32 @@ impl<'db> Bindings<'db> {
if dataclass_field_specifiers.contains(&function)
|| function_type.is_known(db, KnownFunction::Field) =>
{
let has_default_value = overload
.parameter_type_by_name("default", false)
.is_ok_and(|ty| ty.is_some())
|| overload
.parameter_type_by_name("default_factory", false)
.is_ok_and(|ty| ty.is_some())
|| overload
.parameter_type_by_name("factory", false)
.is_ok_and(|ty| ty.is_some());
// Helper to get the type of a keyword argument by name. We first try to get it from
// the parameter binding (for explicit parameters), and then fall back to checking the
// call site arguments (for field-specifier functions that use a `**kwargs` parameter,
// instead of specifying `init`, `default` etc. explicitly).
let get_argument_type = |name, fallback_to_default| -> Option<Type<'db>> {
if let Ok(ty) =
overload.parameter_type_by_name(name, fallback_to_default)
{
return ty;
}
argument_types.iter().find_map(|(arg, ty)| {
if matches!(arg, Argument::Keyword(arg_name) if arg_name == name) {
ty
} else {
None
}
})
};
let init = overload
.parameter_type_by_name("init", true)
.unwrap_or(None);
let kw_only = overload
.parameter_type_by_name("kw_only", true)
.unwrap_or(None);
let alias = overload
.parameter_type_by_name("alias", true)
.unwrap_or(None);
let has_default_value = get_argument_type("default", false).is_some()
|| get_argument_type("default_factory", false).is_some()
|| get_argument_type("factory", false).is_some();
let init = get_argument_type("init", true);
let kw_only = get_argument_type("kw_only", true);
let alias = get_argument_type("alias", true);
// `dataclasses.field` and field-specifier functions of commonly used
// libraries like `pydantic`, `attrs`, and `SQLAlchemy` all return

View file

@ -2277,7 +2277,11 @@ impl<'db> ClassLiteral<'db> {
specialization: Option<Specialization<'db>>,
name: &str,
) -> Member<'db> {
if self.dataclass_params(db).is_some() {
// Check if this class is dataclass-like (either via @dataclass or via dataclass_transform)
if matches!(
CodeGeneratorKind::from_class(db, self, specialization),
Some(CodeGeneratorKind::DataclassLike(_))
) {
if name == "__dataclass_fields__" {
// Make this class look like a subclass of the `DataClassInstance` protocol
return Member {