[ty] bidirectional type inference using function return type annotations (#20528)
Some checks failed
CI / cargo fmt (push) Has been cancelled
CI / mkdocs (push) Has been cancelled
CI / Determine changes (push) Has been cancelled
CI / cargo build (release) (push) Has been cancelled
CI / python package (push) Has been cancelled
CI / pre-commit (push) Has been cancelled
[ty Playground] Release / publish (push) Has been cancelled
CI / cargo clippy (push) Has been cancelled
CI / cargo test (linux) (push) Has been cancelled
CI / cargo test (linux, release) (push) Has been cancelled
CI / cargo test (windows) (push) Has been cancelled
CI / cargo test (wasm) (push) Has been cancelled
CI / cargo build (msrv) (push) Has been cancelled
CI / cargo fuzz build (push) Has been cancelled
CI / fuzz parser (push) Has been cancelled
CI / test scripts (push) Has been cancelled
CI / ecosystem (push) Has been cancelled
CI / Fuzz for new ty panics (push) Has been cancelled
CI / cargo shear (push) Has been cancelled
CI / ty completion evaluation (push) Has been cancelled
CI / formatter instabilities and black similarity (push) Has been cancelled
CI / test ruff-lsp (push) Has been cancelled
CI / check playground (push) Has been cancelled
CI / benchmarks instrumented (ruff) (push) Has been cancelled
CI / benchmarks instrumented (ty) (push) Has been cancelled
CI / benchmarks walltime (medium|multithreaded) (push) Has been cancelled
CI / benchmarks walltime (small|large) (push) Has been cancelled

## Summary

Implements bidirectional type inference using function return type
annotations.

This PR was originally proposed to solve astral-sh/ty#1167, but this
does not fully resolve it on its own.
Additionally, I believe we need to allow dataclasses to generate their
own `__new__` methods, [use constructor return types ​​for
inference](5844c0103d/crates/ty_python_semantic/src/types.rs (L5326-L5328)),
and a mechanism to discard type narrowing like `& ~AlwaysFalsy` if
necessary (at a more general level than this PR).

## Test Plan

`mdtest/bidirectional.md` is added.

---------

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ibraheem Ahmed <ibraheem@ibraheem.ca>
This commit is contained in:
Shunsuke Shibayama 2025-10-11 09:38:35 +09:00 committed by GitHub
parent 11a9e7ee44
commit dc64c08633
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 442 additions and 58 deletions

View file

@ -977,6 +977,10 @@ impl<'db> Type<'db> {
}
}
pub(crate) fn has_type_var(self, db: &'db dyn Db) -> bool {
any_over_type(db, self, &|ty| matches!(ty, Type::TypeVar(_)), false)
}
pub(crate) const fn into_class_literal(self) -> Option<ClassLiteral<'db>> {
match self {
Type::ClassLiteral(class_type) => Some(class_type),
@ -1167,6 +1171,15 @@ impl<'db> Type<'db> {
if yes { self.negate(db) } else { *self }
}
/// Remove the union elements that are not related to `target`.
pub(crate) fn filter_disjoint_elements(self, db: &'db dyn Db, target: Type<'db>) -> Type<'db> {
if let Type::Union(union) = self {
union.filter(db, |elem| !elem.is_disjoint_from(db, target))
} else {
self
}
}
/// Returns the fallback instance type that a literal is an instance of, or `None` if the type
/// is not a literal.
pub(crate) fn literal_fallback_instance(self, db: &'db dyn Db) -> Option<Type<'db>> {

View file

@ -341,6 +341,48 @@ impl<'db> OverloadLiteral<'db> {
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
pub(crate) fn signature(self, db: &'db dyn Db) -> Signature<'db> {
let mut signature = self.raw_signature(db);
let scope = self.body_scope(db);
let module = parsed_module(db, self.file(db)).load(db);
let function_node = scope.node(db).expect_function().node(&module);
let index = semantic_index(db, scope.file(db));
let file_scope_id = scope.file_scope_id(db);
let is_generator = file_scope_id.is_generator_function(index);
if function_node.is_async && !is_generator {
signature = signature.wrap_coroutine_return_type(db);
}
signature = signature.mark_typevars_inferable(db);
let pep695_ctx = function_node.type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(db, index, self.definition(db), type_params)
});
let legacy_ctx = GenericContext::from_function_params(
db,
self.definition(db),
signature.parameters(),
signature.return_ty,
);
// We need to update `signature.generic_context` here,
// because type variables in `GenericContext::variables` are still non-inferable.
signature.generic_context =
GenericContext::merge_pep695_and_legacy(db, pep695_ctx, legacy_ctx);
signature
}
/// Typed internally-visible "raw" signature for this function.
/// That is, type variables in parameter types and the return type remain non-inferable,
/// and the return types of async functions are not wrapped in `CoroutineType[...]`.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
fn raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
/// `self` or `cls` can be implicitly positional-only if:
/// - It is a method AND
/// - No parameters in the method use PEP-570 syntax AND
@ -402,11 +444,11 @@ impl<'db> OverloadLiteral<'db> {
let function_stmt_node = scope.node(db).expect_function().node(&module);
let definition = self.definition(db);
let index = semantic_index(db, scope.file(db));
let generic_context = function_stmt_node.type_params.as_ref().map(|type_params| {
let pep695_ctx = function_stmt_node.type_params.as_ref().map(|type_params| {
GenericContext::from_type_params(db, index, definition, type_params)
});
let file_scope_id = scope.file_scope_id(db);
let is_generator = file_scope_id.is_generator_function(index);
let has_implicitly_positional_first_parameter = has_implicitly_positional_only_first_param(
db,
self,
@ -417,10 +459,9 @@ impl<'db> OverloadLiteral<'db> {
Signature::from_function(
db,
generic_context,
pep695_ctx,
definition,
function_stmt_node,
is_generator,
has_implicitly_positional_first_parameter,
)
}
@ -599,6 +640,18 @@ impl<'db> FunctionLiteral<'db> {
fn last_definition_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.last_definition(db).signature(db)
}
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
///
/// ## Warning
///
/// This uses the semantic index to find the definition of the function. This means that if the
/// calling query is not in the same file as this function is defined in, then this will create
/// a cross-module dependency directly on the full AST which will lead to cache
/// over-invalidation.
fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.last_definition(db).raw_signature(db)
}
}
/// Represents a function type, which might be a non-generic function, or a specialization of a
@ -877,6 +930,17 @@ impl<'db> FunctionType<'db> {
.unwrap_or_else(|| self.literal(db).last_definition_signature(db))
}
/// Typed externally-visible "raw" signature of the last overload or implementation of this function.
#[salsa::tracked(
returns(ref),
cycle_fn=last_definition_signature_cycle_recover,
cycle_initial=last_definition_signature_cycle_initial,
heap_size=ruff_memory_usage::heap_size,
)]
pub(crate) fn last_definition_raw_signature(self, db: &'db dyn Db) -> Signature<'db> {
self.literal(db).last_definition_raw_signature(db)
}
/// Convert the `FunctionType` into a [`CallableType`].
pub(crate) fn into_callable_type(self, db: &'db dyn Db) -> CallableType<'db> {
CallableType::new(db, self.signature(db), false)

View file

@ -291,6 +291,28 @@ impl<'db> GenericContext<'db> {
Some(Self::from_typevar_instances(db, variables))
}
pub(crate) fn merge_pep695_and_legacy(
db: &'db dyn Db,
pep695_generic_context: Option<Self>,
legacy_generic_context: Option<Self>,
) -> Option<Self> {
match (legacy_generic_context, pep695_generic_context) {
(Some(legacy_ctx), Some(ctx)) => {
if legacy_ctx
.variables(db)
.exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{
Some(legacy_ctx.merge(db, ctx))
} else {
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
Some(ctx)
}
}
(left, right) => left.or(right),
}
}
/// Creates a generic context from the legacy `TypeVar`s that appear in class's base class
/// list.
pub(crate) fn from_base_classes(
@ -1174,7 +1196,7 @@ impl<'db> SpecializationBuilder<'db> {
pub(crate) fn infer(
&mut self,
formal: Type<'db>,
actual: Type<'db>,
mut actual: Type<'db>,
) -> Result<(), SpecializationError<'db>> {
if formal == actual {
return Ok(());
@ -1203,6 +1225,10 @@ impl<'db> SpecializationBuilder<'db> {
return Ok(());
}
// For example, if `formal` is `list[T]` and `actual` is `list[int] | None`, we want to specialize `T` to `int`.
// So, here we remove the union elements that are not related to `formal`.
actual = actual.filter_disjoint_elements(self.db, formal);
match (formal, actual) {
// TODO: We haven't implemented a full unification solver yet. If typevars appear in
// multiple union elements, we ideally want to express that _only one_ of them needs to
@ -1228,9 +1254,15 @@ impl<'db> SpecializationBuilder<'db> {
// def _(y: str | int | None):
// reveal_type(g(x)) # revealed: str | int
// ```
let formal_bound_typevars =
(formal_union.elements(self.db).iter()).filter_map(|ty| ty.into_type_var());
let Ok(formal_bound_typevar) = formal_bound_typevars.exactly_one() else {
// We do not handle cases where the `formal` types contain other types that contain type variables
// to prevent incorrect specialization: e.g. `T = int | list[int]` for `formal: T | list[T], actual: int | list[int]`
// (the correct specialization is `T = int`).
let types_have_typevars = formal_union
.elements(self.db)
.iter()
.filter(|ty| ty.has_type_var(self.db));
let Ok(Type::TypeVar(formal_bound_typevar)) = types_have_typevars.exactly_one()
else {
return Ok(());
};
if (actual_union.elements(self.db).iter()).any(|ty| ty.is_type_var()) {
@ -1241,7 +1273,7 @@ impl<'db> SpecializationBuilder<'db> {
if remaining_actual.is_never() {
return Ok(());
}
self.add_type_mapping(formal_bound_typevar, remaining_actual);
self.add_type_mapping(*formal_bound_typevar, remaining_actual);
}
(Type::Union(formal), _) => {
// Second, if the formal is a union, and precisely one union element _is_ a typevar (not

View file

@ -50,6 +50,7 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::scope::ScopeId;
use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::diagnostic::TypeCheckDiagnostics;
use crate::types::function::FunctionType;
use crate::types::generics::Specialization;
use crate::types::unpacker::{UnpackResult, Unpacker};
use crate::types::{ClassLiteral, KnownClass, Truthiness, Type, TypeAndQualifiers};
@ -389,6 +390,12 @@ impl<'db> TypeContext<'db> {
self.annotation
.and_then(|ty| ty.known_specialization(db, known_class))
}
pub(crate) fn map_annotation(self, f: impl FnOnce(Type<'db>) -> Type<'db>) -> Self {
Self {
annotation: self.annotation.map(f),
}
}
}
/// Returns the statically-known truthiness of a given expression.
@ -487,6 +494,30 @@ pub(crate) fn nearest_enclosing_class<'db>(
})
}
/// Returns the type of the nearest enclosing function for the given scope.
///
/// This function walks up the ancestor scopes starting from the given scope,
/// and finds the closest (non-lambda) function definition.
///
/// Returns `None` if no enclosing function is found.
pub(crate) fn nearest_enclosing_function<'db>(
db: &'db dyn Db,
semantic: &SemanticIndex<'db>,
scope: ScopeId,
) -> Option<FunctionType<'db>> {
semantic
.ancestor_scopes(scope.file_scope_id(db))
.find_map(|(_, ancestor_scope)| {
let func = ancestor_scope.node().as_function()?;
let definition = semantic.expect_single_definition(func);
let inference = infer_definition_types(db, definition);
inference
.undecorated_type()
.unwrap_or_else(|| inference.declaration_type(definition).inner_type())
.into_function_literal()
})
}
/// A region within which we can infer types.
#[derive(Copy, Clone, Debug)]
pub(crate) enum InferenceRegion<'db> {

View file

@ -79,6 +79,7 @@ use crate::types::function::{
};
use crate::types::generics::{GenericContext, bind_typevar};
use crate::types::generics::{LegacyGenericBase, SpecializationBuilder};
use crate::types::infer::nearest_enclosing_function;
use crate::types::instance::SliceLiteral;
use crate::types::mro::MroErrorKind;
use crate::types::signatures::Signature;
@ -5101,9 +5102,20 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
fn infer_return_statement(&mut self, ret: &ast::StmtReturn) {
if let Some(ty) =
self.infer_optional_expression(ret.value.as_deref(), TypeContext::default())
{
let tcx = if ret.value.is_some() {
nearest_enclosing_function(self.db(), self.index, self.scope())
.map(|func| {
// When inferring expressions within a function body,
// the expected type passed should be the "raw" type,
// i.e. type variables in the return type are non-inferable,
// and the return types of async functions are not wrapped in `CoroutineType[...]`.
TypeContext::new(func.last_definition_raw_signature(self.db()).return_ty)
})
.unwrap_or_default()
} else {
TypeContext::default()
};
if let Some(ty) = self.infer_optional_expression(ret.value.as_deref(), tcx) {
let range = ret
.value
.as_ref()
@ -5900,6 +5912,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
return None;
};
let tcx = tcx.map_annotation(|annotation| {
// Remove any union elements of `annotation` that are not related to `collection_ty`.
// e.g. `annotation: list[int] | None => list[int]` if `collection_ty: list`
let collection_ty = collection_class.to_instance(self.db());
annotation.filter_disjoint_elements(self.db(), collection_ty)
});
// Extract the annotated type of `T`, if provided.
let annotated_elt_tys = tcx
.known_specialization(self.db(), collection_class)

View file

@ -26,9 +26,10 @@ use crate::types::function::FunctionType;
use crate::types::generics::{GenericContext, typing_self, walk_generic_context};
use crate::types::infer::nearest_enclosing_class;
use crate::types::{
ApplyTypeMappingVisitor, BoundTypeVarInstance, ClassLiteral, FindLegacyTypeVarsVisitor,
HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor, KnownClass, MaterializationKind,
NormalizedVisitor, TypeContext, TypeMapping, TypeRelation, VarianceInferable, todo_type,
ApplyTypeMappingVisitor, BindingContext, BoundTypeVarInstance, ClassLiteral,
FindLegacyTypeVarsVisitor, HasRelationToVisitor, IsDisjointVisitor, IsEquivalentVisitor,
KnownClass, MaterializationKind, NormalizedVisitor, TypeContext, TypeMapping, TypeRelation,
VarianceInferable, todo_type,
};
use crate::{Db, FxOrderSet};
use ruff_python_ast::{self as ast, name::Name};
@ -419,10 +420,9 @@ impl<'db> Signature<'db> {
/// Return a typed signature from a function definition.
pub(super) fn from_function(
db: &'db dyn Db,
generic_context: Option<GenericContext<'db>>,
pep695_generic_context: Option<GenericContext<'db>>,
definition: Definition<'db>,
function_node: &ast::StmtFunctionDef,
is_generator: bool,
has_implicitly_positional_first_parameter: bool,
) -> Self {
let parameters = Parameters::from_parameters(
@ -431,38 +431,17 @@ impl<'db> Signature<'db> {
function_node.parameters.as_ref(),
has_implicitly_positional_first_parameter,
);
let return_ty = function_node.returns.as_ref().map(|returns| {
let plain_return_ty = definition_expression_type(db, definition, returns.as_ref())
.apply_type_mapping(
db,
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
TypeContext::default(),
);
if function_node.is_async && !is_generator {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), plain_return_ty])
} else {
plain_return_ty
}
});
let return_ty = function_node
.returns
.as_ref()
.map(|returns| definition_expression_type(db, definition, returns.as_ref()));
let legacy_generic_context =
GenericContext::from_function_params(db, definition, &parameters, return_ty);
let full_generic_context = match (legacy_generic_context, generic_context) {
(Some(legacy_ctx), Some(ctx)) => {
if legacy_ctx
.variables(db)
.exactly_one()
.is_ok_and(|bound_typevar| bound_typevar.typevar(db).is_self(db))
{
Some(legacy_ctx.merge(db, ctx))
} else {
// TODO: Raise a diagnostic — mixing PEP 695 and legacy typevars is not allowed
Some(ctx)
}
}
(left, right) => left.or(right),
};
let full_generic_context = GenericContext::merge_pep695_and_legacy(
db,
pep695_generic_context,
legacy_generic_context,
);
Self {
generic_context: full_generic_context,
@ -472,6 +451,27 @@ impl<'db> Signature<'db> {
}
}
pub(super) fn mark_typevars_inferable(self, db: &'db dyn Db) -> Self {
if let Some(definition) = self.definition {
self.apply_type_mapping_impl(
db,
&TypeMapping::MarkTypeVarsInferable(Some(BindingContext::Definition(definition))),
TypeContext::default(),
&ApplyTypeMappingVisitor::default(),
)
} else {
self
}
}
pub(super) fn wrap_coroutine_return_type(self, db: &'db dyn Db) -> Self {
let return_ty = self.return_ty.map(|return_ty| {
KnownClass::CoroutineType
.to_specialized_instance(db, [Type::any(), Type::any(), return_ty])
});
Self { return_ty, ..self }
}
/// Returns the signature which accepts any parameters and returns an `Unknown` type.
pub(crate) fn unknown() -> Self {
Self::new(Parameters::unknown(), Some(Type::unknown()))
@ -1728,13 +1728,9 @@ impl<'db> Parameter<'db> {
kind: ParameterKind<'db>,
) -> Self {
Self {
annotated_type: parameter.annotation().map(|annotation| {
definition_expression_type(db, definition, annotation).apply_type_mapping(
db,
&TypeMapping::MarkTypeVarsInferable(Some(definition.into())),
TypeContext::default(),
)
}),
annotated_type: parameter
.annotation()
.map(|annotation| definition_expression_type(db, definition, annotation)),
kind,
form: ParameterForm::Value,
inferred_annotation: false,