[ty] Infer more precise types for collection literals (#20360)
Some checks are pending
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / test ruff-lsp (push) Blocked by required conditions
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Part of https://github.com/astral-sh/ty/issues/168. Infer more precise types for collection literals (currently, only `list` and `set`). For example,

```py
x = [1, 2, 3] # revealed: list[Unknown | int]
y: list[int] = [1, 2, 3] # revealed: list[int]
```

This could easily be extended to `dict` literals, but I am intentionally limiting scope for now.
This commit is contained in:
Ibraheem Ahmed 2025-09-17 18:51:50 -04:00 committed by GitHub
parent bfb0902446
commit e84d523bcf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 341 additions and 78 deletions

View file

@ -1130,11 +1130,30 @@ impl<'db> Type<'db> {
Type::IntLiteral(_) => Some(KnownClass::Int.to_instance(db)),
Type::BytesLiteral(_) => Some(KnownClass::Bytes.to_instance(db)),
Type::ModuleLiteral(_) => Some(KnownClass::ModuleType.to_instance(db)),
Type::FunctionLiteral(_) => Some(KnownClass::FunctionType.to_instance(db)),
Type::EnumLiteral(literal) => Some(literal.enum_class_instance(db)),
_ => None,
}
}
/// If this type is a literal, promote it to a type that this literal is an instance of.
///
/// Note that this function tries to promote literals to a more user-friendly form than their
/// fallback instance type. For example, `def _() -> int` is promoted to `Callable[[], int]`,
/// as opposed to `FunctionType`.
pub(crate) fn literal_promotion_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::StringLiteral(_) | Type::LiteralString => Some(KnownClass::Str.to_instance(db)),
Type::BooleanLiteral(_) => Some(KnownClass::Bool.to_instance(db)),
Type::IntLiteral(_) => Some(KnownClass::Int.to_instance(db)),
Type::BytesLiteral(_) => Some(KnownClass::Bytes.to_instance(db)),
Type::ModuleLiteral(_) => Some(KnownClass::ModuleType.to_instance(db)),
Type::EnumLiteral(literal) => Some(literal.enum_class_instance(db)),
Type::FunctionLiteral(literal) => Some(Type::Callable(literal.into_callable_type(db))),
_ => None,
}
}
/// Return a "normalized" version of `self` that ensures that equivalent types have the same Salsa ID.
///
/// A normalized type:
@ -1704,18 +1723,13 @@ impl<'db> Type<'db> {
| Type::IntLiteral(_)
| Type::BytesLiteral(_)
| Type::ModuleLiteral(_)
| Type::EnumLiteral(_),
| Type::EnumLiteral(_)
| Type::FunctionLiteral(_),
_,
) => (self.literal_fallback_instance(db)).when_some_and(|instance| {
instance.has_relation_to_impl(db, target, relation, visitor)
}),
// A `FunctionLiteral` type is a single-valued type like the other literals handled above,
// so it also, for now, just delegates to its instance fallback.
(Type::FunctionLiteral(_), _) => KnownClass::FunctionType
.to_instance(db)
.has_relation_to_impl(db, target, relation, visitor),
// The same reasoning applies for these special callable types:
(Type::BoundMethod(_), _) => KnownClass::MethodType
.to_instance(db)
@ -5979,8 +5993,9 @@ impl<'db> Type<'db> {
self
}
}
TypeMapping::PromoteLiterals | TypeMapping::BindLegacyTypevars(_) |
TypeMapping::MarkTypeVarsInferable(_) => self,
TypeMapping::PromoteLiterals
| TypeMapping::BindLegacyTypevars(_)
| TypeMapping::MarkTypeVarsInferable(_) => self,
TypeMapping::Materialize(materialization_kind) => {
Type::TypeVar(bound_typevar.materialize_impl(db, *materialization_kind, visitor))
}
@ -6000,10 +6015,10 @@ impl<'db> Type<'db> {
self
}
}
TypeMapping::PromoteLiterals |
TypeMapping::BindLegacyTypevars(_) |
TypeMapping::BindSelf(_) |
TypeMapping::ReplaceSelf { .. }
TypeMapping::PromoteLiterals
| TypeMapping::BindLegacyTypevars(_)
| TypeMapping::BindSelf(_)
| TypeMapping::ReplaceSelf { .. }
=> self,
TypeMapping::Materialize(materialization_kind) => Type::NonInferableTypeVar(bound_typevar.materialize_impl(db, *materialization_kind, visitor))
@ -6023,7 +6038,13 @@ impl<'db> Type<'db> {
}
Type::FunctionLiteral(function) => {
Type::FunctionLiteral(function.with_type_mapping(db, type_mapping))
let function = Type::FunctionLiteral(function.with_type_mapping(db, type_mapping));
match type_mapping {
TypeMapping::PromoteLiterals => function.literal_promotion_type(db)
.expect("function literal should have a promotion type"),
_ => function
}
}
Type::BoundMethod(method) => Type::BoundMethod(BoundMethodType::new(
@ -6129,8 +6150,8 @@ impl<'db> Type<'db> {
TypeMapping::ReplaceSelf { .. } |
TypeMapping::MarkTypeVarsInferable(_) |
TypeMapping::Materialize(_) => self,
TypeMapping::PromoteLiterals => self.literal_fallback_instance(db)
.expect("literal type should have fallback instance type"),
TypeMapping::PromoteLiterals => self.literal_promotion_type(db)
.expect("literal type should have a promotion type"),
}
Type::Dynamic(_) => match type_mapping {
@ -6663,8 +6684,8 @@ pub enum TypeMapping<'a, 'db> {
Specialization(Specialization<'db>),
/// Applies a partial specialization to the type
PartialSpecialization(PartialSpecialization<'a, 'db>),
/// Promotes any literal types to their corresponding instance types (e.g. `Literal["string"]`
/// to `str`)
/// Replaces any literal types with their corresponding promoted type form (e.g. `Literal["string"]`
/// to `str`, or `def _() -> int` to `Callable[[], int]`).
PromoteLiterals,
/// Binds a legacy typevar with the generic context (class, function, type alias) that it is
/// being used in.

View file

@ -1048,7 +1048,7 @@ impl<'db> ClassType<'db> {
/// Return a callable type (or union of callable types) that represents the callable
/// constructor signature of this class.
#[salsa::tracked(heap_size=ruff_memory_usage::heap_size)]
#[salsa::tracked(cycle_fn=into_callable_cycle_recover, cycle_initial=into_callable_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(super) fn into_callable(self, db: &'db dyn Db) -> Type<'db> {
let self_ty = Type::from(self);
let metaclass_dunder_call_function_symbol = self_ty
@ -1208,6 +1208,20 @@ impl<'db> ClassType<'db> {
}
}
#[allow(clippy::trivially_copy_pass_by_ref)]
fn into_callable_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Type<'db>,
_count: u32,
_self: ClassType<'db>,
) -> salsa::CycleRecoveryAction<Type<'db>> {
salsa::CycleRecoveryAction::Iterate
}
fn into_callable_cycle_initial<'db>(_db: &'db dyn Db, _self: ClassType<'db>) -> Type<'db> {
Type::Never
}
impl<'db> From<GenericAlias<'db>> for ClassType<'db> {
fn from(generic: GenericAlias<'db>) -> ClassType<'db> {
ClassType::Generic(generic)

View file

@ -2626,7 +2626,7 @@ pub(crate) fn report_undeclared_protocol_member(
let binding_type = binding_type(db, definition);
let suggestion = binding_type
.literal_fallback_instance(db)
.literal_promotion_type(db)
.unwrap_or(binding_type);
if should_give_hint(db, suggestion) {

View file

@ -1081,16 +1081,13 @@ fn is_instance_truthiness<'db>(
| Type::StringLiteral(..)
| Type::LiteralString
| Type::ModuleLiteral(..)
| Type::EnumLiteral(..) => always_true_if(
| Type::EnumLiteral(..)
| Type::FunctionLiteral(..) => always_true_if(
ty.literal_fallback_instance(db)
.as_ref()
.is_some_and(is_instance),
),
Type::FunctionLiteral(..) => {
always_true_if(is_instance(&KnownClass::FunctionType.to_instance(db)))
}
Type::ClassLiteral(..) => always_true_if(is_instance(&KnownClass::Type.to_instance(db))),
Type::TypeAlias(alias) => is_instance_truthiness(db, alias.value_type(db), class),

View file

@ -49,8 +49,9 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::diagnostic::TypeCheckDiagnostics;
use crate::types::generics::Specialization;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{ClassLiteral, Truthiness, Type, TypeAndQualifiers};
use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers};
use crate::unpack::Unpack;
use builder::TypeInferenceBuilder;
@ -355,10 +356,31 @@ pub(crate) struct TypeContext<'db> {
}
impl<'db> TypeContext<'db> {
pub(crate) fn new(annotation: Type<'db>) -> Self {
Self {
annotation: Some(annotation),
pub(crate) fn new(annotation: Option<Type<'db>>) -> Self {
Self { annotation }
}
// If the type annotation is a specialized instance of the given `KnownClass`, returns the
// specialization.
fn known_specialization(
&self,
known_class: KnownClass,
db: &'db dyn Db,
) -> Option<Specialization<'db>> {
let class_type = match self.annotation? {
Type::NominalInstance(instance) => instance,
Type::TypeAlias(alias) => alias.value_type(db).into_nominal_instance()?,
_ => return None,
}
.class(db);
if !class_type.is_known(db, known_class) {
return None;
}
class_type
.into_generic_alias()
.map(|generic_alias| generic_alias.specialization(db))
}
}

View file

@ -73,13 +73,13 @@ use crate::types::diagnostic::{
use crate::types::function::{
FunctionDecorators, FunctionLiteral, FunctionType, KnownFunction, OverloadLiteral,
};
use crate::types::generics::LegacyGenericBase;
use crate::types::generics::{GenericContext, bind_typevar};
use crate::types::generics::{LegacyGenericBase, SpecializationBuilder};
use crate::types::instance::SliceLiteral;
use crate::types::mro::MroErrorKind;
use crate::types::signatures::Signature;
use crate::types::subclass_of::SubclassOfInner;
use crate::types::tuple::{Tuple, TupleSpec, TupleType};
use crate::types::tuple::{Tuple, TupleLength, TupleSpec, TupleType};
use crate::types::typed_dict::{
TypedDictAssignmentKind, validate_typed_dict_constructor, validate_typed_dict_dict_literal,
validate_typed_dict_key_assignment,
@ -90,8 +90,9 @@ use crate::types::{
IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy,
MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType,
SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers,
TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation,
TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type,
TypeContext, TypeMapping, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation,
TypeVarDefaultEvaluation, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type,
todo_type,
};
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
use crate::unpack::{EvaluationMode, UnpackPosition};
@ -4008,7 +4009,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Some(value) = value {
self.infer_maybe_standalone_expression(
value,
TypeContext::new(annotated.inner_type()),
TypeContext::new(Some(annotated.inner_type())),
);
}
@ -4101,8 +4102,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
debug_assert!(PlaceExpr::try_from_expr(target).is_some());
if let Some(value) = value {
let inferred_ty = self
.infer_maybe_standalone_expression(value, TypeContext::new(declared.inner_type()));
let inferred_ty = self.infer_maybe_standalone_expression(
value,
TypeContext::new(Some(declared.inner_type())),
);
let mut inferred_ty = if target
.as_name_expr()
.is_some_and(|name| &name.id == "TYPE_CHECKING")
@ -5236,7 +5239,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_tuple_expression(
&mut self,
tuple: &ast::ExprTuple,
_tcx: TypeContext<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprTuple {
range: _,
@ -5246,11 +5249,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
parenthesized: _,
} = tuple;
let annotated_tuple = tcx
.known_specialization(KnownClass::Tuple, self.db())
.and_then(|specialization| {
specialization
.tuple(self.db())
.expect("the specialization of `KnownClass::Tuple` must have a tuple spec")
.resize(self.db(), TupleLength::Fixed(elts.len()))
.ok()
});
let mut annotated_elt_tys = annotated_tuple.as_ref().map(Tuple::all_elements);
let db = self.db();
let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| {
// TODO: Use the type context for more precise inference.
let element_type = self.infer_expression(element, TypeContext::default());
let annotated_elt_ty = annotated_elt_tys.as_mut().and_then(Iterator::next).copied();
let element_type = self.infer_expression(element, TypeContext::new(annotated_elt_ty));
if element_type.has_divergent_type(self.db(), divergent) {
divergent
} else {
@ -5261,7 +5277,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::heterogeneous_tuple(db, element_types)
}
fn infer_list_expression(&mut self, list: &ast::ExprList, _tcx: TypeContext<'db>) -> Type<'db> {
fn infer_list_expression(&mut self, list: &ast::ExprList, tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprList {
range: _,
node_index: _,
@ -5269,28 +5285,102 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _,
} = list;
// TODO: Use the type context for more precise inference.
for elt in elts {
self.infer_expression(elt, TypeContext::default());
}
KnownClass::List
.to_specialized_instance(self.db(), [todo_type!("list literal element type")])
self.infer_collection_literal(elts, tcx, KnownClass::List)
.unwrap_or_else(|| {
KnownClass::List.to_specialized_instance(self.db(), [Type::unknown()])
})
}
fn infer_set_expression(&mut self, set: &ast::ExprSet, _tcx: TypeContext<'db>) -> Type<'db> {
fn infer_set_expression(&mut self, set: &ast::ExprSet, tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprSet {
range: _,
node_index: _,
elts,
} = set;
// TODO: Use the type context for more precise inference.
for elt in elts {
self.infer_expression(elt, TypeContext::default());
self.infer_collection_literal(elts, tcx, KnownClass::Set)
.unwrap_or_else(|| {
KnownClass::Set.to_specialized_instance(self.db(), [Type::unknown()])
})
}
// Infer the type of a collection literal expression.
fn infer_collection_literal(
&mut self,
elts: &[ast::Expr],
tcx: TypeContext<'db>,
collection_class: KnownClass,
) -> Option<Type<'db>> {
// Extract the type variable `T` from `list[T]` in typeshed.
fn elts_ty(
collection_class: KnownClass,
db: &dyn Db,
) -> Option<(ClassLiteral<'_>, Type<'_>)> {
let class_literal = collection_class.try_to_class_literal(db)?;
let generic_context = class_literal.generic_context(db)?;
let variables = generic_context.variables(db);
let elts_ty = variables.iter().exactly_one().ok()?;
Some((class_literal, Type::TypeVar(*elts_ty)))
}
KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")])
let annotated_elts_ty = tcx
.known_specialization(collection_class, self.db())
.and_then(|specialization| specialization.types(self.db()).iter().exactly_one().ok())
.copied();
let (class_literal, elts_ty) = elts_ty(collection_class, self.db()).unwrap_or_else(|| {
let name = collection_class.name(self.db());
panic!("Typeshed should always have a `{name}` class in `builtins.pyi` with a single type variable")
});
let mut elements_are_assignable = true;
let mut inferred_elt_tys = Vec::with_capacity(elts.len());
// Infer the type of each element in the collection literal.
for elt in elts {
let inferred_elt_ty = self.infer_expression(elt, TypeContext::new(annotated_elts_ty));
inferred_elt_tys.push(inferred_elt_ty);
if let Some(annotated_elts_ty) = annotated_elts_ty {
elements_are_assignable &=
inferred_elt_ty.is_assignable_to(self.db(), annotated_elts_ty);
}
}
// Create a set of constraints to infer a precise type for `T`.
let mut builder = SpecializationBuilder::new(self.db());
match annotated_elts_ty {
// If the inferred type of any element is not assignable to the type annotation, we
// ignore it, as to provide a more precise error message.
Some(_) if !elements_are_assignable => {}
// Otherwise, the annotated type acts as a constraint for `T`.
//
// Note that we infer the annotated type _before_ the elements, to closer match the order
// of any unions written in the type annotation.
Some(annotated_elts_ty) => {
builder.infer(elts_ty, annotated_elts_ty).ok()?;
}
// If a valid type annotation was not provided, avoid restricting the type of the collection
// by unioning the inferred type with `Unknown`.
None => builder.infer(elts_ty, Type::unknown()).ok()?,
}
// The inferred type of each element acts as an additional constraint on `T`.
for inferred_elt_ty in inferred_elt_tys {
// Convert any element literals to their promoted type form to avoid excessively large
// unions for large nested list literals, which the constraint solver struggles with.
let inferred_elt_ty =
inferred_elt_ty.apply_type_mapping(self.db(), &TypeMapping::PromoteLiterals);
builder.infer(elts_ty, inferred_elt_ty).ok()?;
}
let class_type = class_literal
.apply_specialization(self.db(), |generic_context| builder.build(generic_context));
Type::from(class_type).to_instance(self.db())
}
fn infer_dict_expression(&mut self, dict: &ast::ExprDict, _tcx: TypeContext<'db>) -> Type<'db> {
@ -5314,6 +5404,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
],
)
}
/// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
let mut comprehensions_iter = comprehensions.iter();

View file

@ -545,11 +545,15 @@ impl<T> VariableLengthTuple<T> {
})
}
fn prefix_elements(&self) -> impl DoubleEndedIterator<Item = &T> + ExactSizeIterator + '_ {
pub(crate) fn prefix_elements(
&self,
) -> impl DoubleEndedIterator<Item = &T> + ExactSizeIterator + '_ {
self.prefix.iter()
}
fn suffix_elements(&self) -> impl DoubleEndedIterator<Item = &T> + ExactSizeIterator + '_ {
pub(crate) fn suffix_elements(
&self,
) -> impl DoubleEndedIterator<Item = &T> + ExactSizeIterator + '_ {
self.suffix.iter()
}