mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-26 18:06:43 +00:00
[ty] Support dataclass_transform for base class models (#20783)
## Summary Support `dataclass_transform` when used on a (base) class. ## Typing conformance * The changes in `dataclasses_transform_class.py` look good, just a few mistakes due to missing `alias` support. * I didn't look closely at the changes in `dataclasses_transform_converter.py` since we don't support `converter` yet. ## Ecosystem impact The impact looks huge, but it's concentrated on a single project (ibis). Their setup looks more or less like this: * the real `Annotatable`:d7083c2c96/ibis/common/grounds.py (L100-L101)* the real `DataType`:d7083c2c96/ibis/expr/datatypes/core.py (L161-L179)* the real `Array`:d7083c2c96/ibis/expr/datatypes/core.py (L1003-L1006)```py from typing import dataclass_transform @dataclass_transform() class Annotatable: pass class DataType(Annotatable): nullable: bool = True class Array[T](DataType): value_type: T ``` They expect something like `Array([1, 2])` to work, but ty, pyright, mypy, and pyrefly would all expect there to be a first argument for the `nullable` field on `DataType`. I don't really understand on what grounds they expect the `nullable` field to be excluded from the signature, but this seems to be the main reason for the new diagnostics here. Not sure if related, but it looks like their typing setup is not really complete (https://github.com/ibis-project/ibis/issues/6844#issuecomment-1868274770, this thread also mentions `dataclass_transform`). ## Test Plan Update pre-existing tests.
This commit is contained in:
parent
fc3b341529
commit
cfbd42c22a
4 changed files with 56 additions and 45 deletions
|
|
@ -122,9 +122,6 @@ class CustomerModel(ModelBase):
|
||||||
id: int
|
id: int
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
# TODO: this is not supported yet
|
|
||||||
# error: [unknown-argument]
|
|
||||||
# error: [unknown-argument]
|
|
||||||
CustomerModel(id=1, name="Test")
|
CustomerModel(id=1, name="Test")
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -216,11 +213,7 @@ class OrderedModelBase: ...
|
||||||
class TestWithBase(OrderedModelBase):
|
class TestWithBase(OrderedModelBase):
|
||||||
inner: int
|
inner: int
|
||||||
|
|
||||||
# TODO: No errors here, should reveal `bool`
|
reveal_type(TestWithBase(1) < TestWithBase(2)) # revealed: bool
|
||||||
# error: [too-many-positional-arguments]
|
|
||||||
# error: [too-many-positional-arguments]
|
|
||||||
# error: [unsupported-operator]
|
|
||||||
reveal_type(TestWithBase(1) < TestWithBase(2)) # revealed: Unknown
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### `kw_only_default`
|
### `kw_only_default`
|
||||||
|
|
@ -277,8 +270,7 @@ class ModelBase: ...
|
||||||
class TestBase(ModelBase):
|
class TestBase(ModelBase):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
# TODO: This should be `(self: TestBase, *, name: str) -> None`
|
reveal_type(TestBase.__init__) # revealed: (self: TestBase, *, name: str) -> None
|
||||||
reveal_type(TestBase.__init__) # revealed: def __init__(self) -> None
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### `frozen_default`
|
### `frozen_default`
|
||||||
|
|
@ -333,12 +325,9 @@ class ModelBase: ...
|
||||||
class TestMeta(ModelBase):
|
class TestMeta(ModelBase):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
# TODO: no error here
|
|
||||||
# error: [unknown-argument]
|
|
||||||
t = TestMeta(name="test")
|
t = TestMeta(name="test")
|
||||||
|
|
||||||
# TODO: this should be an `invalid-assignment` error
|
t.name = "new" # error: [invalid-assignment]
|
||||||
t.name = "new"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Combining parameters
|
### Combining parameters
|
||||||
|
|
@ -437,19 +426,15 @@ class DefaultFrozenModel:
|
||||||
class Frozen(DefaultFrozenModel):
|
class Frozen(DefaultFrozenModel):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
# TODO: no error here
|
|
||||||
# error: [unknown-argument]
|
|
||||||
f = Frozen(name="test")
|
f = Frozen(name="test")
|
||||||
# TODO: this should be an `invalid-assignment` error
|
f.name = "new" # error: [invalid-assignment]
|
||||||
f.name = "new"
|
|
||||||
|
|
||||||
class Mutable(DefaultFrozenModel, frozen=False):
|
class Mutable(DefaultFrozenModel, frozen=False):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
# TODO: no error here
|
|
||||||
# error: [unknown-argument]
|
|
||||||
m = Mutable(name="test")
|
m = Mutable(name="test")
|
||||||
m.name = "new" # No error
|
# TODO: This should not be an error
|
||||||
|
m.name = "new" # error: [invalid-assignment]
|
||||||
```
|
```
|
||||||
|
|
||||||
## `field_specifiers`
|
## `field_specifiers`
|
||||||
|
|
@ -532,12 +517,8 @@ class Person(FancyBase):
|
||||||
name: str = fancy_field()
|
name: str = fancy_field()
|
||||||
age: int | None = fancy_field(kw_only=True)
|
age: int | None = fancy_field(kw_only=True)
|
||||||
|
|
||||||
# TODO: should be (self: Person, name: str = Unknown, *, age: int | None = Unknown) -> None
|
reveal_type(Person.__init__) # revealed: (self: Person, name: str, *, age: int | None) -> None
|
||||||
reveal_type(Person.__init__) # revealed: def __init__(self) -> None
|
|
||||||
|
|
||||||
# TODO: shouldn't be an error
|
|
||||||
# error: [too-many-positional-arguments]
|
|
||||||
# error: [unknown-argument]
|
|
||||||
alice = Person("Alice", age=30)
|
alice = Person("Alice", age=30)
|
||||||
|
|
||||||
reveal_type(alice.id) # revealed: int
|
reveal_type(alice.id) # revealed: int
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,11 @@ pub(crate) enum CodeGeneratorKind<'db> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'db> CodeGeneratorKind<'db> {
|
impl<'db> CodeGeneratorKind<'db> {
|
||||||
pub(crate) fn from_class(db: &'db dyn Db, class: ClassLiteral<'db>) -> Option<Self> {
|
pub(crate) fn from_class(
|
||||||
|
db: &'db dyn Db,
|
||||||
|
class: ClassLiteral<'db>,
|
||||||
|
specialization: Option<Specialization<'db>>,
|
||||||
|
) -> Option<Self> {
|
||||||
#[salsa::tracked(
|
#[salsa::tracked(
|
||||||
cycle_fn=code_generator_of_class_recover,
|
cycle_fn=code_generator_of_class_recover,
|
||||||
cycle_initial=code_generator_of_class_initial,
|
cycle_initial=code_generator_of_class_initial,
|
||||||
|
|
@ -199,11 +203,20 @@ impl<'db> CodeGeneratorKind<'db> {
|
||||||
fn code_generator_of_class<'db>(
|
fn code_generator_of_class<'db>(
|
||||||
db: &'db dyn Db,
|
db: &'db dyn Db,
|
||||||
class: ClassLiteral<'db>,
|
class: ClassLiteral<'db>,
|
||||||
|
specialization: Option<Specialization<'db>>,
|
||||||
) -> Option<CodeGeneratorKind<'db>> {
|
) -> Option<CodeGeneratorKind<'db>> {
|
||||||
if class.dataclass_params(db).is_some() {
|
if class.dataclass_params(db).is_some() {
|
||||||
Some(CodeGeneratorKind::DataclassLike(None))
|
Some(CodeGeneratorKind::DataclassLike(None))
|
||||||
} else if let Ok((_, Some(transformer_params))) = class.try_metaclass(db) {
|
} else if let Ok((_, Some(transformer_params))) = class.try_metaclass(db) {
|
||||||
Some(CodeGeneratorKind::DataclassLike(Some(transformer_params)))
|
Some(CodeGeneratorKind::DataclassLike(Some(transformer_params)))
|
||||||
|
} else if let Some(transformer_params) =
|
||||||
|
class.iter_mro(db, specialization).skip(1).find_map(|base| {
|
||||||
|
base.into_class().and_then(|class| {
|
||||||
|
class.class_literal(db).0.dataclass_transformer_params(db)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
{
|
||||||
|
Some(CodeGeneratorKind::DataclassLike(Some(transformer_params)))
|
||||||
} else if class
|
} else if class
|
||||||
.explicit_bases(db)
|
.explicit_bases(db)
|
||||||
.contains(&Type::SpecialForm(SpecialFormType::NamedTuple))
|
.contains(&Type::SpecialForm(SpecialFormType::NamedTuple))
|
||||||
|
|
@ -219,6 +232,7 @@ impl<'db> CodeGeneratorKind<'db> {
|
||||||
fn code_generator_of_class_initial<'db>(
|
fn code_generator_of_class_initial<'db>(
|
||||||
_db: &'db dyn Db,
|
_db: &'db dyn Db,
|
||||||
_class: ClassLiteral<'db>,
|
_class: ClassLiteral<'db>,
|
||||||
|
_specialization: Option<Specialization<'db>>,
|
||||||
) -> Option<CodeGeneratorKind<'db>> {
|
) -> Option<CodeGeneratorKind<'db>> {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
|
@ -229,21 +243,37 @@ impl<'db> CodeGeneratorKind<'db> {
|
||||||
_value: &Option<CodeGeneratorKind<'db>>,
|
_value: &Option<CodeGeneratorKind<'db>>,
|
||||||
_count: u32,
|
_count: u32,
|
||||||
_class: ClassLiteral<'db>,
|
_class: ClassLiteral<'db>,
|
||||||
|
_specialization: Option<Specialization<'db>>,
|
||||||
) -> salsa::CycleRecoveryAction<Option<CodeGeneratorKind<'db>>> {
|
) -> salsa::CycleRecoveryAction<Option<CodeGeneratorKind<'db>>> {
|
||||||
salsa::CycleRecoveryAction::Iterate
|
salsa::CycleRecoveryAction::Iterate
|
||||||
}
|
}
|
||||||
|
|
||||||
code_generator_of_class(db, class)
|
code_generator_of_class(db, class, specialization)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn matches(self, db: &'db dyn Db, class: ClassLiteral<'db>) -> bool {
|
pub(super) fn matches(
|
||||||
|
self,
|
||||||
|
db: &'db dyn Db,
|
||||||
|
class: ClassLiteral<'db>,
|
||||||
|
specialization: Option<Specialization<'db>>,
|
||||||
|
) -> bool {
|
||||||
matches!(
|
matches!(
|
||||||
(CodeGeneratorKind::from_class(db, class), self),
|
(
|
||||||
|
CodeGeneratorKind::from_class(db, class, specialization),
|
||||||
|
self
|
||||||
|
),
|
||||||
(Some(Self::DataclassLike(_)), Self::DataclassLike(_))
|
(Some(Self::DataclassLike(_)), Self::DataclassLike(_))
|
||||||
| (Some(Self::NamedTuple), Self::NamedTuple)
|
| (Some(Self::NamedTuple), Self::NamedTuple)
|
||||||
| (Some(Self::TypedDict), Self::TypedDict)
|
| (Some(Self::TypedDict), Self::TypedDict)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) fn dataclass_transformer_params(self) -> Option<DataclassTransformerParams<'db>> {
|
||||||
|
match self {
|
||||||
|
Self::DataclassLike(params) => params,
|
||||||
|
Self::NamedTuple | Self::TypedDict => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A specialization of a generic class with a particular assignment of types to typevars.
|
/// A specialization of a generic class with a particular assignment of types to typevars.
|
||||||
|
|
@ -2200,7 +2230,7 @@ impl<'db> ClassLiteral<'db> {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if CodeGeneratorKind::NamedTuple.matches(db, self) {
|
if CodeGeneratorKind::NamedTuple.matches(db, self, specialization) {
|
||||||
if let Some(field) = self
|
if let Some(field) = self
|
||||||
.own_fields(db, specialization, CodeGeneratorKind::NamedTuple)
|
.own_fields(db, specialization, CodeGeneratorKind::NamedTuple)
|
||||||
.get(name)
|
.get(name)
|
||||||
|
|
@ -2262,7 +2292,7 @@ impl<'db> ClassLiteral<'db> {
|
||||||
) -> Option<Type<'db>> {
|
) -> Option<Type<'db>> {
|
||||||
let dataclass_params = self.dataclass_params(db);
|
let dataclass_params = self.dataclass_params(db);
|
||||||
|
|
||||||
let field_policy = CodeGeneratorKind::from_class(db, self)?;
|
let field_policy = CodeGeneratorKind::from_class(db, self, specialization)?;
|
||||||
|
|
||||||
let transformer_params =
|
let transformer_params =
|
||||||
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
|
if let CodeGeneratorKind::DataclassLike(Some(transformer_params)) = field_policy {
|
||||||
|
|
@ -2808,7 +2838,7 @@ impl<'db> ClassLiteral<'db> {
|
||||||
.filter_map(|superclass| {
|
.filter_map(|superclass| {
|
||||||
if let Some(class) = superclass.into_class() {
|
if let Some(class) = superclass.into_class() {
|
||||||
let (class_literal, specialization) = class.class_literal(db);
|
let (class_literal, specialization) = class.class_literal(db);
|
||||||
if field_policy.matches(db, class_literal) {
|
if field_policy.matches(db, class_literal, specialization) {
|
||||||
Some((class_literal, specialization))
|
Some((class_literal, specialization))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
|
|
@ -3623,7 +3653,7 @@ impl<'db> VarianceInferable<'db> for ClassLiteral<'db> {
|
||||||
.map(|class| class.variance_of(db, typevar));
|
.map(|class| class.variance_of(db, typevar));
|
||||||
|
|
||||||
let default_attribute_variance = {
|
let default_attribute_variance = {
|
||||||
let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self);
|
let is_namedtuple = CodeGeneratorKind::NamedTuple.matches(db, self, None);
|
||||||
// Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses
|
// Python 3.13 introduced a synthesized `__replace__` method on dataclasses which uses
|
||||||
// their field types in contravariant position, thus meaning a frozen dataclass must
|
// their field types in contravariant position, thus meaning a frozen dataclass must
|
||||||
// still be invariant in its field types. Other synthesized methods on dataclasses are
|
// still be invariant in its field types. Other synthesized methods on dataclasses are
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ impl<'db> AllMembers<'db> {
|
||||||
self.extend_with_instance_members(db, ty, class_literal);
|
self.extend_with_instance_members(db, ty, class_literal);
|
||||||
|
|
||||||
// If this is a NamedTuple instance, include members from NamedTupleFallback
|
// If this is a NamedTuple instance, include members from NamedTupleFallback
|
||||||
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
|
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
|
||||||
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
|
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -142,7 +142,7 @@ impl<'db> AllMembers<'db> {
|
||||||
Type::ClassLiteral(class_literal) => {
|
Type::ClassLiteral(class_literal) => {
|
||||||
self.extend_with_class_members(db, ty, class_literal);
|
self.extend_with_class_members(db, ty, class_literal);
|
||||||
|
|
||||||
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
|
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
|
||||||
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
|
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -153,7 +153,7 @@ impl<'db> AllMembers<'db> {
|
||||||
|
|
||||||
Type::GenericAlias(generic_alias) => {
|
Type::GenericAlias(generic_alias) => {
|
||||||
let class_literal = generic_alias.origin(db);
|
let class_literal = generic_alias.origin(db);
|
||||||
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
|
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
|
||||||
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
|
self.extend_with_type(db, KnownClass::NamedTupleFallback.to_class_literal(db));
|
||||||
}
|
}
|
||||||
self.extend_with_class_members(db, ty, class_literal);
|
self.extend_with_class_members(db, ty, class_literal);
|
||||||
|
|
@ -164,7 +164,7 @@ impl<'db> AllMembers<'db> {
|
||||||
let class_literal = class_type.class_literal(db).0;
|
let class_literal = class_type.class_literal(db).0;
|
||||||
self.extend_with_class_members(db, ty, class_literal);
|
self.extend_with_class_members(db, ty, class_literal);
|
||||||
|
|
||||||
if CodeGeneratorKind::NamedTuple.matches(db, class_literal) {
|
if CodeGeneratorKind::NamedTuple.matches(db, class_literal, None) {
|
||||||
self.extend_with_type(
|
self.extend_with_type(
|
||||||
db,
|
db,
|
||||||
KnownClass::NamedTupleFallback.to_class_literal(db),
|
KnownClass::NamedTupleFallback.to_class_literal(db),
|
||||||
|
|
|
||||||
|
|
@ -577,7 +577,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_named_tuple = CodeGeneratorKind::NamedTuple.matches(self.db(), class);
|
let is_named_tuple = CodeGeneratorKind::NamedTuple.matches(self.db(), class, None);
|
||||||
|
|
||||||
// (2) If it's a `NamedTuple` class, check that no field without a default value
|
// (2) If it's a `NamedTuple` class, check that no field without a default value
|
||||||
// appears after a field with a default value.
|
// appears after a field with a default value.
|
||||||
|
|
@ -898,7 +898,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
|
|
||||||
// (7) Check that a dataclass does not have more than one `KW_ONLY`.
|
// (7) Check that a dataclass does not have more than one `KW_ONLY`.
|
||||||
if let Some(field_policy @ CodeGeneratorKind::DataclassLike(_)) =
|
if let Some(field_policy @ CodeGeneratorKind::DataclassLike(_)) =
|
||||||
CodeGeneratorKind::from_class(self.db(), class)
|
CodeGeneratorKind::from_class(self.db(), class, None)
|
||||||
{
|
{
|
||||||
let specialization = None;
|
let specialization = None;
|
||||||
|
|
||||||
|
|
@ -4569,11 +4569,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
||||||
.dataclass_params(db)
|
.dataclass_params(db)
|
||||||
.map(|params| SmallVec::from(params.field_specifiers(db)))
|
.map(|params| SmallVec::from(params.field_specifiers(db)))
|
||||||
.or_else(|| {
|
.or_else(|| {
|
||||||
class_literal
|
Some(SmallVec::from(
|
||||||
.try_metaclass(db)
|
CodeGeneratorKind::from_class(db, class_literal, None)?
|
||||||
.ok()
|
.dataclass_transformer_params()?
|
||||||
.and_then(|(_, params)| params)
|
.field_specifiers(db),
|
||||||
.map(|params| SmallVec::from(params.field_specifiers(db)))
|
))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue