mirror of
https://github.com/astral-sh/ruff.git
synced 2025-12-23 09:19:39 +00:00
[ty] Handle field specifier functions that accept **kwargs and recognize metaclass-based transformers as instances of DataclassInstance (#22018)
## Summary This contains two bug fixes: - [Handle field specifier functions that accept `**kwargs`](ad6918d505) - [Recognize metaclass-based transformers as instances of `DataclassInstance`](1a8e29b23c) closes https://github.com/astral-sh/ty/issues/1987 ## Test Plan * New Markdown tests * Made sure that the example in 1987 checks without errors
This commit is contained in:
parent
764ad8b29b
commit
2a61fe2353
3 changed files with 201 additions and 21 deletions
|
|
@ -643,6 +643,91 @@ reveal_type(Person.__init__) # revealed: (self: Person, name: str) -> None
|
|||
Person(name="Alice")
|
||||
```
|
||||
|
||||
### Field specifiers using `**kwargs`
|
||||
|
||||
Some field specifiers may use `**kwargs` to pass through standard parameters like `default`,
|
||||
`default_factory`, `init`, `kw_only`, and `alias`. This section tests that all these parameters work
|
||||
correctly when passed via `**kwargs` for all three kinds of transformers.
|
||||
|
||||
#### Function-based transformer
|
||||
|
||||
```py
|
||||
from typing import Any
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
def field(**kwargs: Any) -> Any: ...
|
||||
@dataclass_transform(field_specifiers=(field,))
|
||||
def create_model[T](cls: type[T]) -> type[T]:
|
||||
return cls
|
||||
|
||||
@create_model
|
||||
class Person:
|
||||
id: int = field(init=False)
|
||||
name: str
|
||||
age: int = field(default=0)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
email: str = field(kw_only=True)
|
||||
internal_notes: str = field(alias="notes")
|
||||
|
||||
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
|
||||
reveal_type(Person.__init__)
|
||||
|
||||
Person("Alice", 30, [], "some notes", email="alice@example.com")
|
||||
Person("Bob", email="bob@example.com", notes="other notes")
|
||||
```
|
||||
|
||||
#### Metaclass-based transformer
|
||||
|
||||
```py
|
||||
from typing import Any
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
def field(**kwargs: Any) -> Any: ...
|
||||
@dataclass_transform(field_specifiers=(field,))
|
||||
class ModelMeta(type): ...
|
||||
|
||||
class ModelBase(metaclass=ModelMeta): ...
|
||||
|
||||
class Person(ModelBase):
|
||||
id: int = field(init=False)
|
||||
name: str
|
||||
age: int = field(default=0)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
email: str = field(kw_only=True)
|
||||
internal_notes: str = field(alias="notes")
|
||||
|
||||
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
|
||||
reveal_type(Person.__init__)
|
||||
|
||||
Person("Alice", 30, [], "some notes", email="alice@example.com")
|
||||
Person("Bob", email="bob@example.com", notes="other notes")
|
||||
```
|
||||
|
||||
#### Base-class-based transformer
|
||||
|
||||
```py
|
||||
from typing import Any
|
||||
from typing_extensions import dataclass_transform
|
||||
|
||||
def field(**kwargs: Any) -> Any: ...
|
||||
@dataclass_transform(field_specifiers=(field,))
|
||||
class ModelBase: ...
|
||||
|
||||
class Person(ModelBase):
|
||||
id: int = field(init=False)
|
||||
name: str
|
||||
age: int = field(default=0)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
email: str = field(kw_only=True)
|
||||
internal_notes: str = field(alias="notes")
|
||||
|
||||
# revealed: (self: Person, name: str, age: int = ..., tags: list[str] = ..., notes: str, *, email: str) -> None
|
||||
reveal_type(Person.__init__)
|
||||
|
||||
Person("Alice", 30, [], "some notes", email="alice@example.com")
|
||||
Person("Bob", email="bob@example.com", notes="other notes")
|
||||
```
|
||||
|
||||
### Support for `alias`
|
||||
|
||||
The `alias` parameter in field specifiers allows providing an alternative name for the parameter in
|
||||
|
|
@ -868,4 +953,83 @@ reveal_type(t.key) # revealed: int
|
|||
reveal_type(t.name) # revealed: str
|
||||
```
|
||||
|
||||
## `__dataclass_fields__` and `DataclassInstance` protocol
|
||||
|
||||
Classes created via `dataclass_transform` should have `__dataclass_fields__` and
|
||||
`__dataclass_params__` attributes, allowing them to satisfy the `DataclassInstance` protocol. This
|
||||
enables use of `dataclasses.fields`, `dataclasses.asdict`, `dataclasses.replace`, etc.
|
||||
|
||||
### Function-based transformer
|
||||
|
||||
```py
|
||||
from dataclasses import fields, asdict, replace, Field
|
||||
from typing import dataclass_transform, Any
|
||||
|
||||
@dataclass_transform()
|
||||
def create_model[T](cls: type[T]) -> type[T]:
|
||||
return cls
|
||||
|
||||
@create_model
|
||||
class Person:
|
||||
name: str
|
||||
age: int
|
||||
|
||||
p = Person("Alice", 30)
|
||||
|
||||
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
|
||||
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
|
||||
|
||||
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
|
||||
reveal_type(asdict(p)) # revealed: dict[str, Any]
|
||||
reveal_type(replace(p, name="Bob")) # revealed: Person
|
||||
```
|
||||
|
||||
### Metaclass-based transformer
|
||||
|
||||
```py
|
||||
from dataclasses import fields, asdict, replace, Field
|
||||
from typing import dataclass_transform, Any
|
||||
|
||||
@dataclass_transform()
|
||||
class ModelMeta(type): ...
|
||||
|
||||
class ModelBase(metaclass=ModelMeta): ...
|
||||
|
||||
class Person(ModelBase):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
p = Person("Alice", 30)
|
||||
|
||||
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
|
||||
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
|
||||
|
||||
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
|
||||
reveal_type(asdict(p)) # revealed: dict[str, Any]
|
||||
reveal_type(replace(p, name="Bob")) # revealed: Person
|
||||
```
|
||||
|
||||
### Base-class-based transformer
|
||||
|
||||
```py
|
||||
from dataclasses import fields, asdict, replace, Field
|
||||
from typing import dataclass_transform, Any
|
||||
|
||||
@dataclass_transform()
|
||||
class ModelBase: ...
|
||||
|
||||
class Person(ModelBase):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
p = Person("Alice", 30)
|
||||
|
||||
reveal_type(Person.__dataclass_fields__) # revealed: dict[str, Field[Any]]
|
||||
reveal_type(p.__dataclass_fields__) # revealed: dict[str, Field[Any]]
|
||||
|
||||
reveal_type(fields(Person)) # revealed: tuple[Field[Any], ...]
|
||||
reveal_type(asdict(p)) # revealed: dict[str, Any]
|
||||
reveal_type(replace(p, name="Bob")) # revealed: Person
|
||||
```
|
||||
|
||||
[`typing.dataclass_transform`]: https://docs.python.org/3/library/typing.html#typing.dataclass_transform
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ impl<'db> Bindings<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
self.evaluate_known_cases(db, dataclass_field_specifiers);
|
||||
self.evaluate_known_cases(db, argument_types, dataclass_field_specifiers);
|
||||
|
||||
// In order of precedence:
|
||||
//
|
||||
|
|
@ -337,7 +337,12 @@ impl<'db> Bindings<'db> {
|
|||
|
||||
/// Evaluates the return type of certain known callables, where we have special-case logic to
|
||||
/// determine the return type in a way that isn't directly expressible in the type system.
|
||||
fn evaluate_known_cases(&mut self, db: &'db dyn Db, dataclass_field_specifiers: &[Type<'db>]) {
|
||||
fn evaluate_known_cases(
|
||||
&mut self,
|
||||
db: &'db dyn Db,
|
||||
argument_types: &CallArguments<'_, 'db>,
|
||||
dataclass_field_specifiers: &[Type<'db>],
|
||||
) {
|
||||
let to_bool = |ty: &Option<Type<'_>>, default: bool| -> bool {
|
||||
if let Some(Type::BooleanLiteral(value)) = ty {
|
||||
*value
|
||||
|
|
@ -666,25 +671,32 @@ impl<'db> Bindings<'db> {
|
|||
if dataclass_field_specifiers.contains(&function)
|
||||
|| function_type.is_known(db, KnownFunction::Field) =>
|
||||
{
|
||||
let has_default_value = overload
|
||||
.parameter_type_by_name("default", false)
|
||||
.is_ok_and(|ty| ty.is_some())
|
||||
|| overload
|
||||
.parameter_type_by_name("default_factory", false)
|
||||
.is_ok_and(|ty| ty.is_some())
|
||||
|| overload
|
||||
.parameter_type_by_name("factory", false)
|
||||
.is_ok_and(|ty| ty.is_some());
|
||||
// Helper to get the type of a keyword argument by name. We first try to get it from
|
||||
// the parameter binding (for explicit parameters), and then fall back to checking the
|
||||
// call site arguments (for field-specifier functions that use a `**kwargs` parameter,
|
||||
// instead of specifying `init`, `default` etc. explicitly).
|
||||
let get_argument_type = |name, fallback_to_default| -> Option<Type<'db>> {
|
||||
if let Ok(ty) =
|
||||
overload.parameter_type_by_name(name, fallback_to_default)
|
||||
{
|
||||
return ty;
|
||||
}
|
||||
argument_types.iter().find_map(|(arg, ty)| {
|
||||
if matches!(arg, Argument::Keyword(arg_name) if arg_name == name) {
|
||||
ty
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let init = overload
|
||||
.parameter_type_by_name("init", true)
|
||||
.unwrap_or(None);
|
||||
let kw_only = overload
|
||||
.parameter_type_by_name("kw_only", true)
|
||||
.unwrap_or(None);
|
||||
let alias = overload
|
||||
.parameter_type_by_name("alias", true)
|
||||
.unwrap_or(None);
|
||||
let has_default_value = get_argument_type("default", false).is_some()
|
||||
|| get_argument_type("default_factory", false).is_some()
|
||||
|| get_argument_type("factory", false).is_some();
|
||||
|
||||
let init = get_argument_type("init", true);
|
||||
let kw_only = get_argument_type("kw_only", true);
|
||||
let alias = get_argument_type("alias", true);
|
||||
|
||||
// `dataclasses.field` and field-specifier functions of commonly used
|
||||
// libraries like `pydantic`, `attrs`, and `SQLAlchemy` all return
|
||||
|
|
|
|||
|
|
@ -2277,7 +2277,11 @@ impl<'db> ClassLiteral<'db> {
|
|||
specialization: Option<Specialization<'db>>,
|
||||
name: &str,
|
||||
) -> Member<'db> {
|
||||
if self.dataclass_params(db).is_some() {
|
||||
// Check if this class is dataclass-like (either via @dataclass or via dataclass_transform)
|
||||
if matches!(
|
||||
CodeGeneratorKind::from_class(db, self, specialization),
|
||||
Some(CodeGeneratorKind::DataclassLike(_))
|
||||
) {
|
||||
if name == "__dataclass_fields__" {
|
||||
// Make this class look like a subclass of the `DataClassInstance` protocol
|
||||
return Member {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue