diff --git a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md index bff8dba19d..f210e97b0d 100644 --- a/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md +++ b/crates/ty_python_semantic/resources/mdtest/dataclasses/dataclass_transform.md @@ -202,7 +202,7 @@ from typing import dataclass_transform from dataclasses import field @dataclass_transform(kw_only_default=True) -def create_model(*, init=True): ... +def create_model(*, kw_only: bool = True): ... @create_model() class A: name: str @@ -213,24 +213,69 @@ a = A(name="Harry") a = A("Harry") ``` -TODO: This can be overridden by the call to the decorator function. +This can be overridden by setting `kw_only=False` when applying the decorator: ```py -from typing import dataclass_transform - -@dataclass_transform(kw_only_default=True) -def create_model(*, kw_only: bool = True): ... @create_model(kw_only=False) class CustomerModel: id: int name: str -# TODO: Should not emit errors -# error: [missing-argument] -# error: [too-many-positional-arguments] c = CustomerModel(1, "Harry") ``` +### `frozen_default` + +When provided, sets the default value for the `frozen` parameter of `field()`. + +```py +from typing import dataclass_transform + +@dataclass_transform(frozen_default=True) +def create_model(*, frozen: bool = True): ... +@create_model() +class ImmutableModel: + name: str + +i = ImmutableModel(name="test") +i.name = "new" # error: [invalid-assignment] +``` + +Again, this can be overridden by setting `frozen=False` when applying the decorator: + +```py +@create_model(frozen=False) +class MutableModel: + name: str + +m = MutableModel(name="test") +m.name = "new" # No error +``` + +### Combining parameters + +Combining several of these parameters also works as expected: + +```py +from typing import dataclass_transform + +@dataclass_transform(eq_default=True, order_default=False, kw_only_default=True, frozen_default=True) +def create_model(*, eq: bool = True, order: bool = False, kw_only: bool = True, frozen: bool = True): ... +@create_model(eq=False, order=True, kw_only=False, frozen=False) +class OverridesAllParametersModel: + name: str + age: int + +# Positional arguments are allowed: +model = OverridesAllParametersModel("test", 25) + +# Mutation is allowed: +model.name = "new" # No error + +# Comparison methods are generated: +model < model # No error +``` + ### `field_specifiers` To do diff --git a/crates/ty_python_semantic/src/types/call/bind.rs b/crates/ty_python_semantic/src/types/call/bind.rs index fef791d38d..6ea7fd8752 100644 --- a/crates/ty_python_semantic/src/types/call/bind.rs +++ b/crates/ty_python_semantic/src/types/call/bind.rs @@ -1023,18 +1023,30 @@ impl<'db> Bindings<'db> { let mut dataclass_params = DataclassParams::from(params); - if let Some(Some(Type::BooleanLiteral(order))) = - overload - .signature - .parameters() - .keyword_by_name("order") - .map(|(idx, _)| idx) - .and_then(|idx| { - overload.parameter_types().get(idx) - }) + if let Ok(Some(Type::BooleanLiteral(order))) = + overload.parameter_type_by_name("order") + { + dataclass_params.set(DataclassParams::ORDER, order); + } + + if let Ok(Some(Type::BooleanLiteral(eq))) = + overload.parameter_type_by_name("eq") + { + dataclass_params.set(DataclassParams::EQ, eq); + } + + if let Ok(Some(Type::BooleanLiteral(kw_only))) = + overload.parameter_type_by_name("kw_only") { dataclass_params - .set(DataclassParams::ORDER, *order); + .set(DataclassParams::KW_ONLY, kw_only); + } + + if let Ok(Some(Type::BooleanLiteral(frozen))) = + overload.parameter_type_by_name("frozen") + { + dataclass_params + .set(DataclassParams::FROZEN, frozen); } Type::DataclassDecorator(dataclass_params) @@ -2933,12 +2945,11 @@ impl<'db> Binding<'db> { &self, parameter_name: &str, ) -> Result>, UnknownParameterNameError> { - let (index, _) = self + let index = self .signature .parameters() - .iter() - .enumerate() - .find(|(_, param)| param.name().is_some_and(|name| name == parameter_name)) + .keyword_by_name(parameter_name) + .map(|(i, _)| i) .ok_or(UnknownParameterNameError)?; Ok(self.parameter_tys[index])