diff --git a/crates/ty_python_semantic/resources/corpus/88_regression_pr_20962.py b/crates/ty_python_semantic/resources/corpus/88_regression_pr_20962.py new file mode 100644 index 0000000000..d0b9f706ce --- /dev/null +++ b/crates/ty_python_semantic/resources/corpus/88_regression_pr_20962.py @@ -0,0 +1,18 @@ +name_1 +{0: 0 for unique_name_0 in unique_name_1 if name_1} + + +@[name_2 for unique_name_2 in name_2] +def name_2(): + pass + + +def name_2(): + pass + + +match 0: + case name_2(): + pass + case []: + name_1 = 0 diff --git a/crates/ty_python_semantic/resources/mdtest/comprehensions/basic.md b/crates/ty_python_semantic/resources/mdtest/comprehensions/basic.md index bdd9ec435c..254ac03d73 100644 --- a/crates/ty_python_semantic/resources/mdtest/comprehensions/basic.md +++ b/crates/ty_python_semantic/resources/mdtest/comprehensions/basic.md @@ -103,3 +103,92 @@ async def _(): # revealed: Unknown [reveal_type(x) async for x in range(3)] ``` + +## Comprehension expression types + +The type of the comprehension expression itself should reflect the inferred element type: + +```py +from typing import TypedDict, Literal + +# revealed: list[int | Unknown] +reveal_type([x for x in range(10)]) + +# revealed: set[int | Unknown] +reveal_type({x for x in range(10)}) + +# revealed: dict[int | Unknown, str | Unknown] +reveal_type({x: str(x) for x in range(10)}) + +# revealed: list[tuple[int, Unknown | str] | Unknown] +reveal_type([(x, y) for x in range(5) for y in ["a", "b", "c"]]) + +squares: list[int | None] = [x**2 for x in range(10)] +reveal_type(squares) # revealed: list[int | None] +``` + +Inference for comprehensions takes the type context into account: + +```py +# Without type context: +reveal_type([x for x in [1, 2, 3]]) # revealed: list[Unknown | int] +reveal_type({x: "a" for x in [1, 2, 3]}) # revealed: dict[Unknown | int, str | Unknown] +reveal_type({str(x): x for x in [1, 2, 3]}) # revealed: dict[str | Unknown, Unknown | int] +reveal_type({x for x in [1, 2, 3]}) # revealed: set[Unknown | int] + +# With type context: +xs: list[int] = [x for x in [1, 2, 3]] +reveal_type(xs) # revealed: list[int] + +ys: dict[int, str] = {x: str(x) for x in [1, 2, 3]} +reveal_type(ys) # revealed: dict[int, str] + +zs: set[int] = {x for x in [1, 2, 3]} +``` + +This also works for nested comprehensions: + +```py +table = [[(x, y) for x in range(3)] for y in range(3)] +reveal_type(table) # revealed: list[list[tuple[int, int] | Unknown] | Unknown] + +table_with_content: list[list[tuple[int, int, str | None]]] = [[(x, y, None) for x in range(3)] for y in range(3)] +reveal_type(table_with_content) # revealed: list[list[tuple[int, int, str | None]]] +``` + +The type context is propagated down into the comprehension: + +```py +class Person(TypedDict): + name: str + +persons: list[Person] = [{"name": n} for n in ["Alice", "Bob"]] +reveal_type(persons) # revealed: list[Person] + +# TODO: This should be an error +invalid: list[Person] = [{"misspelled": n} for n in ["Alice", "Bob"]] +``` + +We promote literals to avoid overly-precise types in invariant positions: + +```py +reveal_type([x for x in ("a", "b", "c")]) # revealed: list[str | Unknown] +reveal_type({x for x in (1, 2, 3)}) # revealed: set[int | Unknown] +reveal_type({k: 0 for k in ("a", "b", "c")}) # revealed: dict[str | Unknown, int | Unknown] +``` + +Type context can prevent this promotion from happening: + +```py +list_of_literals: list[Literal["a", "b", "c"]] = [x for x in ("a", "b", "c")] +reveal_type(list_of_literals) # revealed: list[Literal["a", "b", "c"]] + +dict_with_literal_keys: dict[Literal["a", "b", "c"], int] = {k: 0 for k in ("a", "b", "c")} +reveal_type(dict_with_literal_keys) # revealed: dict[Literal["a", "b", "c"], int] + +dict_with_literal_values: dict[str, Literal[1, 2, 3]] = {str(k): k for k in (1, 2, 3)} +reveal_type(dict_with_literal_values) # revealed: dict[str, Literal[1, 2, 3]] + +set_with_literals: set[Literal[1, 2, 3]] = {k for k in (1, 2, 3)} +reveal_type(set_with_literals) # revealed: set[Literal[1, 2, 3]] +``` diff --git a/crates/ty_python_semantic/resources/mdtest/literal/collections/dictionary.md b/crates/ty_python_semantic/resources/mdtest/literal/collections/dictionary.md index 7e1acf4efb..ad5829da1f 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal/collections/dictionary.md +++ b/crates/ty_python_semantic/resources/mdtest/literal/collections/dictionary.md @@ -51,6 +51,6 @@ reveal_type({"a": 1, "b": (1, 2), "c": (1, 2, 3)}) ## Dict comprehensions ```py -# revealed: dict[@Todo(dict comprehension key type), @Todo(dict comprehension value type)] +# revealed: dict[int | Unknown, int | Unknown] reveal_type({x: y for x, y in enumerate(range(42))}) ``` diff --git a/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md b/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md index 15f385fa88..325caba10d 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md +++ b/crates/ty_python_semantic/resources/mdtest/literal/collections/list.md @@ -41,5 +41,5 @@ reveal_type([1, (1, 2), (1, 2, 3)]) ## List comprehensions ```py -reveal_type([x for x in range(42)]) # revealed: list[@Todo(list comprehension element type)] +reveal_type([x for x in range(42)]) # revealed: list[int | Unknown] ``` diff --git a/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md b/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md index 6c6855e40e..d80112ee84 100644 --- a/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md +++ b/crates/ty_python_semantic/resources/mdtest/literal/collections/set.md @@ -35,5 +35,5 @@ reveal_type({1, (1, 2), (1, 2, 3)}) ## Set comprehensions ```py -reveal_type({x for x in range(42)}) # revealed: set[@Todo(set comprehension element type)] +reveal_type({x for x in range(42)}) # revealed: set[int | Unknown] ``` diff --git a/crates/ty_python_semantic/resources/mdtest/regression/pr_20962_comprehension_panics.md b/crates/ty_python_semantic/resources/mdtest/regression/pr_20962_comprehension_panics.md new file mode 100644 index 0000000000..b011d95e8c --- /dev/null +++ b/crates/ty_python_semantic/resources/mdtest/regression/pr_20962_comprehension_panics.md @@ -0,0 +1,50 @@ +# Documentation of two fuzzer panics involving comprehensions + +Type inference for comprehensions was added in . It +added two new fuzzer panics that are documented here for regression testing. + +## Too many cycle iterations in `place_by_id` + + + +```py +name_5(name_3) +[0 for unique_name_0 in unique_name_1 for unique_name_2 in name_3] + +@{name_3 for unique_name_3 in unique_name_4} +class name_4[**name_3](0, name_2=name_5): + pass + +try: + name_0 = name_4 +except* 0: + pass +else: + match unique_name_12: + case 0: + from name_2 import name_3 + case name_0(): + + @name_4 + def name_3(): + pass + +(name_3 := 0) + +@name_3 +async def name_5(): + pass +``` + +## Too many cycle iterations in `infer_definition_types` + + + +```py +for name_1 in { + {{0: name_4 for unique_name_0 in unique_name_1}: 0 for unique_name_2 in unique_name_3 if name_4}: 0 + for unique_name_4 in name_1 + for name_4 in name_1 +}: + pass +``` diff --git a/crates/ty_python_semantic/src/types/function.rs b/crates/ty_python_semantic/src/types/function.rs index 0f5797ae7a..6244b0a85a 100644 --- a/crates/ty_python_semantic/src/types/function.rs +++ b/crates/ty_python_semantic/src/types/function.rs @@ -534,6 +534,14 @@ pub struct FunctionLiteral<'db> { // The Salsa heap is tracked separately. impl get_size2::GetSize for FunctionLiteral<'_> {} +fn overloads_and_implementation_cycle_initial<'db>( + _db: &'db dyn Db, + _id: salsa::Id, + _self: FunctionLiteral<'db>, +) -> (Box<[OverloadLiteral<'db>]>, Option>) { + (Box::new([]), None) +} + #[salsa::tracked] impl<'db> FunctionLiteral<'db> { fn name(self, db: &'db dyn Db) -> &'db ast::name::Name { @@ -576,7 +584,7 @@ impl<'db> FunctionLiteral<'db> { self.last_definition(db).spans(db) } - #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size)] + #[salsa::tracked(returns(ref), heap_size=ruff_memory_usage::heap_size, cycle_initial=overloads_and_implementation_cycle_initial)] fn overloads_and_implementation( self, db: &'db dyn Db, diff --git a/crates/ty_python_semantic/src/types/infer/builder.rs b/crates/ty_python_semantic/src/types/infer/builder.rs index f6055c0a0e..ad0a103319 100644 --- a/crates/ty_python_semantic/src/types/infer/builder.rs +++ b/crates/ty_python_semantic/src/types/infer/builder.rs @@ -5943,9 +5943,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ast::Expr::Set(set) => self.infer_set_expression(set, tcx), ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx), ast::Expr::Generator(generator) => self.infer_generator_expression(generator), - ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp), - ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp), - ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp), + ast::Expr::ListComp(listcomp) => { + self.infer_list_comprehension_expression(listcomp, tcx) + } + ast::Expr::DictComp(dictcomp) => { + self.infer_dict_comprehension_expression(dictcomp, tcx) + } + ast::Expr::SetComp(setcomp) => self.infer_set_comprehension_expression(setcomp, tcx), ast::Expr::Name(name) => self.infer_name_expression(name), ast::Expr::Attribute(attribute) => self.infer_attribute_expression(attribute), ast::Expr::UnaryOp(unary_op) => self.infer_unary_expression(unary_op), @@ -6450,52 +6454,121 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> { ) } - fn infer_list_comprehension_expression(&mut self, listcomp: &ast::ExprListComp) -> Type<'db> { + /// Return a specialization of the collection class (list, dict, set) based on the type context and the inferred + /// element / key-value types from the comprehension expression. + fn infer_comprehension_specialization( + &self, + collection_class: KnownClass, + inferred_element_types: &[Type<'db>], + tcx: TypeContext<'db>, + ) -> Type<'db> { + // Remove any union elements of that are unrelated to the collection type. + let tcx = tcx.map(|annotation| { + annotation.filter_disjoint_elements( + self.db(), + collection_class.to_instance(self.db()), + InferableTypeVars::None, + ) + }); + + if let Some(annotated_element_types) = tcx + .known_specialization(self.db(), collection_class) + .map(|specialization| specialization.types(self.db())) + && annotated_element_types + .iter() + .zip(inferred_element_types.iter()) + .all(|(annotated, inferred)| inferred.is_assignable_to(self.db(), *annotated)) + { + collection_class + .to_specialized_instance(self.db(), annotated_element_types.iter().copied()) + } else { + collection_class.to_specialized_instance( + self.db(), + inferred_element_types.iter().map(|ty| { + UnionType::from_elements( + self.db(), + [ + ty.promote_literals(self.db(), TypeContext::default()), + Type::unknown(), + ], + ) + }), + ) + } + } + + fn infer_list_comprehension_expression( + &mut self, + listcomp: &ast::ExprListComp, + tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprListComp { range: _, node_index: _, - elt: _, + elt, generators, } = listcomp; self.infer_first_comprehension_iter(generators); - KnownClass::List - .to_specialized_instance(self.db(), [todo_type!("list comprehension element type")]) + let scope_id = self + .index + .node_scope(NodeWithScopeRef::ListComprehension(listcomp)); + let scope = scope_id.to_scope_id(self.db(), self.file()); + let inference = infer_scope_types(self.db(), scope); + let element_type = inference.expression_type(elt.as_ref()); + + self.infer_comprehension_specialization(KnownClass::List, &[element_type], tcx) } - fn infer_dict_comprehension_expression(&mut self, dictcomp: &ast::ExprDictComp) -> Type<'db> { + fn infer_dict_comprehension_expression( + &mut self, + dictcomp: &ast::ExprDictComp, + tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprDictComp { range: _, node_index: _, - key: _, - value: _, + key, + value, generators, } = dictcomp; self.infer_first_comprehension_iter(generators); - KnownClass::Dict.to_specialized_instance( - self.db(), - [ - todo_type!("dict comprehension key type"), - todo_type!("dict comprehension value type"), - ], - ) + let scope_id = self + .index + .node_scope(NodeWithScopeRef::DictComprehension(dictcomp)); + let scope = scope_id.to_scope_id(self.db(), self.file()); + let inference = infer_scope_types(self.db(), scope); + let key_type = inference.expression_type(key.as_ref()); + let value_type = inference.expression_type(value.as_ref()); + + self.infer_comprehension_specialization(KnownClass::Dict, &[key_type, value_type], tcx) } - fn infer_set_comprehension_expression(&mut self, setcomp: &ast::ExprSetComp) -> Type<'db> { + fn infer_set_comprehension_expression( + &mut self, + setcomp: &ast::ExprSetComp, + tcx: TypeContext<'db>, + ) -> Type<'db> { let ast::ExprSetComp { range: _, node_index: _, - elt: _, + elt, generators, } = setcomp; self.infer_first_comprehension_iter(generators); - KnownClass::Set - .to_specialized_instance(self.db(), [todo_type!("set comprehension element type")]) + let scope_id = self + .index + .node_scope(NodeWithScopeRef::SetComprehension(setcomp)); + let scope = scope_id.to_scope_id(self.db(), self.file()); + let inference = infer_scope_types(self.db(), scope); + let element_type = inference.expression_type(elt.as_ref()); + + self.infer_comprehension_specialization(KnownClass::Set, &[element_type], tcx) } fn infer_generator_expression_scope(&mut self, generator: &ast::ExprGenerator) { diff --git a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs index 3c7bdb5464..0d72548e49 100644 --- a/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs +++ b/crates/ty_python_semantic/src/types/infer/builder/type_expression.rs @@ -346,7 +346,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::DictComp(dictcomp) => { - self.infer_dict_comprehension_expression(dictcomp); + self.infer_dict_comprehension_expression(dictcomp, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("Dict comprehensions are not allowed in type expressions"), @@ -355,7 +355,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::ListComp(listcomp) => { - self.infer_list_comprehension_expression(listcomp); + self.infer_list_comprehension_expression(listcomp, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("List comprehensions are not allowed in type expressions"), @@ -364,7 +364,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> { } ast::Expr::SetComp(setcomp) => { - self.infer_set_comprehension_expression(setcomp); + self.infer_set_comprehension_expression(setcomp, TypeContext::default()); self.report_invalid_type_expression( expression, format_args!("Set comprehensions are not allowed in type expressions"),