diff --git a/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md new file mode 100644 index 0000000000..552c0e4a9d --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/dataclasses.md @@ -0,0 +1,242 @@ +# Dataclasses + +## Basic + +Decorating a class with `@dataclass` is a convenient way to add special methods such as `__init__`, +`__repr__`, and `__eq__` to a class. The following example shows the basic usage of the `@dataclass` +decorator. By default, only the three mentioned methods are generated. + +```py +from dataclasses import dataclass + +@dataclass +class Person: + name: str + age: int | None = None + +alice1 = Person("Alice", 30) +alice2 = Person(name="Alice", age=30) +alice3 = Person(age=30, name="Alice") +alice4 = Person("Alice", age=30) + +reveal_type(alice1) # revealed: Person +reveal_type(type(alice1)) # revealed: type[Person] + +reveal_type(alice1.name) # revealed: str +reveal_type(alice1.age) # revealed: int | None + +reveal_type(repr(alice1)) # revealed: str + +reveal_type(alice1 == alice2) # revealed: bool +reveal_type(alice1 == "Alice") # revealed: bool + +bob = Person("Bob") +bob2 = Person("Bob", None) +bob3 = Person(name="Bob") +bob4 = Person(name="Bob", age=None) +``` + +The signature of the `__init__` method is generated based on the classes attributes. The following +calls are not valid: + +```py +# TODO: should be an error: too few arguments +Person() + +# TODO: should be an error: too many arguments +Person("Eve", 20, "too many arguments") + +# TODO: should be an error: wrong argument type +Person("Eve", "string instead of int") + +# TODO: should be an error: wrong argument types +Person(20, "Eve") +``` + +## `@dataclass` calls with arguments + +The `@dataclass` decorator can take several arguments to customize the existence of the generated +methods. The following test makes sure that we still treat the class as a dataclass if (the default) +arguments are passed in: + +```py +from dataclasses import dataclass + +@dataclass(init=True, repr=True, eq=True) +class Person: + name: str + age: int | None = None + +alice = Person("Alice", 30) +reveal_type(repr(alice)) # revealed: str +reveal_type(alice == alice) # revealed: bool +``` + +If `init` is set to `False`, no `__init__` method is generated: + +```py +from dataclasses import dataclass + +@dataclass(init=False) +class C: + x: int + +C() # Okay + +# error: [too-many-positional-arguments] +C(1) + +repr(C()) + +C() == C() +``` + +## Inheritance + +### Normal class inheriting from a dataclass + +```py +from dataclasses import dataclass + +@dataclass +class Base: + x: int + +class Derived(Base): ... + +d = Derived(1) # OK +reveal_type(d.x) # revealed: int +``` + +### Dataclass inheriting from normal class + +```py +from dataclasses import dataclass + +class Base: + x: int = 1 + +@dataclass +class Derived(Base): + y: str + +d = Derived("a") + +# TODO: should be an error: +Derived(1, "a") +``` + +### Dataclass inheriting from another dataclass + +```py +from dataclasses import dataclass + +@dataclass +class Base: + x: int + +@dataclass +class Derived(Base): + y: str + +d = Derived(1, "a") # OK + +reveal_type(d.x) # revealed: int +reveal_type(d.y) # revealed: str + +# TODO: should be an error: +Derived("a") +``` + +## Generic dataclasses + +```py +from dataclasses import dataclass + +@dataclass +class DataWithDescription[T]: + data: T + description: str + +reveal_type(DataWithDescription[int]) # revealed: Literal[DataWithDescription[int]] + +d_int = DataWithDescription[int](1, "description") # OK +reveal_type(d_int.data) # revealed: int +reveal_type(d_int.description) # revealed: str + +# TODO: should be an error: wrong argument type +DataWithDescription[int](None, "description") +``` + +## Frozen instances + +To do + +## Descriptor-typed fields + +To do + +## `dataclasses.field` + +To do + +## Other special cases + +### `dataclasses.dataclass` + +We also understand dataclasses if they are decorated with the fully qualified name: + +```py +import dataclasses + +@dataclasses.dataclass +class C: + x: str + +# TODO: should show the proper signature +reveal_type(C.__init__) # revealed: (*args: Any, **kwargs: Any) -> None +``` + +### Dataclass with `init=False` + +To do + +### Dataclass with custom `__init__` method + +To do + +### Dataclass with `ClassVar`s + +To do + +### Using `dataclass` as a function + +To do + +## Internals + +The `dataclass` decorator returns the class itself. This means that the type of `Person` is `type`, +and attributes like the MRO are unchanged: + +```py +from dataclasses import dataclass + +@dataclass +class Person: + name: str + age: int | None = None + +reveal_type(type(Person)) # revealed: Literal[type] +reveal_type(Person.__mro__) # revealed: tuple[Literal[Person], Literal[object]] +``` + +The generated methods have the following signatures: + +```py +# TODO: proper signature +reveal_type(Person.__init__) # revealed: (*args: Any, **kwargs: Any) -> None + +reveal_type(Person.__repr__) # revealed: def __repr__(self) -> str + +reveal_type(Person.__eq__) # revealed: def __eq__(self, value: object, /) -> bool +``` diff --git a/crates/red_knot_python_semantic/src/module_resolver/module.rs b/crates/red_knot_python_semantic/src/module_resolver/module.rs index 6256a6f98b..afcc6687ba 100644 --- a/crates/red_knot_python_semantic/src/module_resolver/module.rs +++ b/crates/red_knot_python_semantic/src/module_resolver/module.rs @@ -116,6 +116,7 @@ pub enum KnownModule { Sys, #[allow(dead_code)] Abc, // currently only used in tests + Dataclasses, Collections, Inspect, KnotExtensions, @@ -132,6 +133,7 @@ impl KnownModule { Self::TypingExtensions => "typing_extensions", Self::Sys => "sys", Self::Abc => "abc", + Self::Dataclasses => "dataclasses", Self::Collections => "collections", Self::Inspect => "inspect", Self::KnotExtensions => "knot_extensions", diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 7795658be7..33f7046cd2 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -311,6 +311,32 @@ impl<'db> PropertyInstanceType<'db> { } } +bitflags! { + /// Used as the return type of `dataclass(…)` calls. Keeps track of the arguments + /// that were passed in. For the precise meaning of the fields, see [1]. + /// + /// [1]: https://docs.python.org/3/library/dataclasses.html + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] + pub struct DataclassMetadata: u16 { + const INIT = 0b0000_0000_0001; + const REPR = 0b0000_0000_0010; + const EQ = 0b0000_0000_0100; + const ORDER = 0b0000_0000_1000; + const UNSAFE_HASH = 0b0000_0001_0000; + const FROZEN = 0b0000_0010_0000; + const MATCH_ARGS = 0b0000_0100_0000; + const KW_ONLY = 0b0000_1000_0000; + const SLOTS = 0b0001_0000_0000; + const WEAKREF_SLOT = 0b0010_0000_0000; + } +} + +impl Default for DataclassMetadata { + fn default() -> Self { + Self::INIT | Self::REPR | Self::EQ | Self::MATCH_ARGS + } +} + /// Representation of a type: a set of possible values at runtime. #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash, salsa::Update)] pub enum Type<'db> { @@ -348,6 +374,10 @@ pub enum Type<'db> { /// type. We currently add this as a separate variant because `FunctionType.__get__` /// is an overloaded method and we do not support `@overload` yet. WrapperDescriptor(WrapperDescriptorKind), + /// A special callable that is returned by a `dataclass(…)` call. It is usually + /// used as a decorator. Note that this is only used as a return type for actual + /// `dataclass` calls, not for the argumentless `@dataclass` decorator. + DataclassDecorator(DataclassMetadata), /// The type of an arbitrary callable object with a certain specified signature. Callable(CallableType<'db>), /// A specific module object @@ -458,7 +488,8 @@ impl<'db> Type<'db> { | Self::Dynamic(DynamicType::Unknown | DynamicType::Any) | Self::BoundMethod(_) | Self::WrapperDescriptor(_) - | Self::MethodWrapper(_) => false, + | Self::MethodWrapper(_) + | Self::DataclassDecorator(_) => false, Self::GenericAlias(generic) => generic .specialization(db) @@ -747,6 +778,7 @@ impl<'db> Type<'db> { | Type::MethodWrapper(_) | Type::BoundMethod(_) | Type::WrapperDescriptor(_) + | Self::DataclassDecorator(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::KnownInstance(_) @@ -979,6 +1011,11 @@ impl<'db> Type<'db> { .signature(db) .is_subtype_of(db, other_callable.signature(db)), + (Type::DataclassDecorator(_), _) => { + // TODO: Implement subtyping using an equivalent `Callable` type. + false + } + (Type::Callable(_), _) => { // TODO: Implement subtyping between callable types and other types like // function literals, bound methods, class literals, `type[]`, etc.) @@ -1507,6 +1544,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(..) | Type::MethodWrapper(..) | Type::WrapperDescriptor(..) + | Type::DataclassDecorator(..) | Type::IntLiteral(..) | Type::SliceLiteral(..) | Type::StringLiteral(..) @@ -1522,6 +1560,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(..) | Type::MethodWrapper(..) | Type::WrapperDescriptor(..) + | Type::DataclassDecorator(..) | Type::IntLiteral(..) | Type::SliceLiteral(..) | Type::StringLiteral(..) @@ -1716,7 +1755,8 @@ impl<'db> Type<'db> { true } - (Type::Callable(_), _) | (_, Type::Callable(_)) => { + (Type::Callable(_) | Type::DataclassDecorator(_), _) + | (_, Type::Callable(_) | Type::DataclassDecorator(_)) => { // TODO: Implement disjointness for general callable type with other types false } @@ -1773,6 +1813,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::ModuleLiteral(..) | Type::IntLiteral(_) | Type::BooleanLiteral(_) @@ -1891,6 +1932,7 @@ impl<'db> Type<'db> { // (this variant represents `f.__get__`, where `f` is any function) false } + Type::DataclassDecorator(_) => false, Type::Instance(InstanceType { class }) => { class.known(db).is_some_and(KnownClass::is_singleton) } @@ -1977,7 +2019,8 @@ impl<'db> Type<'db> { | Type::AlwaysTruthy | Type::AlwaysFalsy | Type::Callable(_) - | Type::PropertyInstance(_) => false, + | Type::PropertyInstance(_) + | Type::DataclassDecorator(_) => false, } } @@ -2106,6 +2149,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::ModuleLiteral(_) | Type::KnownInstance(_) | Type::AlwaysTruthy @@ -2198,6 +2242,9 @@ impl<'db> Type<'db> { Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType .to_instance(db) .instance_member(db, name), + Type::DataclassDecorator(_) => KnownClass::FunctionType + .to_instance(db) + .instance_member(db, name), Type::Callable(_) => KnownClass::Object.to_instance(db).instance_member(db, name), Type::TypeVar(typevar) => match typevar.bound_or_constraints(db) { @@ -2604,6 +2651,9 @@ impl<'db> Type<'db> { Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType .to_instance(db) .member(db, &name), + Type::DataclassDecorator(_) => { + KnownClass::FunctionType.to_instance(db).member(db, &name) + } Type::Callable(_) => KnownClass::Object.to_instance(db).member(db, &name), Type::Instance(InstanceType { class }) @@ -2898,6 +2948,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::ModuleLiteral(_) | Type::SliceLiteral(_) | Type::AlwaysTruthy => Truthiness::AlwaysTrue, @@ -3289,6 +3340,83 @@ impl<'db> Type<'db> { Signatures::single(signature) } + Some(KnownFunction::Dataclass) => { + let signature = CallableSignature::from_overloads( + self, + [ + // def dataclass(cls: None, /) -> Callable[[type[_T]], type[_T]]: ... + Signature::new( + Parameters::new([Parameter::positional_only(Some( + Name::new_static("cls"), + )) + .with_annotated_type(Type::none(db))]), + None, + ), + // def dataclass(cls: type[_T], /) -> type[_T]: ... + Signature::new( + Parameters::new([Parameter::positional_only(Some( + Name::new_static("cls"), + )) + // TODO: type[_T] + .with_annotated_type(Type::any())]), + None, + ), + // TODO: make this overload Python-version-dependent + + // def dataclass( + // *, + // init: bool = True, + // repr: bool = True, + // eq: bool = True, + // order: bool = False, + // unsafe_hash: bool = False, + // frozen: bool = False, + // match_args: bool = True, + // kw_only: bool = False, + // slots: bool = False, + // weakref_slot: bool = False, + // ) -> Callable[[type[_T]], type[_T]]: ... + Signature::new( + Parameters::new([ + Parameter::keyword_only(Name::new_static("init")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(true)), + Parameter::keyword_only(Name::new_static("repr")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(true)), + Parameter::keyword_only(Name::new_static("eq")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(true)), + Parameter::keyword_only(Name::new_static("order")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(false)), + Parameter::keyword_only(Name::new_static("unsafe_hash")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(false)), + Parameter::keyword_only(Name::new_static("frozen")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(false)), + Parameter::keyword_only(Name::new_static("match_args")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(true)), + Parameter::keyword_only(Name::new_static("kw_only")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(false)), + Parameter::keyword_only(Name::new_static("slots")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(false)), + Parameter::keyword_only(Name::new_static("weakref_slot")) + .with_annotated_type(KnownClass::Bool.to_instance(db)) + .with_default_type(Type::BooleanLiteral(false)), + ]), + None, + ), + ], + ); + + Signatures::single(signature) + } + _ => Signatures::single(CallableSignature::single( self, function_type.signature(db).clone(), @@ -3911,6 +4039,7 @@ impl<'db> Type<'db> { | Type::MethodWrapper(_) | Type::BoundMethod(_) | Type::WrapperDescriptor(_) + | Type::DataclassDecorator(_) | Type::Instance(_) | Type::KnownInstance(_) | Type::PropertyInstance(_) @@ -3979,6 +4108,7 @@ impl<'db> Type<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::Never | Type::FunctionLiteral(_) | Type::PropertyInstance(_) => Err(InvalidTypeExpressionError { @@ -4188,6 +4318,7 @@ impl<'db> Type<'db> { Type::BoundMethod(_) => KnownClass::MethodType.to_class_literal(db), Type::MethodWrapper(_) => KnownClass::MethodWrapperType.to_class_literal(db), Type::WrapperDescriptor(_) => KnownClass::WrapperDescriptorType.to_class_literal(db), + Type::DataclassDecorator(_) => KnownClass::FunctionType.to_class_literal(db), Type::Callable(_) => KnownClass::Type.to_instance(db), Type::ModuleLiteral(_) => KnownClass::ModuleType.to_class_literal(db), Type::Tuple(_) => KnownClass::Tuple.to_class_literal(db), @@ -4326,6 +4457,7 @@ impl<'db> Type<'db> { | Type::AlwaysFalsy | Type::WrapperDescriptor(_) | Type::MethodWrapper(MethodWrapperKind::StrStartswith(_)) + | Type::DataclassDecorator(_) | Type::ModuleLiteral(_) // A non-generic class never needs to be specialized. A generic class is specialized // explicitly (via a subscript expression) or implicitly (via a call), and not because @@ -4430,6 +4562,7 @@ impl<'db> Type<'db> { | Self::SliceLiteral(_) | Self::MethodWrapper(_) | Self::WrapperDescriptor(_) + | Self::DataclassDecorator(_) | Self::PropertyInstance(_) | Self::Tuple(_) => self.to_meta_type(db).definition(db), @@ -5581,6 +5714,9 @@ pub enum KnownFunction { #[strum(serialize = "abstractmethod")] AbstractMethod, + /// `dataclasses.dataclass` + Dataclass, + /// `inspect.getattr_static` GetattrStatic, @@ -5640,6 +5776,9 @@ impl KnownFunction { Self::AbstractMethod => { matches!(module, KnownModule::Abc) } + Self::Dataclass => { + matches!(module, KnownModule::Dataclasses) + } Self::GetattrStatic => module.is_inspect(), Self::IsAssignableTo | Self::IsDisjointFrom @@ -6578,6 +6717,8 @@ pub(crate) mod tests { KnownFunction::AbstractMethod => KnownModule::Abc, + KnownFunction::Dataclass => KnownModule::Dataclasses, + KnownFunction::GetattrStatic => KnownModule::Inspect, KnownFunction::Cast diff --git a/crates/red_knot_python_semantic/src/types/call/bind.rs b/crates/red_knot_python_semantic/src/types/call/bind.rs index e270329636..da918d2249 100644 --- a/crates/red_knot_python_semantic/src/types/call/bind.rs +++ b/crates/red_knot_python_semantic/src/types/call/bind.rs @@ -18,8 +18,8 @@ use crate::types::diagnostic::{ }; use crate::types::signatures::{Parameter, ParameterForm}; use crate::types::{ - todo_type, BoundMethodType, FunctionDecorators, KnownClass, KnownFunction, KnownInstanceType, - MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, + todo_type, BoundMethodType, DataclassMetadata, FunctionDecorators, KnownClass, KnownFunction, + KnownInstanceType, MethodWrapperKind, PropertyInstanceType, UnionType, WrapperDescriptorKind, }; use ruff_db::diagnostic::{Annotation, Severity, Span, SubDiagnostic}; use ruff_python_ast as ast; @@ -573,6 +573,56 @@ impl<'db> Bindings<'db> { ); } + Some(KnownFunction::Dataclass) => { + if let [init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots, weakref_slot] = + overload.parameter_types() + { + let to_bool = |ty: &Option>, default: bool| -> bool { + if let Some(Type::BooleanLiteral(value)) = ty { + *value + } else { + // TODO: emit a diagnostic if we receive `bool` + default + } + }; + + let mut metadata = DataclassMetadata::empty(); + + if to_bool(init, true) { + metadata |= DataclassMetadata::INIT; + } + if to_bool(repr, true) { + metadata |= DataclassMetadata::REPR; + } + if to_bool(eq, true) { + metadata |= DataclassMetadata::EQ; + } + if to_bool(order, false) { + metadata |= DataclassMetadata::ORDER; + } + if to_bool(unsafe_hash, false) { + metadata |= DataclassMetadata::UNSAFE_HASH; + } + if to_bool(frozen, false) { + metadata |= DataclassMetadata::FROZEN; + } + if to_bool(match_args, true) { + metadata |= DataclassMetadata::MATCH_ARGS; + } + if to_bool(kw_only, false) { + metadata |= DataclassMetadata::KW_ONLY; + } + if to_bool(slots, false) { + metadata |= DataclassMetadata::SLOTS; + } + if to_bool(weakref_slot, false) { + metadata |= DataclassMetadata::WEAKREF_SLOT; + } + + overload.set_return_type(Type::DataclassDecorator(metadata)); + } + } + _ => {} }, diff --git a/crates/red_knot_python_semantic/src/types/class.rs b/crates/red_knot_python_semantic/src/types/class.rs index af84291f19..289bb3d67d 100644 --- a/crates/red_knot_python_semantic/src/types/class.rs +++ b/crates/red_knot_python_semantic/src/types/class.rs @@ -7,6 +7,8 @@ use super::{ }; use crate::semantic_index::definition::Definition; use crate::types::generics::{GenericContext, Specialization}; +use crate::types::signatures::{Parameter, Parameters}; +use crate::types::{CallableType, DataclassMetadata, Signature}; use crate::{ module_resolver::file_to_module, semantic_index::{ @@ -30,6 +32,7 @@ use crate::{ use indexmap::IndexSet; use itertools::Itertools as _; use ruff_db::files::File; +use ruff_python_ast::name::Name; use ruff_python_ast::{self as ast, PythonVersion}; use rustc_hash::FxHashSet; @@ -98,6 +101,8 @@ pub struct Class<'db> { pub(crate) body_scope: ScopeId<'db>, pub(crate) known: Option, + + pub(crate) dataclass_metadata: Option, } impl<'db> Class<'db> { @@ -364,6 +369,10 @@ impl<'db> ClassLiteralType<'db> { self.class(db).known } + pub(crate) fn dataclass_metadata(self, db: &'db dyn Db) -> Option { + self.class(db).dataclass_metadata + } + /// Return `true` if this class represents `known_class` pub(crate) fn is_known(self, db: &'db dyn Db, known_class: KnownClass) -> bool { self.class(db).known == Some(known_class) @@ -782,6 +791,26 @@ impl<'db> ClassLiteralType<'db> { /// directly. Use [`ClassLiteralType::class_member`] if you require a method that will /// traverse through the MRO until it finds the member. pub(super) fn own_class_member(self, db: &'db dyn Db, name: &str) -> SymbolAndQualifiers<'db> { + if let Some(metadata) = self.dataclass_metadata(db) { + if name == "__init__" { + if metadata.contains(DataclassMetadata::INIT) { + // TODO: Generate the signature from the attributes on the class + let init_signature = Signature::new( + Parameters::new([ + Parameter::variadic(Name::new_static("args")) + .with_annotated_type(Type::any()), + Parameter::keyword_variadic(Name::new_static("kwargs")) + .with_annotated_type(Type::any()), + ]), + Some(Type::none(db)), + ); + + return Symbol::bound(Type::Callable(CallableType::new(db, init_signature))) + .into(); + } + } + } + let body_scope = self.body_scope(db); class_symbol(db, body_scope, name) } diff --git a/crates/red_knot_python_semantic/src/types/class_base.rs b/crates/red_knot_python_semantic/src/types/class_base.rs index 26e563316a..0a173993f5 100644 --- a/crates/red_knot_python_semantic/src/types/class_base.rs +++ b/crates/red_knot_python_semantic/src/types/class_base.rs @@ -86,6 +86,7 @@ impl<'db> ClassBase<'db> { | Type::BoundMethod(_) | Type::MethodWrapper(_) | Type::WrapperDescriptor(_) + | Type::DataclassDecorator(_) | Type::BytesLiteral(_) | Type::IntLiteral(_) | Type::StringLiteral(_) diff --git a/crates/red_knot_python_semantic/src/types/display.rs b/crates/red_knot_python_semantic/src/types/display.rs index 599bceabd8..1f922ad4db 100644 --- a/crates/red_knot_python_semantic/src/types/display.rs +++ b/crates/red_knot_python_semantic/src/types/display.rs @@ -165,6 +165,9 @@ impl Display for DisplayRepresentation<'_> { }; write!(f, "") } + Type::DataclassDecorator(_) => { + f.write_str("") + } Type::Union(union) => union.display(self.db).fmt(f), Type::Intersection(intersection) => intersection.display(self.db).fmt(f), Type::IntLiteral(n) => n.fmt(f), diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index e793b5389e..a7c5698381 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -82,12 +82,13 @@ use crate::types::mro::MroErrorKind; use crate::types::unpacker::{UnpackResult, Unpacker}; use crate::types::{ todo_type, CallDunderError, CallableSignature, CallableType, Class, ClassLiteralType, - DynamicType, FunctionDecorators, FunctionType, GenericAlias, GenericClass, IntersectionBuilder, - IntersectionType, KnownClass, KnownFunction, KnownInstanceType, MemberLookupPolicy, - MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, Signature, - Signatures, SliceLiteralType, StringLiteralType, SubclassOfType, Symbol, SymbolAndQualifiers, - Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, TypeArrayDisplay, - TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, UnionType, + DataclassMetadata, DynamicType, FunctionDecorators, FunctionType, GenericAlias, GenericClass, + IntersectionBuilder, IntersectionType, KnownClass, KnownFunction, KnownInstanceType, + MemberLookupPolicy, MetaclassCandidate, NonGenericClass, Parameter, ParameterForm, Parameters, + Signature, Signatures, SliceLiteralType, StringLiteralType, SubclassOfType, Symbol, + SymbolAndQualifiers, Truthiness, TupleType, Type, TypeAliasType, TypeAndQualifiers, + TypeArrayDisplay, TypeQualifiers, TypeVarBoundOrConstraints, TypeVarInstance, UnionBuilder, + UnionType, }; use crate::unpack::{Unpack, UnpackPosition}; use crate::util::subscript::{PyIndex, PySlice}; @@ -1725,8 +1726,21 @@ impl<'db> TypeInferenceBuilder<'db> { body: _, } = class_node; + let mut dataclass_metadata = None; for decorator in decorator_list { - self.infer_decorator(decorator); + let decorator_ty = self.infer_decorator(decorator); + if decorator_ty + .into_function_literal() + .is_some_and(|function| function.is_known(self.db(), KnownFunction::Dataclass)) + { + dataclass_metadata = Some(DataclassMetadata::default()); + continue; + } + + if let Type::DataclassDecorator(metadata) = decorator_ty { + dataclass_metadata = Some(metadata); + continue; + } } let generic_context = type_params.as_ref().map(|type_params| { @@ -1744,6 +1758,7 @@ impl<'db> TypeInferenceBuilder<'db> { name: name.id.clone(), body_scope, known: maybe_known_class, + dataclass_metadata, }; let class_literal = match generic_context { Some(generic_context) => { @@ -2432,6 +2447,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::BoundMethod(_) | Type::MethodWrapper(_) | Type::WrapperDescriptor(_) + | Type::DataclassDecorator(_) | Type::TypeVar(..) | Type::AlwaysTruthy | Type::AlwaysFalsy => { @@ -4677,6 +4693,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::Callable(..) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::BoundMethod(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) @@ -4955,6 +4972,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::GenericAlias(_) @@ -4977,6 +4995,7 @@ impl<'db> TypeInferenceBuilder<'db> { | Type::BoundMethod(_) | Type::WrapperDescriptor(_) | Type::MethodWrapper(_) + | Type::DataclassDecorator(_) | Type::ModuleLiteral(_) | Type::ClassLiteral(_) | Type::GenericAlias(_) diff --git a/crates/red_knot_python_semantic/src/types/type_ordering.rs b/crates/red_knot_python_semantic/src/types/type_ordering.rs index c3f8177266..0fbe974bcb 100644 --- a/crates/red_knot_python_semantic/src/types/type_ordering.rs +++ b/crates/red_knot_python_semantic/src/types/type_ordering.rs @@ -70,6 +70,12 @@ pub(super) fn union_or_intersection_elements_ordering<'db>( (Type::WrapperDescriptor(_), _) => Ordering::Less, (_, Type::WrapperDescriptor(_)) => Ordering::Greater, + (Type::DataclassDecorator(left), Type::DataclassDecorator(right)) => { + left.bits().cmp(&right.bits()) + } + (Type::DataclassDecorator(_), _) => Ordering::Less, + (_, Type::DataclassDecorator(_)) => Ordering::Greater, + (Type::Callable(left), Type::Callable(right)) => { debug_assert_eq!(*left, left.normalized(db)); debug_assert_eq!(*right, right.normalized(db));