[red-knot] Dataclasses: support order=True (#17406)

## Summary

Support dataclasses with `order=True`:

```py
@dataclass(order=True)
class WithOrder:
    x: int

WithOrder(1) < WithOrder(2)  # no error
```

Also adds some additional tests to `dataclasses.md`.

ticket: #16651

## Test Plan

New Markdown tests
This commit is contained in:
David Peter 2025-04-17 08:58:46 +02:00 committed by GitHub
parent 914095d08f
commit b4de245a5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 221 additions and 32 deletions

View file

@ -68,7 +68,7 @@ jobs:
--type-checker knot \
--old base_commit \
--new "$GITHUB_SHA" \
--project-selector '/(mypy_primer|black|pyp|git-revise|zipp|arrow|isort|itsdangerous|rich|packaging|pybind11|pyinstrument|typeshed-stats|scrapy|werkzeug|bidict|async-utils)$' \
--project-selector '/(mypy_primer|black|pyp|git-revise|zipp|arrow|isort|itsdangerous|rich|packaging|pybind11|pyinstrument|typeshed-stats|scrapy|werkzeug|bidict|async-utils|python-chess|dacite|python-htmlgen|paroxython|porcupine|psycopg)$' \
--output concise \
--debug > mypy_primer.diff || [ $? -eq 1 ]

View file

@ -91,6 +91,125 @@ repr(C())
C() == C()
```
## Other dataclass parameters
### `repr`
A custom `__repr__` method is generated by default. It can be disabled by passing `repr=False`, but
in that case `__repr__` is still available via `object.__repr__`:
```py
from dataclasses import dataclass
@dataclass(repr=False)
class WithoutRepr:
x: int
reveal_type(WithoutRepr(1).__repr__) # revealed: bound method WithoutRepr.__repr__() -> str
```
### `eq`
The same is true for `__eq__`. Setting `eq=False` disables the generated `__eq__` method, but
`__eq__` is still available via `object.__eq__`:
```py
from dataclasses import dataclass
@dataclass(eq=False)
class WithoutEq:
x: int
reveal_type(WithoutEq(1) == WithoutEq(2)) # revealed: bool
```
### `order`
`order` is set to `False` by default. If `order=True`, `__lt__`, `__le__`, `__gt__`, and `__ge__`
methods will be generated:
```py
from dataclasses import dataclass
@dataclass
class WithoutOrder:
x: int
WithoutOrder(1) < WithoutOrder(2) # error: [unsupported-operator]
WithoutOrder(1) <= WithoutOrder(2) # error: [unsupported-operator]
WithoutOrder(1) > WithoutOrder(2) # error: [unsupported-operator]
WithoutOrder(1) >= WithoutOrder(2) # error: [unsupported-operator]
@dataclass(order=True)
class WithOrder:
x: int
WithOrder(1) < WithOrder(2)
WithOrder(1) <= WithOrder(2)
WithOrder(1) > WithOrder(2)
WithOrder(1) >= WithOrder(2)
```
Comparisons are only allowed for `WithOrder` instances:
```py
WithOrder(1) < 2 # error: [unsupported-operator]
WithOrder(1) <= 2 # error: [unsupported-operator]
WithOrder(1) > 2 # error: [unsupported-operator]
WithOrder(1) >= 2 # error: [unsupported-operator]
```
This also works for generic dataclasses:
```py
from dataclasses import dataclass
@dataclass(order=True)
class GenericWithOrder[T]:
x: T
GenericWithOrder[int](1) < GenericWithOrder[int](1)
GenericWithOrder[int](1) < GenericWithOrder[str]("a") # error: [unsupported-operator]
```
If a class already defines one of the comparison methods, a `TypeError` is raised at runtime.
Ideally, we would emit a diagnostic in that case:
```py
@dataclass(order=True)
class AlreadyHasCustomDunderLt:
x: int
# TODO: Ideally, we would emit a diagnostic here
def __lt__(self, other: object) -> bool:
return False
```
### `unsafe_hash`
To do
### `frozen`
To do
### `match_args`
To do
### `kw_only`
To do
### `slots`
To do
### `weakref_slot`
To do
## Inheritance
### Normal class inheriting from a dataclass
@ -168,13 +287,30 @@ reveal_type(d_int.description) # revealed: str
DataWithDescription[int](None, "description")
```
## Frozen instances
To do
## Descriptor-typed fields
To do
```py
from dataclasses import dataclass
class Descriptor:
_value: int = 0
def __get__(self, instance, owner) -> str:
return str(self._value)
def __set__(self, instance, value: int) -> None:
self._value = value
@dataclass
class C:
d: Descriptor = Descriptor()
c = C(1)
reveal_type(c.d) # revealed: str
# TODO: should be an error
C("a")
```
## `dataclasses.field`
@ -197,18 +333,61 @@ class C:
reveal_type(C.__init__) # revealed: (*args: Any, **kwargs: Any) -> None
```
### Dataclass with `init=False`
To do
### Dataclass with custom `__init__` method
To do
If a class already defines `__init__`, it is not replaced by the `dataclass` decorator.
```py
from dataclasses import dataclass
@dataclass(init=True)
class C:
x: str
def __init__(self, x: int) -> None:
self.x = str(x)
C(1) # OK
# TODO: should be an error
C("a")
```
Similarly, if we set `init=False`, we still recognize the custom `__init__` method:
```py
@dataclass(init=False)
class D:
def __init__(self, x: int) -> None:
self.x = str(x)
D(1) # OK
D() # error: [missing-argument]
```
### Dataclass with `ClassVar`s
To do
### Return type of `dataclass(...)`
A call like `dataclass(order=True)` returns a callable itself, which is then used as the decorator.
We can store the callable in a variable and later use it as a decorator:
```py
from dataclasses import dataclass
dataclass_with_order = dataclass(order=True)
reveal_type(dataclass_with_order) # revealed: <decorator produced by dataclasses.dataclass>
@dataclass_with_order
class C:
x: int
C(1) < C(2) # ok
```
### Using `dataclass` as a function
To do

View file

@ -823,21 +823,31 @@ impl<'db> ClassLiteralType<'db> {
name: &str,
) -> SymbolAndQualifiers<'db> {
if let Some(metadata) = self.dataclass_metadata(db) {
if name == "__init__" {
if metadata.contains(DataclassMetadata::INIT) {
// TODO: Generate the signature from the attributes on the class
let init_signature = Signature::new(
Parameters::new([
Parameter::variadic(Name::new_static("args"))
.with_annotated_type(Type::any()),
Parameter::keyword_variadic(Name::new_static("kwargs"))
.with_annotated_type(Type::any()),
]),
Some(Type::none(db)),
);
if name == "__init__" && metadata.contains(DataclassMetadata::INIT) {
// TODO: Generate the signature from the attributes on the class
let init_signature = Signature::new(
Parameters::new([
Parameter::variadic(Name::new_static("args"))
.with_annotated_type(Type::any()),
Parameter::keyword_variadic(Name::new_static("kwargs"))
.with_annotated_type(Type::any()),
]),
Some(Type::none(db)),
);
return Symbol::bound(Type::Callable(CallableType::new(db, init_signature)))
.into();
return Symbol::bound(Type::Callable(CallableType::new(db, init_signature))).into();
} else if matches!(name, "__lt__" | "__le__" | "__gt__" | "__ge__") {
if metadata.contains(DataclassMetadata::ORDER) {
let signature = Signature::new(
Parameters::new([Parameter::positional_or_keyword(Name::new_static(
"other",
))
.with_annotated_type(Type::instance(
self.apply_optional_specialization(db, specialization),
))]),
Some(KnownClass::Bool.to_instance(db)),
);
return Symbol::bound(Type::Callable(CallableType::new(db, signature))).into();
}
}
}

View file

@ -13,7 +13,7 @@ use crate::Db;
///
/// TODO: Handle nested generic contexts better, with actual parent links to the lexically
/// containing context.
#[salsa::tracked(debug)]
#[salsa::interned(debug)]
pub struct GenericContext<'db> {
#[return_ref]
pub(crate) variables: Box<[TypeVarInstance<'db>]>,
@ -25,7 +25,7 @@ impl<'db> GenericContext<'db> {
index: &'db SemanticIndex<'db>,
type_params_node: &ast::TypeParams,
) -> Self {
let variables = type_params_node
let variables: Box<[_]> = type_params_node
.iter()
.filter_map(|type_param| Self::variable_from_type_param(db, index, type_param))
.collect();
@ -116,7 +116,7 @@ impl<'db> GenericContext<'db> {
///
/// TODO: Handle nested specializations better, with actual parent links to the specialization of
/// the lexically containing context.
#[salsa::tracked(debug)]
#[salsa::interned(debug)]
pub struct Specialization<'db> {
pub(crate) generic_context: GenericContext<'db>,
#[return_ref]
@ -138,7 +138,7 @@ impl<'db> Specialization<'db> {
/// That lets us produce the generic alias `A[int]`, which is the corresponding entry in the
/// MRO of `B[int]`.
pub(crate) fn apply_specialization(self, db: &'db dyn Db, other: Specialization<'db>) -> Self {
let types = self
let types: Box<[_]> = self
.types(db)
.into_iter()
.map(|ty| ty.apply_specialization(db, other))
@ -154,7 +154,7 @@ impl<'db> Specialization<'db> {
pub(crate) fn combine(self, db: &'db dyn Db, other: Self) -> Self {
let generic_context = self.generic_context(db);
assert!(other.generic_context(db) == generic_context);
let types = self
let types: Box<[_]> = self
.types(db)
.into_iter()
.zip(other.types(db))
@ -167,7 +167,7 @@ impl<'db> Specialization<'db> {
}
pub(crate) fn normalized(self, db: &'db dyn Db) -> Self {
let types = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
let types: Box<[_]> = self.types(db).iter().map(|ty| ty.normalized(db)).collect();
Self::new(db, self.generic_context(db), types)
}
@ -201,7 +201,7 @@ impl<'db> SpecializationBuilder<'db> {
}
pub(crate) fn build(mut self) -> Specialization<'db> {
let types = self
let types: Box<[_]> = self
.generic_context
.variables(self.db)
.iter()