[ty] Add basic support for dataclasses.field (#19553)

## Summary

Add basic support for `dataclasses.field`:
* remove fields with `init=False` from the signature of the synthesized
`__init__` method
* infer correct default value types from `default` or `default_factory`
arguments

```py
from dataclasses import dataclass, field

def default_roles() -> list[str]:
    return ["user"]

@dataclass
class Member:
    name: str
    roles: list[str] = field(default_factory=default_roles)
    tag: str | None = field(default=None, init=False)

# revealed: (self: Member, name: str, roles: list[str] = list[str]) -> None
reveal_type(Member.__init__)
```

Support for `kw_only` has **not** been added.

part of https://github.com/astral-sh/ty/issues/111

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-07-25 14:56:04 +02:00 committed by GitHub
parent b033fb6bfd
commit d4eb4277ad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 188 additions and 11 deletions

View file

@ -1279,6 +1279,15 @@ impl<'db> Type<'db> {
// handled above. It's always assignable, though.
(Type::Dynamic(_), _) | (_, Type::Dynamic(_)) => relation.is_assignability(),
// Pretend that instances of `dataclasses.Field` are assignable to their default type.
// This allows field definitions like `name: str = field(default="")` in dataclasses
// to pass the assignability check of the inferred type to the declared type.
(Type::KnownInstance(KnownInstanceType::Field(field)), right)
if relation.is_assignability() =>
{
field.default_type(db).has_relation_to(db, right, relation)
}
// In general, a TypeVar `T` is not a subtype of a type `S` unless one of the two conditions is satisfied:
// 1. `T` is a bound TypeVar and `T`'s upper bound is a subtype of `S`.
// TypeVars without an explicit upper bound are treated as having an implicit upper bound of `object`.
@ -5109,6 +5118,10 @@ impl<'db> Type<'db> {
invalid_expressions: smallvec::smallvec![InvalidTypeExpression::Deprecated],
fallback_type: Type::unknown(),
}),
KnownInstanceType::Field(__call__) => Err(InvalidTypeExpressionError {
invalid_expressions: smallvec::smallvec![InvalidTypeExpression::Field],
fallback_type: Type::unknown(),
}),
KnownInstanceType::SubscriptedProtocol(_) => Err(InvalidTypeExpressionError {
invalid_expressions: smallvec::smallvec_inline![
InvalidTypeExpression::Protocol
@ -5957,6 +5970,9 @@ pub enum KnownInstanceType<'db> {
/// A single instance of `warnings.deprecated` or `typing_extensions.deprecated`
Deprecated(DeprecatedInstance<'db>),
/// A single instance of `dataclasses.Field`
Field(FieldInstance<'db>),
}
fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
@ -5978,6 +5994,9 @@ fn walk_known_instance_type<'db, V: visitor::TypeVisitor<'db> + ?Sized>(
KnownInstanceType::Deprecated(_) => {
// Nothing to visit
}
KnownInstanceType::Field(field) => {
visitor.visit_type(db, field.default_type(db));
}
}
}
@ -5998,6 +6017,7 @@ impl<'db> KnownInstanceType<'db> {
// Nothing to normalize
Self::Deprecated(deprecated)
}
Self::Field(field) => Self::Field(field.normalized_impl(db, visitor)),
}
}
@ -6007,6 +6027,7 @@ impl<'db> KnownInstanceType<'db> {
Self::TypeVar(_) => KnownClass::TypeVar,
Self::TypeAliasType(_) => KnownClass::TypeAliasType,
Self::Deprecated(_) => KnownClass::Deprecated,
Self::Field(_) => KnownClass::Field,
}
}
@ -6052,6 +6073,11 @@ impl<'db> KnownInstanceType<'db> {
// have a `Type::TypeVar(_)`, which is rendered as the typevar's name.
KnownInstanceType::TypeVar(_) => f.write_str("typing.TypeVar"),
KnownInstanceType::Deprecated(_) => f.write_str("warnings.deprecated"),
KnownInstanceType::Field(field) => {
f.write_str("dataclasses.Field[")?;
field.default_type(self.db).display(self.db).fmt(f)?;
f.write_str("]")
}
}
}
}
@ -6261,6 +6287,8 @@ enum InvalidTypeExpression<'db> {
Generic,
/// Same for `@deprecated`
Deprecated,
/// Same for `dataclasses.Field`
Field,
/// Type qualifiers are always invalid in *type expressions*,
/// but these ones are okay with 0 arguments in *annotation expressions*
TypeQualifier(SpecialFormType),
@ -6305,6 +6333,9 @@ impl<'db> InvalidTypeExpression<'db> {
InvalidTypeExpression::Deprecated => {
f.write_str("`warnings.deprecated` is not allowed in type expressions")
}
InvalidTypeExpression::Field => {
f.write_str("`dataclasses.Field` is not allowed in type expressions")
}
InvalidTypeExpression::TypeQualifier(qualifier) => write!(
f,
"Type qualifier `{qualifier}` is not allowed in type expressions \
@ -6371,6 +6402,36 @@ pub struct DeprecatedInstance<'db> {
// The Salsa heap is tracked separately.
impl get_size2::GetSize for DeprecatedInstance<'_> {}
/// Contains information about instances of `dataclasses.Field`, typically created using
/// `dataclasses.field()`.
#[salsa::interned(debug)]
#[derive(PartialOrd, Ord)]
pub struct FieldInstance<'db> {
/// The type of the default value for this field. This is derived from the `default` or
/// `default_factory` arguments to `dataclasses.field()`.
pub default_type: Type<'db>,
/// Whether this field is part of the `__init__` signature, or not.
pub init: bool,
}
// The Salsa heap is tracked separately.
impl get_size2::GetSize for FieldInstance<'_> {}
impl<'db> FieldInstance<'db> {
pub(crate) fn normalized_impl(
self,
db: &'db dyn Db,
visitor: &mut TypeTransformer<'db>,
) -> Self {
FieldInstance::new(
db,
self.default_type(db).normalized_impl(db, visitor),
self.init(db),
)
}
}
/// Whether this typecar was created via the legacy `TypeVar` constructor, or using PEP 695 syntax.
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum TypeVarKind {

View file

@ -28,7 +28,7 @@ use crate::types::generics::{Specialization, SpecializationBuilder, Specializati
use crate::types::signatures::{Parameter, ParameterForm, Parameters};
use crate::types::tuple::{Tuple, TupleLength, TupleType};
use crate::types::{
BoundMethodType, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType,
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownClass, KnownInstanceType,
MethodWrapperKind, PropertyInstanceType, SpecialFormType, TypeMapping, UnionType,
WrapperDescriptorKind, enums, ide_support, todo_type,
};
@ -899,6 +899,36 @@ impl<'db> Bindings<'db> {
}
}
Some(KnownFunction::Field) => {
if let [default, default_factory, init, ..] = overload.parameter_types()
{
let default_ty = match (default, default_factory) {
(Some(default_ty), _) => *default_ty,
(_, Some(default_factory_ty)) => default_factory_ty
.try_call(db, &CallArguments::none())
.map_or(Type::unknown(), |binding| binding.return_type(db)),
_ => Type::unknown(),
};
let init = init
.map(|init| !init.bool(db).is_always_false())
.unwrap_or(true);
// `typeshed` pretends that `dataclasses.field()` returns the type of the
// default value directly. At runtime, however, this function returns an
// instance of `dataclasses.Field`. We also model it this way and return
// a known-instance type with information about the field. The drawback
// of this approach is that we need to pretend that instances of `Field`
// are assignable to `T` if the default type of the field is assignable
// to `T`. Otherwise, we would error on `name: str = field(default="")`.
overload.set_return_type(Type::KnownInstance(
KnownInstanceType::Field(FieldInstance::new(
db, default_ty, init,
)),
));
}
}
_ => {
// Ideally, either the implementation, or exactly one of the overloads
// of the function can have the dataclass_transform decorator applied.

View file

@ -894,6 +894,9 @@ pub(crate) struct DataclassField<'db> {
/// Whether or not this field is "init-only". If this is true, it only appears in the
/// `__init__` signature, but is not accessible as a real field
pub(crate) init_only: bool,
/// Whether or not this field should appear in the signature of `__init__`.
pub(crate) init: bool,
}
/// Representation of a class definition statement in the AST: either a non-generic class, or a
@ -1601,9 +1604,15 @@ impl<'db> ClassLiteral<'db> {
mut field_ty,
mut default_ty,
init_only: _,
init,
},
) in self.fields(db, specialization, field_policy)
{
if name == "__init__" && !init {
// Skip fields with `init=False`
continue;
}
if field_ty
.into_nominal_instance()
.is_some_and(|instance| instance.class.is_known(db, KnownClass::KwOnly))
@ -1852,15 +1861,25 @@ impl<'db> ClassLiteral<'db> {
if let Some(attr_ty) = attr.place.ignore_possibly_unbound() {
let bindings = use_def.end_of_scope_symbol_bindings(symbol_id);
let default_ty = place_from_bindings(db, bindings).ignore_possibly_unbound();
let mut default_ty =
place_from_bindings(db, bindings).ignore_possibly_unbound();
default_ty =
default_ty.map(|ty| ty.apply_optional_specialization(db, specialization));
let mut init = true;
if let Some(Type::KnownInstance(KnownInstanceType::Field(field))) = default_ty {
default_ty = Some(field.default_type(db));
init = field.init(db);
}
attributes.insert(
symbol.name().clone(),
DataclassField {
field_ty: attr_ty.apply_optional_specialization(db, specialization),
default_ty: default_ty
.map(|ty| ty.apply_optional_specialization(db, specialization)),
default_ty,
init_only: attr.is_init_var(),
init,
},
);
}

View file

@ -165,7 +165,8 @@ impl<'db> ClassBase<'db> {
KnownInstanceType::SubscriptedProtocol(_) => Some(Self::Protocol),
KnownInstanceType::TypeAliasType(_)
| KnownInstanceType::TypeVar(_)
| KnownInstanceType::Deprecated(_) => None,
| KnownInstanceType::Deprecated(_)
| KnownInstanceType::Field(_) => None,
},
Type::SpecialForm(special_form) => match special_form {

View file

@ -1049,6 +1049,8 @@ pub enum KnownFunction {
/// `dataclasses.dataclass`
Dataclass,
/// `dataclasses.field`
Field,
/// `inspect.getattr_static`
GetattrStatic,
@ -1127,7 +1129,7 @@ impl KnownFunction {
Self::AbstractMethod => {
matches!(module, KnownModule::Abc)
}
Self::Dataclass => {
Self::Dataclass | Self::Field => {
matches!(module, KnownModule::Dataclasses)
}
Self::GetattrStatic => module.is_inspect(),
@ -1408,7 +1410,7 @@ pub(crate) mod tests {
KnownFunction::AbstractMethod => KnownModule::Abc,
KnownFunction::Dataclass => KnownModule::Dataclasses,
KnownFunction::Dataclass | KnownFunction::Field => KnownModule::Dataclasses,
KnownFunction::GetattrStatic => KnownModule::Inspect,

View file

@ -10021,6 +10021,15 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
Type::unknown()
}
KnownInstanceType::Field(_) => {
self.infer_type_expression(&subscript.slice);
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"`dataclasses.Field` is not allowed in type expressions",
));
}
Type::unknown()
}
KnownInstanceType::TypeVar(_) => {
self.infer_type_expression(&subscript.slice);
todo_type!("TypeVar annotations")