From 25023cc0ea779750e2b7c6a2c2a4aa5c4815ec33 Mon Sep 17 00:00:00 2001 From: Ibraheem Ahmed Date: Thu, 16 Oct 2025 15:40:39 -0400 Subject: [PATCH] [ty] Use declared variable types as bidirectional type context (#20796) ## Summary Use the declared type of variables as type context for the RHS of assignment expressions, e.g., ```py x: list[int | str] x = [1] reveal_type(x) # revealed: list[int | str] ``` --- .../mdtest/assignment/annotations.md | 12 ++ .../mdtest/narrow/conditionals/nested.md | 4 +- .../resources/mdtest/typed_dict.md | 14 +- .../src/types/infer/builder.rs | 135 ++++++++++-------- 4 files changed, 102 insertions(+), 63 deletions(-) diff --git a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md index fe0bbf84fd..3d5e75ab99 100644 --- a/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md +++ b/crates/ty_python_semantic/resources/mdtest/assignment/annotations.md @@ -144,6 +144,12 @@ reveal_type(q) # revealed: dict[int | str, int] r: dict[int | str, int | str] = {1: 1, 2: 2, 3: 3} reveal_type(r) # revealed: dict[int | str, int | str] + +s: dict[int | str, int | str] +s = {1: 1, 2: 2, 3: 3} +reveal_type(s) # revealed: dict[int | str, int | str] +(s := {1: 1, 2: 2, 3: 3}) +reveal_type(s) # revealed: dict[int | str, int | str] ``` ## Optional collection literal annotations are understood @@ -296,6 +302,12 @@ reveal_type(q) # revealed: list[int] r: list[Literal[1, 2, 3, 4]] = [1, 2] reveal_type(r) # revealed: list[Literal[1, 2, 3, 4]] + +s: list[Literal[1]] +s = [1] +reveal_type(s) # revealed: list[Literal[1]] +(s := [1]) +reveal_type(s) # revealed: list[Literal[1]] ``` ## PEP-604 annotations are supported diff --git a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md index f27b3deb08..2cb4585b3b 100644 --- a/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md +++ b/crates/ty_python_semantic/resources/mdtest/narrow/conditionals/nested.md @@ -310,13 +310,13 @@ no longer valid in the inner lazy scope. def f(l: list[str | None]): if l[0] is not None: def _(): - reveal_type(l[0]) # revealed: str | None | Unknown + reveal_type(l[0]) # revealed: str | None l = [None] def f(l: list[str | None]): l[0] = "a" def _(): - reveal_type(l[0]) # revealed: str | None | Unknown + reveal_type(l[0]) # revealed: str | None l = [None] def f(l: list[str | None]): diff --git a/crates/ty_python_semantic/resources/mdtest/typed_dict.md b/crates/ty_python_semantic/resources/mdtest/typed_dict.md index b9e3015c9f..5cefed9b28 100644 --- a/crates/ty_python_semantic/resources/mdtest/typed_dict.md +++ b/crates/ty_python_semantic/resources/mdtest/typed_dict.md @@ -233,10 +233,12 @@ Person({"name": "Alice"}) # error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor" accepts_person({"name": "Alice"}) + # TODO: this should be an error, similar to the above house.owner = {"name": "Alice"} + a_person: Person -# TODO: this should be an error, similar to the above +# error: [missing-typed-dict-key] "Missing required key 'age' in TypedDict `Person` constructor" a_person = {"name": "Alice"} ``` @@ -254,9 +256,12 @@ Person({"name": None, "age": 30}) accepts_person({"name": None, "age": 30}) # TODO: this should be an error, similar to the above house.owner = {"name": None, "age": 30} + a_person: Person -# TODO: this should be an error, similar to the above +# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`" a_person = {"name": None, "age": 30} +# error: [invalid-argument-type] "Invalid argument to key "name" with declared type `str` on TypedDict `Person`: value of type `None`" +(a_person := {"name": None, "age": 30}) ``` All of these have an extra field that is not defined in the `TypedDict`: @@ -273,9 +278,12 @@ Person({"name": "Alice", "age": 30, "extra": True}) accepts_person({"name": "Alice", "age": 30, "extra": True}) # TODO: this should be an error house.owner = {"name": "Alice", "age": 30, "extra": True} -# TODO: this should be an error + a_person: Person +# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra"" a_person = {"name": "Alice", "age": 30, "extra": True} +# error: [invalid-key] "Invalid key access on TypedDict `Person`: Unknown key "extra"" +(a_person := {"name": "Alice", "age": 30, "extra": True}) ``` ## Type ignore compatibility issues diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index 8ea2ca9988..f84922a5f1 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -1345,7 +1345,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { true } - fn add_binding(&mut self, node: AnyNodeRef, binding: Definition<'db>, ty: Type<'db>) { + /// Add a binding for the given definition. + /// + /// Returns the result of the `infer_value_ty` closure, which is called with the declared type + /// as type context. + fn add_binding( + &mut self, + node: AnyNodeRef, + binding: Definition<'db>, + infer_value_ty: impl FnOnce(&mut Self, TypeContext<'db>) -> Type<'db>, + ) -> Type<'db> { /// Arbitrary `__getitem__`/`__setitem__` methods on a class do not /// necessarily guarantee that the passed-in value for `__setitem__` is stored and /// can be retrieved unmodified via `__getitem__`. Therefore, we currently only @@ -1390,7 +1399,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { let file_scope_id = binding.file_scope(db); let place_table = self.index.place_table(file_scope_id); let use_def = self.index.use_def_map(file_scope_id); - let mut bound_ty = ty; let global_use_def_map = self.index.use_def_map(FileScopeId::global()); let place_id = binding.place(self.db()); @@ -1501,12 +1509,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { qualifiers, } = place_and_quals; - let unwrap_declared_ty = || { - resolved_place - .ignore_possibly_undefined() - .unwrap_or(Type::unknown()) - }; - // If the place is unbound and its an attribute or subscript place, fall back to normal // attribute/subscript inference on the root type. let declared_ty = @@ -1518,9 +1520,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { value_type.member(db, attr).place { // TODO: also consider qualifiers on the attribute - ty + Some(ty) } else { - unwrap_declared_ty() + None } } else if let AnyNodeRef::ExprSubscript( subscript @ ast::ExprSubscript { @@ -1530,13 +1532,19 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { { let value_ty = self.infer_expression(value, TypeContext::default()); let slice_ty = self.infer_expression(slice, TypeContext::default()); - self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx) + Some(self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx)) } else { - unwrap_declared_ty() + None } } else { - unwrap_declared_ty() - }; + None + } + .or_else(|| resolved_place.ignore_possibly_undefined()); + + let inferred_ty = infer_value_ty(self, TypeContext::new(declared_ty)); + + let declared_ty = declared_ty.unwrap_or(Type::unknown()); + let mut bound_ty = inferred_ty; if qualifiers.contains(TypeQualifiers::FINAL) { let mut previous_bindings = use_def.bindings_at_definition(binding); @@ -1592,7 +1600,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { if !bound_ty.is_assignable_to(db, declared_ty) { report_invalid_assignment(&self.context, node, binding, declared_ty, bound_ty); - // allow declarations to override inference in case of invalid assignment + + // Allow declarations to override inference in case of invalid assignment. bound_ty = declared_ty; } // In the following cases, the bound type may not be the same as the RHS value type. @@ -1620,6 +1629,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } self.bindings.insert(binding, bound_ty); + + inferred_ty } /// Returns `true` if `symbol_id` should be looked up in the global scope, skipping intervening @@ -2485,7 +2496,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } else { Type::unknown() }; - self.add_binding(parameter.into(), definition, ty); + + self.add_binding(parameter.into(), definition, |_, _| ty); } } @@ -2515,11 +2527,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &DeclaredAndInferredType::are_the_same_type(ty), ); } else { - self.add_binding( - parameter.into(), - definition, - Type::homogeneous_tuple(self.db(), Type::unknown()), - ); + self.add_binding(parameter.into(), definition, |builder, _| { + Type::homogeneous_tuple(builder.db(), Type::unknown()) + }); } } @@ -2547,14 +2557,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { &DeclaredAndInferredType::are_the_same_type(ty), ); } else { - self.add_binding( - parameter.into(), - definition, + self.add_binding(parameter.into(), definition, |builder, _| { KnownClass::Dict.to_specialized_instance( - self.db(), - [KnownClass::Str.to_instance(self.db()), Type::unknown()], - ), - ); + builder.db(), + [KnownClass::Str.to_instance(builder.db()), Type::unknown()], + ) + }); } } @@ -2828,12 +2836,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { for item in items { let target = item.optional_vars.as_deref(); if let Some(target) = target { - self.infer_target(target, &item.context_expr, |builder, context_expr| { + self.infer_target(target, &item.context_expr, |builder| { // TODO: `infer_with_statement_definition` reports a diagnostic if `ctx_manager_ty` isn't a context manager // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `with not_context_manager as a.x: ... builder - .infer_standalone_expression(context_expr, TypeContext::default()) + .infer_standalone_expression(&item.context_expr, TypeContext::default()) .enter(builder.db()) }); } else { @@ -2873,7 +2881,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; self.store_expression_type(target, target_ty); - self.add_binding(target.into(), definition, target_ty); + self.add_binding(target.into(), definition, |_, _| target_ty); } /// Infers the type of a context expression (`with expr`) and returns the target's type @@ -3005,7 +3013,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { self.add_binding( except_handler_definition.node(self.module()).into(), definition, - symbol_ty, + |_, _| symbol_ty, ); } @@ -3174,11 +3182,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // against the subject expression type (which we can query via `infer_expression_types`) // and extract the type at the `index` position if the pattern matches. This will be // similar to the logic in `self.infer_assignment_definition`. - self.add_binding( - pattern.into(), - definition, - todo_type!("`match` pattern definition types"), - ); + self.add_binding(pattern.into(), definition, |_, _| { + todo_type!("`match` pattern definition types") + }); } fn infer_match_pattern(&mut self, pattern: &ast::Pattern) { @@ -3299,8 +3305,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } = assignment; for target in targets { - self.infer_target(target, value, |builder, value_expr| { - builder.infer_standalone_expression(value_expr, TypeContext::default()) + self.infer_target(target, value, |builder| { + builder.infer_standalone_expression(value, TypeContext::default()) }); } } @@ -3316,11 +3322,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { /// `target`. fn infer_target(&mut self, target: &ast::Expr, value: &ast::Expr, infer_value_expr: F) where - F: Fn(&mut TypeInferenceBuilder<'db, '_>, &ast::Expr) -> Type<'db>, + F: Fn(&mut Self) -> Type<'db>, { let assigned_ty = match target { ast::Expr::Name(_) => None, - _ => Some(infer_value_expr(self, value)), + _ => Some(infer_value_expr(self)), }; self.infer_target_impl(target, value, assigned_ty); } @@ -4069,6 +4075,21 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { assignment: &AssignmentDefinitionKind<'db>, definition: Definition<'db>, ) { + let target = assignment.target(self.module()); + + self.add_binding(target.into(), definition, |builder, tcx| { + let target_ty = builder.infer_assignment_definition_impl(assignment, definition, tcx); + builder.store_expression_type(target, target_ty); + target_ty + }); + } + + fn infer_assignment_definition_impl( + &mut self, + assignment: &AssignmentDefinitionKind<'db>, + definition: Definition<'db>, + tcx: TypeContext<'db>, + ) -> Type<'db> { let value = assignment.value(self.module()); let target = assignment.target(self.module()); @@ -4084,7 +4105,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { unpacked.expression_type(target) } TargetKind::Single => { - let tcx = TypeContext::default(); let value_ty = if let Some(standalone_expression) = self.index.try_expression(value) { self.infer_standalone_expression_impl(value, standalone_expression, tcx) @@ -4109,6 +4129,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { } else { self.infer_call_expression_impl(call_expr, callable_type, tcx) }; + self.store_expression_type(value, ty); ty } else { @@ -4140,8 +4161,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { target_ty = Type::SpecialForm(special_form); } - self.store_expression_type(target, target_ty); - self.add_binding(target.into(), definition, target_ty); + target_ty } fn infer_legacy_typevar( @@ -4678,7 +4698,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { definition: Definition<'db>, ) { let target_ty = self.infer_augment_assignment(assignment); - self.add_binding(assignment.into(), definition, target_ty); + self.add_binding(assignment.into(), definition, |_, _| target_ty); } fn infer_augment_assignment(&mut self, assignment: &ast::StmtAugAssign) -> Type<'db> { @@ -4729,12 +4749,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { is_async: _, } = for_statement; - self.infer_target(target, iter, |builder, iter_expr| { + self.infer_target(target, iter, |builder| { // TODO: `infer_for_statement_definition` reports a diagnostic if `iter_ty` isn't iterable // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `for a.x in not_iterable: ... builder - .infer_standalone_expression(iter_expr, TypeContext::default()) + .infer_standalone_expression(iter, TypeContext::default()) .iterate(builder.db()) .homogeneous_element_type(builder.db()) }); @@ -4778,7 +4798,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; self.store_expression_type(target, loop_var_value_type); - self.add_binding(target.into(), definition, loop_var_value_type); + self.add_binding(target.into(), definition, |_, _| loop_var_value_type); } fn infer_while_statement(&mut self, while_statement: &ast::StmtWhile) { @@ -6291,23 +6311,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { is_async: _, } = comprehension; - self.infer_target(target, iter, |builder, iter_expr| { + self.infer_target(target, iter, |builder| { // TODO: `infer_comprehension_definition` reports a diagnostic if `iter_ty` isn't iterable // but only if the target is a name. We should report a diagnostic here if the target isn't a name: // `[... for a.x in not_iterable] if is_first { infer_same_file_expression_type( builder.db(), - builder.index.expression(iter_expr), + builder.index.expression(iter), TypeContext::default(), builder.module(), ) } else { - builder.infer_standalone_expression(iter_expr, TypeContext::default()) + builder.infer_standalone_expression(iter, TypeContext::default()) } .iterate(builder.db()) .homogeneous_element_type(builder.db()) }); + for expr in ifs { self.infer_standalone_expression(expr, TypeContext::default()); } @@ -6365,7 +6386,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { }; self.expressions.insert(target.into(), target_type); - self.add_binding(target.into(), definition, target_type); + self.add_binding(target.into(), definition, |_, _| target_type); } fn infer_named_expression(&mut self, named: &ast::ExprNamed) -> Type<'db> { @@ -6395,12 +6416,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { value, } = named; - let value_ty = self.infer_expression(value, TypeContext::default()); self.infer_expression(target, TypeContext::default()); - self.add_binding(named.into(), definition, value_ty); - - value_ty + self.add_binding(named.into(), definition, |builder, tcx| { + builder.infer_expression(value, tcx) + }) } fn infer_if_expression(&mut self, if_expression: &ast::ExprIf) -> Type<'db> { @@ -8549,8 +8569,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { // - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal // - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal let db = self.db(); - let try_dunder = |inference: &mut TypeInferenceBuilder<'db, '_>, - policy: MemberLookupPolicy| { + let try_dunder = |inference: &mut Self, policy: MemberLookupPolicy| { let rich_comparison = |op| inference.infer_rich_comparison(left, right, op, policy); let membership_test_comparison = |op, range: TextRange| { inference.infer_membership_test_comparison(left, right, op, range)