[ty] Integrate type context for bidirectional inference (#20337)
Some checks are pending
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / test ruff-lsp (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Adds the infrastructure necessary to perform bidirectional type
inference (https://github.com/astral-sh/ty/issues/168) without any
typing changes.
This commit is contained in:
Ibraheem Ahmed 2025-09-11 15:19:12 -04:00 committed by GitHub
parent c4cd5c00fd
commit 36888198a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 548 additions and 340 deletions

View file

@ -7,7 +7,7 @@ use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt};
use ruff_python_ast::{self as ast}; use ruff_python_ast::{self as ast};
use crate::semantic_index::{SemanticIndex, semantic_index}; use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::{Truthiness, Type, infer_expression_types}; use crate::types::{Truthiness, Type, TypeContext, infer_expression_types};
use crate::{Db, ModuleName, resolve_module}; use crate::{Db, ModuleName, resolve_module};
#[allow(clippy::ref_option)] #[allow(clippy::ref_option)]
@ -182,7 +182,8 @@ impl<'db> DunderAllNamesCollector<'db> {
/// ///
/// This function panics if `expr` was not marked as a standalone expression during semantic indexing. /// This function panics if `expr` was not marked as a standalone expression during semantic indexing.
fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> { fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> {
infer_expression_types(self.db, self.index.expression(expr)).expression_type(expr) infer_expression_types(self.db, self.index.expression(expr), TypeContext::default())
.expression_type(expr)
} }
/// Evaluate the given expression and return its truthiness. /// Evaluate the given expression and return its truthiness.

View file

@ -208,8 +208,8 @@ use crate::semantic_index::predicate::{
Predicates, ScopedPredicateId, Predicates, ScopedPredicateId,
}; };
use crate::types::{ use crate::types::{
IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type, IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType,
static_expression_truthiness, infer_expression_type, static_expression_truthiness,
}; };
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula /// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@ -328,10 +328,12 @@ fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type
fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> { fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> {
match kind { match kind {
PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton), PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton),
PatternPredicateKind::Value(value) => infer_expression_type(db, *value), PatternPredicateKind::Value(value) => {
infer_expression_type(db, *value, TypeContext::default())
}
PatternPredicateKind::Class(class_expr, kind) => { PatternPredicateKind::Class(class_expr, kind) => {
if kind.is_irrefutable() { if kind.is_irrefutable() {
infer_expression_type(db, *class_expr) infer_expression_type(db, *class_expr, TypeContext::default())
.to_instance(db) .to_instance(db)
.unwrap_or(Type::Never) .unwrap_or(Type::Never)
} else { } else {
@ -718,7 +720,7 @@ impl ReachabilityConstraints {
) -> Truthiness { ) -> Truthiness {
match predicate_kind { match predicate_kind {
PatternPredicateKind::Value(value) => { PatternPredicateKind::Value(value) => {
let value_ty = infer_expression_type(db, *value); let value_ty = infer_expression_type(db, *value, TypeContext::default());
if subject_ty.is_single_valued(db) { if subject_ty.is_single_valued(db) {
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty)) Truthiness::from(subject_ty.is_equivalent_to(db, value_ty))
@ -769,7 +771,8 @@ impl ReachabilityConstraints {
truthiness truthiness
} }
PatternPredicateKind::Class(class_expr, kind) => { PatternPredicateKind::Class(class_expr, kind) => {
let class_ty = infer_expression_type(db, *class_expr).to_instance(db); let class_ty =
infer_expression_type(db, *class_expr, TypeContext::default()).to_instance(db);
class_ty.map_or(Truthiness::Ambiguous, |class_ty| { class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
if subject_ty.is_subtype_of(db, class_ty) { if subject_ty.is_subtype_of(db, class_ty) {
@ -797,7 +800,7 @@ impl ReachabilityConstraints {
} }
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness { fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
let subject_ty = infer_expression_type(db, predicate.subject(db)); let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
let narrowed_subject_ty = IntersectionBuilder::new(db) let narrowed_subject_ty = IntersectionBuilder::new(db)
.add_positive(subject_ty) .add_positive(subject_ty)
@ -837,7 +840,7 @@ impl ReachabilityConstraints {
// selection algorithm). // selection algorithm).
// Avoiding this on the happy-path is important because these constraints can be // Avoiding this on the happy-path is important because these constraints can be
// very large in number, since we add them on all statement level function calls. // very large in number, since we add them on all statement level function calls.
let ty = infer_expression_type(db, callable); let ty = infer_expression_type(db, callable, TypeContext::default());
// Short-circuit for well known types that are known not to return `Never` when called. // Short-circuit for well known types that are known not to return `Never` when called.
// Without the short-circuit, we've seen that threads keep blocking each other // Without the short-circuit, we've seen that threads keep blocking each other
@ -875,7 +878,7 @@ impl ReachabilityConstraints {
} else if all_overloads_return_never { } else if all_overloads_return_never {
Truthiness::AlwaysTrue Truthiness::AlwaysTrue
} else { } else {
let call_expr_ty = infer_expression_type(db, call_expr); let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default());
if call_expr_ty.is_equivalent_to(db, Type::Never) { if call_expr_ty.is_equivalent_to(db, Type::Never) {
Truthiness::AlwaysTrue Truthiness::AlwaysTrue
} else { } else {

View file

@ -23,8 +23,8 @@ pub(crate) use self::cyclic::{CycleDetector, PairVisitor, TypeTransformer};
pub use self::diagnostic::TypeCheckDiagnostics; pub use self::diagnostic::TypeCheckDiagnostics;
pub(crate) use self::diagnostic::register_lints; pub(crate) use self::diagnostic::register_lints;
pub(crate) use self::infer::{ pub(crate) use self::infer::{
infer_deferred_types, infer_definition_types, infer_expression_type, infer_expression_types, TypeContext, infer_deferred_types, infer_definition_types, infer_expression_type,
infer_scope_types, static_expression_truthiness, infer_expression_types, infer_scope_types, static_expression_truthiness,
}; };
pub(crate) use self::signatures::{CallableSignature, Signature}; pub(crate) use self::signatures::{CallableSignature, Signature};
pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType}; pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType};
@ -10824,12 +10824,10 @@ static_assertions::assert_eq_size!(Type, [u8; 16]);
pub(crate) mod tests { pub(crate) mod tests {
use super::*; use super::*;
use crate::db::tests::{TestDbBuilder, setup_db}; use crate::db::tests::{TestDbBuilder, setup_db};
use crate::place::{global_symbol, typing_extensions_symbol, typing_symbol}; use crate::place::{typing_extensions_symbol, typing_symbol};
use crate::semantic_index::FileScopeId; use crate::semantic_index::FileScopeId;
use ruff_db::files::system_path_to_file; use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use ruff_db::system::DbWithWritableSystem as _; use ruff_db::system::DbWithWritableSystem as _;
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::PythonVersion; use ruff_python_ast::PythonVersion;
use test_case::test_case; use test_case::test_case;
@ -10868,65 +10866,6 @@ pub(crate) mod tests {
); );
} }
/// Inferring the result of a call-expression shouldn't need to re-run after
/// a trivial change to the function's file (e.g. by adding a docstring to the function).
#[test]
fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
return 5
"#,
)?;
db.write_dedented(
"src/bar.py",
r#"
from foo import foo
a = foo()
"#,
)?;
let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
// Add a docstring to foo to trigger a re-run.
// The bar-call site of foo should not be re-run because of that
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
"Computes a value"
return 5
"#,
)?;
db.clear_salsa_events();
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events);
Ok(())
}
/// All other tests also make sure that `Type::Todo` works as expected. This particular /// All other tests also make sure that `Type::Todo` works as expected. This particular
/// test makes sure that we handle `Todo` types correctly, even if they originate from /// test makes sure that we handle `Todo` types correctly, even if they originate from
/// different sources. /// different sources.

View file

@ -28,9 +28,9 @@ use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType, ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor, DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind, IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping, NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams, TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind,
UnionBuilder, VarianceInferable, declaration_type, infer_definition_types, TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types,
}; };
use crate::{ use crate::{
Db, FxIndexMap, FxOrderSet, Program, Db, FxIndexMap, FxOrderSet, Program,
@ -2926,7 +2926,11 @@ impl<'db> ClassLiteral<'db> {
// `self.SOME_CONSTANT: Final = 1`, infer the type from the value // `self.SOME_CONSTANT: Final = 1`, infer the type from the value
// on the right-hand side. // on the right-hand side.
let inferred_ty = infer_expression_type(db, index.expression(value)); let inferred_ty = infer_expression_type(
db,
index.expression(value),
TypeContext::default(),
);
return Place::bound(inferred_ty).with_qualifiers(all_qualifiers); return Place::bound(inferred_ty).with_qualifiers(all_qualifiers);
} }
@ -3014,6 +3018,7 @@ impl<'db> ClassLiteral<'db> {
let inferred_ty = infer_expression_type( let inferred_ty = infer_expression_type(
db, db,
index.expression(assign.value(&module)), index.expression(assign.value(&module)),
TypeContext::default(),
); );
union_of_inferred_types = union_of_inferred_types.add(inferred_ty); union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
@ -3041,6 +3046,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type( let iterable_ty = infer_expression_type(
db, db,
index.expression(for_stmt.iterable(&module)), index.expression(for_stmt.iterable(&module)),
TypeContext::default(),
); );
// TODO: Potential diagnostics resulting from the iterable are currently not reported. // TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = let inferred_ty =
@ -3071,6 +3077,7 @@ impl<'db> ClassLiteral<'db> {
let context_ty = infer_expression_type( let context_ty = infer_expression_type(
db, db,
index.expression(with_item.context_expr(&module)), index.expression(with_item.context_expr(&module)),
TypeContext::default(),
); );
let inferred_ty = if with_item.is_async() { let inferred_ty = if with_item.is_async() {
context_ty.aenter(db) context_ty.aenter(db)
@ -3104,6 +3111,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type( let iterable_ty = infer_expression_type(
db, db,
index.expression(comprehension.iterable(&module)), index.expression(comprehension.iterable(&module)),
TypeContext::default(),
); );
// TODO: Potential diagnostics resulting from the iterable are currently not reported. // TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty = let inferred_ty =

View file

@ -171,11 +171,21 @@ fn deferred_cycle_initial<'db>(
/// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an /// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an
/// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a /// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a
/// type narrowing guard expression (e.g. if statement test node). /// type narrowing guard expression (e.g. if statement test node).
#[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn infer_expression_types<'db>( pub(crate) fn infer_expression_types<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> &'db ExpressionInference<'db> {
infer_expression_types_impl(db, InferExpression::new(db, expression, tcx))
}
#[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn infer_expression_types_impl<'db>(
db: &'db dyn Db,
input: InferExpression<'db>,
) -> ExpressionInference<'db> { ) -> ExpressionInference<'db> {
let (expression, tcx) = (input.expression(db), input.tcx(db));
let file = expression.file(db); let file = expression.file(db);
let module = parsed_module(db, file).load(db); let module = parsed_module(db, file).load(db);
let _span = tracing::trace_span!( let _span = tracing::trace_span!(
@ -188,8 +198,13 @@ pub(crate) fn infer_expression_types<'db>(
let index = semantic_index(db, file); let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index, &module) TypeInferenceBuilder::new(
.finish_expression() db,
InferenceRegion::Expression(expression, tcx),
index,
&module,
)
.finish_expression()
} }
/// How many fixpoint iterations to allow before falling back to Divergent type. /// How many fixpoint iterations to allow before falling back to Divergent type.
@ -199,11 +214,11 @@ fn expression_cycle_recover<'db>(
db: &'db dyn Db, db: &'db dyn Db,
_value: &ExpressionInference<'db>, _value: &ExpressionInference<'db>,
count: u32, count: u32,
expression: Expression<'db>, input: InferExpression<'db>,
) -> salsa::CycleRecoveryAction<ExpressionInference<'db>> { ) -> salsa::CycleRecoveryAction<ExpressionInference<'db>> {
if count == ITERATIONS_BEFORE_FALLBACK { if count == ITERATIONS_BEFORE_FALLBACK {
salsa::CycleRecoveryAction::Fallback(ExpressionInference::cycle_fallback( salsa::CycleRecoveryAction::Fallback(ExpressionInference::cycle_fallback(
expression.scope(db), input.expression(db).scope(db),
)) ))
} else { } else {
salsa::CycleRecoveryAction::Iterate salsa::CycleRecoveryAction::Iterate
@ -212,9 +227,9 @@ fn expression_cycle_recover<'db>(
fn expression_cycle_initial<'db>( fn expression_cycle_initial<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, input: InferExpression<'db>,
) -> ExpressionInference<'db> { ) -> ExpressionInference<'db> {
ExpressionInference::cycle_initial(expression.scope(db)) ExpressionInference::cycle_initial(input.expression(db).scope(db))
} }
/// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query. /// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query.
@ -225,9 +240,10 @@ fn expression_cycle_initial<'db>(
pub(super) fn infer_same_file_expression_type<'db>( pub(super) fn infer_same_file_expression_type<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
tcx: TypeContext<'db>,
parsed: &ParsedModuleRef, parsed: &ParsedModuleRef,
) -> Type<'db> { ) -> Type<'db> {
let inference = infer_expression_types(db, expression); let inference = infer_expression_types(db, expression, tcx);
inference.expression_type(expression.node_ref(db, parsed)) inference.expression_type(expression.node_ref(db, parsed))
} }
@ -238,34 +254,108 @@ pub(super) fn infer_same_file_expression_type<'db>(
/// ///
/// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same /// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same
/// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`. /// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`.
#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn infer_expression_type<'db>( pub(crate) fn infer_expression_type<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> { ) -> Type<'db> {
let file = expression.file(db); infer_expression_type_impl(db, InferExpression::new(db, expression, tcx))
}
#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> {
let file = input.expression(db).file(db);
let module = parsed_module(db, file).load(db); let module = parsed_module(db, file).load(db);
// It's okay to call the "same file" version here because we're inside a salsa query. // It's okay to call the "same file" version here because we're inside a salsa query.
infer_same_file_expression_type(db, expression, &module) let inference = infer_expression_types_impl(db, input);
inference.expression_type(input.expression(db).node_ref(db, &module))
} }
fn single_expression_cycle_recover<'db>( fn single_expression_cycle_recover<'db>(
_db: &'db dyn Db, _db: &'db dyn Db,
_value: &Type<'db>, _value: &Type<'db>,
_count: u32, _count: u32,
_expression: Expression<'db>, _input: InferExpression<'db>,
) -> salsa::CycleRecoveryAction<Type<'db>> { ) -> salsa::CycleRecoveryAction<Type<'db>> {
salsa::CycleRecoveryAction::Iterate salsa::CycleRecoveryAction::Iterate
} }
fn single_expression_cycle_initial<'db>( fn single_expression_cycle_initial<'db>(
_db: &'db dyn Db, _db: &'db dyn Db,
_expression: Expression<'db>, _input: InferExpression<'db>,
) -> Type<'db> { ) -> Type<'db> {
Type::Never Type::Never
} }
/// An `Expression` with an optional `TypeContext`.
///
/// This is a Salsa supertype used as the input to `infer_expression_types` to avoid
/// interning an `ExpressionWithContext` unnecessarily when no type context is provided.
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)]
enum InferExpression<'db> {
Bare(Expression<'db>),
WithContext(ExpressionWithContext<'db>),
}
impl<'db> InferExpression<'db> {
fn new(
db: &'db dyn Db,
expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> InferExpression<'db> {
if tcx.annotation.is_some() {
InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx))
} else {
// Drop the empty `TypeContext` to avoid the interning cost.
InferExpression::Bare(expression)
}
}
fn expression(self, db: &'db dyn Db) -> Expression<'db> {
match self {
InferExpression::Bare(expression) => expression,
InferExpression::WithContext(expression_with_context) => {
expression_with_context.expression(db)
}
}
}
fn tcx(self, db: &'db dyn Db) -> TypeContext<'db> {
match self {
InferExpression::Bare(_) => TypeContext::default(),
InferExpression::WithContext(expression_with_context) => {
expression_with_context.tcx(db)
}
}
}
}
/// An `Expression` with a `TypeContext`.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
struct ExpressionWithContext<'db> {
expression: Expression<'db>,
tcx: TypeContext<'db>,
}
/// The type context for a given expression, namely the type annotation
/// in an annotated assignment.
///
/// Knowing the outer type context when inferring an expression can enable
/// more precise inference results, aka "bidirectional type inference".
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct TypeContext<'db> {
annotation: Option<Type<'db>>,
}
impl<'db> TypeContext<'db> {
pub(crate) fn new(annotation: Type<'db>) -> Self {
Self {
annotation: Some(annotation),
}
}
}
/// Returns the statically-known truthiness of a given expression. /// Returns the statically-known truthiness of a given expression.
/// ///
/// Returns [`Truthiness::Ambiguous`] in case any non-definitely bound places /// Returns [`Truthiness::Ambiguous`] in case any non-definitely bound places
@ -275,7 +365,7 @@ pub(crate) fn static_expression_truthiness<'db>(
db: &'db dyn Db, db: &'db dyn Db,
expression: Expression<'db>, expression: Expression<'db>,
) -> Truthiness { ) -> Truthiness {
let inference = infer_expression_types(db, expression); let inference = infer_expression_types_impl(db, InferExpression::Bare(expression));
if !inference.all_places_definitely_bound() { if !inference.all_places_definitely_bound() {
return Truthiness::Ambiguous; return Truthiness::Ambiguous;
@ -366,7 +456,7 @@ pub(crate) fn nearest_enclosing_class<'db>(
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub(crate) enum InferenceRegion<'db> { pub(crate) enum InferenceRegion<'db> {
/// infer types for a standalone [`Expression`] /// infer types for a standalone [`Expression`]
Expression(Expression<'db>), Expression(Expression<'db>, TypeContext<'db>),
/// infer types for a [`Definition`] /// infer types for a [`Definition`]
Definition(Definition<'db>), Definition(Definition<'db>),
/// infer deferred types for a [`Definition`] /// infer deferred types for a [`Definition`]
@ -378,7 +468,7 @@ pub(crate) enum InferenceRegion<'db> {
impl<'db> InferenceRegion<'db> { impl<'db> InferenceRegion<'db> {
fn scope(self, db: &'db dyn Db) -> ScopeId<'db> { fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
match self { match self {
InferenceRegion::Expression(expression) => expression.scope(db), InferenceRegion::Expression(expression, _) => expression.scope(db),
InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition) => { InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition) => {
definition.scope(db) definition.scope(db)
} }

View file

@ -90,8 +90,8 @@ use crate::types::{
IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy, IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy,
MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType, MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType,
SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers, SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers,
TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance, TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation,
TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type, TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type,
}; };
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic}; use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
use crate::unpack::{EvaluationMode, UnpackPosition}; use crate::unpack::{EvaluationMode, UnpackPosition};
@ -440,7 +440,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
InferenceRegion::Scope(scope) => self.infer_region_scope(scope), InferenceRegion::Scope(scope) => self.infer_region_scope(scope),
InferenceRegion::Definition(definition) => self.infer_region_definition(definition), InferenceRegion::Definition(definition) => self.infer_region_definition(definition),
InferenceRegion::Deferred(definition) => self.infer_region_deferred(definition), InferenceRegion::Deferred(definition) => self.infer_region_deferred(definition),
InferenceRegion::Expression(expression) => self.infer_region_expression(expression), InferenceRegion::Expression(expression, tcx) => {
self.infer_region_expression(expression, tcx);
}
} }
} }
@ -1221,10 +1223,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
} }
fn infer_region_expression(&mut self, expression: Expression<'db>) { fn infer_region_expression(&mut self, expression: Expression<'db>, tcx: TypeContext<'db>) {
match expression.kind(self.db()) { match expression.kind(self.db()) {
ExpressionKind::Normal => { ExpressionKind::Normal => {
self.infer_expression_impl(expression.node_ref(self.db(), self.module())); self.infer_expression_impl(expression.node_ref(self.db(), self.module()), tcx);
} }
ExpressionKind::TypeExpression => { ExpressionKind::TypeExpression => {
self.infer_type_expression(expression.node_ref(self.db(), self.module())); self.infer_type_expression(expression.node_ref(self.db(), self.module()));
@ -1435,7 +1437,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let declared_ty = if resolved_place.is_unbound() && !place_table.place(place_id).is_symbol() let declared_ty = if resolved_place.is_unbound() && !place_table.place(place_id).is_symbol()
{ {
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node {
let value_type = self.infer_maybe_standalone_expression(value); let value_type =
self.infer_maybe_standalone_expression(value, TypeContext::default());
if let Place::Type(ty, Boundness::Bound) = value_type.member(db, attr).place { if let Place::Type(ty, Boundness::Bound) = value_type.member(db, attr).place {
// TODO: also consider qualifiers on the attribute // TODO: also consider qualifiers on the attribute
ty ty
@ -1448,8 +1451,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}, },
) = node ) = node
{ {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice); let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx) self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx)
} else { } else {
unwrap_declared_ty() unwrap_declared_ty()
@ -1517,9 +1520,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
// In the following cases, the bound type may not be the same as the RHS value type. // In the following cases, the bound type may not be the same as the RHS value type.
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node { if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node {
let value_ty = self let value_ty = self.try_expression_type(value).unwrap_or_else(|| {
.try_expression_type(value) self.infer_maybe_standalone_expression(value, TypeContext::default())
.unwrap_or_else(|| self.infer_maybe_standalone_expression(value)); });
// If the member is a data descriptor, the RHS value may differ from the value actually assigned. // If the member is a data descriptor, the RHS value may differ from the value actually assigned.
if value_ty if value_ty
.class_member(db, attr.id.clone()) .class_member(db, attr.id.clone())
@ -1532,7 +1535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node { } else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node {
let value_ty = self let value_ty = self
.try_expression_type(value) .try_expression_type(value)
.unwrap_or_else(|| self.infer_expression(value)); .unwrap_or_else(|| self.infer_expression(value, TypeContext::default()));
if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) { if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) {
bound_ty = declared_ty; bound_ty = declared_ty;
@ -1719,7 +1722,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
std::mem::replace(&mut self.deferred_state, in_stub.into()); std::mem::replace(&mut self.deferred_state, in_stub.into());
let mut call_arguments = let mut call_arguments =
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| { CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
let ty = self.infer_expression(splatted_value); let ty = self.infer_expression(splatted_value, TypeContext::default());
self.store_expression_type(argument, ty); self.store_expression_type(argument, ty);
ty ty
}); });
@ -1988,7 +1991,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}) => { }) => {
// If this is a call expression, we would have added a `ReturnsNever` constraint, // If this is a call expression, we would have added a `ReturnsNever` constraint,
// meaning this will be a standalone expression. // meaning this will be a standalone expression.
self.infer_maybe_standalone_expression(value); self.infer_maybe_standalone_expression(value, TypeContext::default());
} }
ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement), ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement),
ast::Stmt::Try(try_statement) => self.infer_try_statement(try_statement), ast::Stmt::Try(try_statement) => self.infer_try_statement(try_statement),
@ -2085,7 +2088,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.iter_non_variadic_params() .iter_non_variadic_params()
.filter_map(|param| param.default.as_deref()) .filter_map(|param| param.default.as_deref())
{ {
self.infer_expression(default); self.infer_expression(default, TypeContext::default());
} }
// If there are type params, parameters and returns are evaluated in that scope, that is, in // If there are type params, parameters and returns are evaluated in that scope, that is, in
@ -2517,7 +2520,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// and we don't need to run inference here // and we don't need to run inference here
if type_params.is_none() { if type_params.is_none() {
for keyword in class_node.keywords() { for keyword in class_node.keywords() {
self.infer_expression(&keyword.value); self.infer_expression(&keyword.value, TypeContext::default());
} }
// Inference of bases deferred in stubs, or if any are string literals. // Inference of bases deferred in stubs, or if any are string literals.
@ -2527,7 +2530,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let previous_typevar_binding_context = let previous_typevar_binding_context =
self.typevar_binding_context.replace(definition); self.typevar_binding_context.replace(definition);
for base in class_node.bases() { for base in class_node.bases() {
self.infer_expression(base); self.infer_expression(base, TypeContext::default());
} }
self.typevar_binding_context = previous_typevar_binding_context; self.typevar_binding_context = previous_typevar_binding_context;
} }
@ -2552,9 +2555,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let previous_typevar_binding_context = self.typevar_binding_context.replace(definition); let previous_typevar_binding_context = self.typevar_binding_context.replace(definition);
for base in class.bases() { for base in class.bases() {
if self.in_stub() { if self.in_stub() {
self.infer_expression_with_state(base, DeferredExpressionState::Deferred); self.infer_expression_with_state(
base,
TypeContext::default(),
DeferredExpressionState::Deferred,
);
} else { } else {
self.infer_expression(base); self.infer_expression(base, TypeContext::default());
} }
} }
self.typevar_binding_context = previous_typevar_binding_context; self.typevar_binding_context = previous_typevar_binding_context;
@ -2565,7 +2572,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
type_alias: &ast::StmtTypeAlias, type_alias: &ast::StmtTypeAlias,
definition: Definition<'db>, definition: Definition<'db>,
) { ) {
self.infer_expression(&type_alias.name); self.infer_expression(&type_alias.name, TypeContext::default());
let rhs_scope = self let rhs_scope = self
.index .index
@ -2597,7 +2604,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
elif_else_clauses, elif_else_clauses,
} = if_statement; } = if_statement;
let test_ty = self.infer_standalone_expression(test); let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) { if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test); err.report_diagnostic(&self.context, &**test);
@ -2614,7 +2621,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = clause; } = clause;
if let Some(test) = &test { if let Some(test) = &test {
let test_ty = self.infer_standalone_expression(test); let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) { if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, test); err.report_diagnostic(&self.context, test);
@ -2681,15 +2688,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// but only if the target is a name. We should report a diagnostic here if the target isn't a name: // but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `with not_context_manager as a.x: ... // `with not_context_manager as a.x: ...
builder builder
.infer_standalone_expression(context_expr) .infer_standalone_expression(context_expr, TypeContext::default())
.enter(builder.db()) .enter(builder.db())
}); });
} else { } else {
// Call into the context expression inference to validate that it evaluates // Call into the context expression inference to validate that it evaluates
// to a valid context manager. // to a valid context manager.
let context_expression_ty = self.infer_expression(&item.context_expr); let context_expression_ty =
self.infer_expression(&item.context_expr, TypeContext::default());
self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async); self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async);
self.infer_optional_expression(target); self.infer_optional_expression(target, TypeContext::default());
} }
} }
@ -2713,7 +2721,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
unpacked.expression_type(target) unpacked.expression_type(target)
} }
TargetKind::Single => { TargetKind::Single => {
let context_expr_ty = self.infer_standalone_expression(context_expr); let context_expr_ty =
self.infer_standalone_expression(context_expr, TypeContext::default());
self.infer_context_expression(context_expr, context_expr_ty, with_item.is_async()) self.infer_context_expression(context_expr, context_expr_ty, with_item.is_async())
} }
}; };
@ -2755,7 +2764,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_exception(&mut self, node: Option<&ast::Expr>, is_star: bool) -> Type<'db> { fn infer_exception(&mut self, node: Option<&ast::Expr>, is_star: bool) -> Type<'db> {
// If there is no handled exception, it's invalid syntax; // If there is no handled exception, it's invalid syntax;
// a diagnostic will have already been emitted // a diagnostic will have already been emitted
let node_ty = node.map_or(Type::unknown(), |ty| self.infer_expression(ty)); let node_ty = node.map_or(Type::unknown(), |ty| {
self.infer_expression(ty, TypeContext::default())
});
let type_base_exception = KnownClass::BaseException.to_subclass_of(self.db()); let type_base_exception = KnownClass::BaseException.to_subclass_of(self.db());
// If it's an `except*` handler, this won't actually be the type of the bound symbol; // If it's an `except*` handler, this won't actually be the type of the bound symbol;
@ -2947,7 +2958,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
name: _, name: _,
default, default,
} = node; } = node;
self.infer_optional_expression(default.as_deref()); self.infer_optional_expression(default.as_deref(), TypeContext::default());
let pep_695_todo = Type::Dynamic(DynamicType::TodoPEP695ParamSpec); let pep_695_todo = Type::Dynamic(DynamicType::TodoPEP695ParamSpec);
self.add_declaration_with_binding( self.add_declaration_with_binding(
node.into(), node.into(),
@ -2967,7 +2978,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
name: _, name: _,
default, default,
} = node; } = node;
self.infer_optional_expression(default.as_deref()); self.infer_optional_expression(default.as_deref(), TypeContext::default());
let pep_695_todo = todo_type!("PEP-695 TypeVarTuple definition types"); let pep_695_todo = todo_type!("PEP-695 TypeVarTuple definition types");
self.add_declaration_with_binding( self.add_declaration_with_binding(
node.into(), node.into(),
@ -2984,7 +2995,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
cases, cases,
} = match_statement; } = match_statement;
self.infer_standalone_expression(subject); self.infer_standalone_expression(subject, TypeContext::default());
for case in cases { for case in cases {
let ast::MatchCase { let ast::MatchCase {
@ -2997,7 +3008,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_match_pattern(pattern); self.infer_match_pattern(pattern);
if let Some(guard) = guard.as_deref() { if let Some(guard) = guard.as_deref() {
let guard_ty = self.infer_standalone_expression(guard); let guard_ty = self.infer_standalone_expression(guard, TypeContext::default());
if let Err(err) = guard_ty.try_bool(self.db()) { if let Err(err) = guard_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, guard); err.report_diagnostic(&self.context, guard);
@ -3052,7 +3063,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510 // the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
match pattern { match pattern {
ast::Pattern::MatchValue(match_value) => { ast::Pattern::MatchValue(match_value) => {
self.infer_standalone_expression(&match_value.value); self.infer_standalone_expression(&match_value.value, TypeContext::default());
} }
ast::Pattern::MatchClass(match_class) => { ast::Pattern::MatchClass(match_class) => {
let ast::PatternMatchClass { let ast::PatternMatchClass {
@ -3067,7 +3078,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for keyword in &arguments.keywords { for keyword in &arguments.keywords {
self.infer_nested_match_pattern(&keyword.pattern); self.infer_nested_match_pattern(&keyword.pattern);
} }
self.infer_standalone_expression(cls); self.infer_standalone_expression(cls, TypeContext::default());
} }
ast::Pattern::MatchOr(match_or) => { ast::Pattern::MatchOr(match_or) => {
for pattern in &match_or.patterns { for pattern in &match_or.patterns {
@ -3083,7 +3094,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) { fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) {
match pattern { match pattern {
ast::Pattern::MatchValue(match_value) => { ast::Pattern::MatchValue(match_value) => {
self.infer_maybe_standalone_expression(&match_value.value); self.infer_maybe_standalone_expression(&match_value.value, TypeContext::default());
} }
ast::Pattern::MatchSequence(match_sequence) => { ast::Pattern::MatchSequence(match_sequence) => {
for pattern in &match_sequence.patterns { for pattern in &match_sequence.patterns {
@ -3099,7 +3110,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
rest: _, rest: _,
} = match_mapping; } = match_mapping;
for key in keys { for key in keys {
self.infer_expression(key); self.infer_expression(key, TypeContext::default());
} }
for pattern in patterns { for pattern in patterns {
self.infer_nested_match_pattern(pattern); self.infer_nested_match_pattern(pattern);
@ -3118,7 +3129,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for keyword in &arguments.keywords { for keyword in &arguments.keywords {
self.infer_nested_match_pattern(&keyword.pattern); self.infer_nested_match_pattern(&keyword.pattern);
} }
self.infer_maybe_standalone_expression(cls); self.infer_maybe_standalone_expression(cls, TypeContext::default());
} }
ast::Pattern::MatchAs(match_as) => { ast::Pattern::MatchAs(match_as) => {
if let Some(pattern) = &match_as.pattern { if let Some(pattern) = &match_as.pattern {
@ -3144,7 +3155,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for target in targets { for target in targets {
self.infer_target(target, value, |builder, value_expr| { self.infer_target(target, value, |builder, value_expr| {
builder.infer_standalone_expression(value_expr) builder.infer_standalone_expression(value_expr, TypeContext::default())
}); });
} }
} }
@ -3184,8 +3195,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _, ctx: _,
} = target; } = target;
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice); let slice_ty = self.infer_expression(slice, TypeContext::default());
let db = self.db(); let db = self.db();
let context = &self.context; let context = &self.context;
@ -3878,7 +3889,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) => { ) => {
self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown())); self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown()));
let object_ty = self.infer_expression(object); let object_ty = self.infer_expression(object, TypeContext::default());
if let Some(assigned_ty) = assigned_ty { if let Some(assigned_ty) = assigned_ty {
self.validate_attribute_assignment( self.validate_attribute_assignment(
@ -3899,7 +3910,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
_ => { _ => {
// TODO: Remove this once we handle all possible assignment targets. // TODO: Remove this once we handle all possible assignment targets.
self.infer_expression(target); self.infer_expression(target, TypeContext::default());
} }
} }
} }
@ -3924,7 +3935,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
unpacked.expression_type(target) unpacked.expression_type(target)
} }
TargetKind::Single => { TargetKind::Single => {
let value_ty = self.infer_standalone_expression(value); let value_ty = self.infer_standalone_expression(value, TypeContext::default());
// `TYPE_CHECKING` is a special variable that should only be assigned `False` // `TYPE_CHECKING` is a special variable that should only be assigned `False`
// at runtime, but is always considered `True` in type checking. // at runtime, but is always considered `True` in type checking.
@ -3988,12 +3999,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
if let Some(value) = value { if let Some(value) = value {
self.infer_maybe_standalone_expression(value); self.infer_maybe_standalone_expression(
value,
TypeContext::new(annotated.inner_type()),
);
} }
// If we have an annotated assignment like `self.attr: int = 1`, we still need to // If we have an annotated assignment like `self.attr: int = 1`, we still need to
// do type inference on the `self.attr` target to get types for all sub-expressions. // do type inference on the `self.attr` target to get types for all sub-expressions.
self.infer_expression(target); self.infer_expression(target, TypeContext::default());
// But here we explicitly overwrite the type for the overall `self.attr` node with // But here we explicitly overwrite the type for the overall `self.attr` node with
// the annotated type. We do no use `store_expression_type` here, because it checks // the annotated type. We do no use `store_expression_type` here, because it checks
@ -4080,7 +4094,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
debug_assert!(PlaceExpr::try_from_expr(target).is_some()); debug_assert!(PlaceExpr::try_from_expr(target).is_some());
if let Some(value) = value { if let Some(value) = value {
let inferred_ty = self.infer_maybe_standalone_expression(value); let inferred_ty = self
.infer_maybe_standalone_expression(value, TypeContext::new(declared.inner_type()));
let mut inferred_ty = if target let mut inferred_ty = if target
.as_name_expr() .as_name_expr()
.is_some_and(|name| &name.id == "TYPE_CHECKING") .is_some_and(|name| &name.id == "TYPE_CHECKING")
@ -4236,9 +4251,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.store_expression_type(target, previous_value); self.store_expression_type(target, previous_value);
previous_value previous_value
} }
_ => self.infer_expression(target), _ => self.infer_expression(target, TypeContext::default()),
}; };
let value_type = self.infer_expression(value); let value_type = self.infer_expression(value, TypeContext::default());
self.infer_augmented_op(assignment, target_type, value_type) self.infer_augmented_op(assignment, target_type, value_type)
} }
@ -4263,7 +4278,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// but only if the target is a name. We should report a diagnostic here if the target isn't a name: // but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `for a.x in not_iterable: ... // `for a.x in not_iterable: ...
builder builder
.infer_standalone_expression(iter_expr) .infer_standalone_expression(iter_expr, TypeContext::default())
.iterate(builder.db()) .iterate(builder.db())
.homogeneous_element_type(builder.db()) .homogeneous_element_type(builder.db())
}); });
@ -4290,7 +4305,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
unpacked.expression_type(target) unpacked.expression_type(target)
} }
TargetKind::Single => { TargetKind::Single => {
let iterable_type = self.infer_standalone_expression(iterable); let iterable_type =
self.infer_standalone_expression(iterable, TypeContext::default());
iterable_type iterable_type
.try_iterate_with_mode( .try_iterate_with_mode(
@ -4318,7 +4334,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
orelse, orelse,
} = while_statement; } = while_statement;
let test_ty = self.infer_standalone_expression(test); let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) { if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test); err.report_diagnostic(&self.context, &**test);
@ -4500,13 +4516,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
msg, msg,
} = assert; } = assert;
let test_ty = self.infer_standalone_expression(test); let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) { if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test); err.report_diagnostic(&self.context, &**test);
} }
self.infer_optional_expression(msg.as_deref()); self.infer_optional_expression(msg.as_deref(), TypeContext::default());
} }
fn infer_raise_statement(&mut self, raise: &ast::StmtRaise) { fn infer_raise_statement(&mut self, raise: &ast::StmtRaise) {
@ -4526,7 +4542,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
UnionType::from_elements(self.db(), [can_be_raised, Type::none(self.db())]); UnionType::from_elements(self.db(), [can_be_raised, Type::none(self.db())]);
if let Some(raised) = exc { if let Some(raised) = exc {
let raised_type = self.infer_expression(raised); let raised_type = self.infer_expression(raised, TypeContext::default());
if !raised_type.is_assignable_to(self.db(), can_be_raised) { if !raised_type.is_assignable_to(self.db(), can_be_raised) {
report_invalid_exception_raised(&self.context, raised, raised_type); report_invalid_exception_raised(&self.context, raised, raised_type);
@ -4534,7 +4550,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
if let Some(cause) = cause { if let Some(cause) = cause {
let cause_type = self.infer_expression(cause); let cause_type = self.infer_expression(cause, TypeContext::default());
if !cause_type.is_assignable_to(self.db(), can_be_exception_cause) { if !cause_type.is_assignable_to(self.db(), can_be_exception_cause) {
report_invalid_exception_cause(&self.context, cause, cause_type); report_invalid_exception_cause(&self.context, cause, cause_type);
@ -4740,7 +4756,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
fn infer_return_statement(&mut self, ret: &ast::StmtReturn) { fn infer_return_statement(&mut self, ret: &ast::StmtReturn) {
if let Some(ty) = self.infer_optional_expression(ret.value.as_deref()) { if let Some(ty) =
self.infer_optional_expression(ret.value.as_deref(), TypeContext::default())
{
let range = ret let range = ret
.value .value
.as_ref() .as_ref()
@ -4758,7 +4776,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
targets, targets,
} = delete; } = delete;
for target in targets { for target in targets {
self.infer_expression(target); self.infer_expression(target, TypeContext::default());
} }
} }
@ -4898,7 +4916,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
expression, expression,
} = decorator; } = decorator;
self.infer_expression(expression) self.infer_expression(expression, TypeContext::default())
} }
fn infer_argument_types<'a>( fn infer_argument_types<'a>(
@ -4920,7 +4938,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::ArgOrKeyword::Arg(arg) => arg, ast::ArgOrKeyword::Arg(arg) => arg,
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value, ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
}; };
let ty = self.infer_argument_type(argument, form); let ty = self.infer_argument_type(argument, form, TypeContext::default());
*argument_type = Some(ty); *argument_type = Some(ty);
} }
} }
@ -4929,58 +4947,73 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&mut self, &mut self,
ast_argument: &ast::Expr, ast_argument: &ast::Expr,
form: Option<ParameterForm>, form: Option<ParameterForm>,
tcx: TypeContext<'db>,
) -> Type<'db> { ) -> Type<'db> {
match form { match form {
None | Some(ParameterForm::Value) => self.infer_expression(ast_argument), None | Some(ParameterForm::Value) => self.infer_expression(ast_argument, tcx),
Some(ParameterForm::Type) => self.infer_type_expression(ast_argument), Some(ParameterForm::Type) => self.infer_type_expression(ast_argument),
} }
} }
fn infer_optional_expression(&mut self, expression: Option<&ast::Expr>) -> Option<Type<'db>> { fn infer_optional_expression(
expression.map(|expr| self.infer_expression(expr)) &mut self,
expression: Option<&ast::Expr>,
tcx: TypeContext<'db>,
) -> Option<Type<'db>> {
expression.map(|expr| self.infer_expression(expr, tcx))
} }
#[track_caller] #[track_caller]
fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> { fn infer_expression(&mut self, expression: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> {
debug_assert!( debug_assert!(
!self.index.is_standalone_expression(expression), !self.index.is_standalone_expression(expression),
"Calling `self.infer_expression` on a standalone-expression is not allowed because it can lead to double-inference. Use `self.infer_standalone_expression` instead." "Calling `self.infer_expression` on a standalone-expression is not allowed because it can lead to double-inference. Use `self.infer_standalone_expression` instead."
); );
self.infer_expression_impl(expression) self.infer_expression_impl(expression, tcx)
} }
fn infer_expression_with_state( fn infer_expression_with_state(
&mut self, &mut self,
expression: &ast::Expr, expression: &ast::Expr,
tcx: TypeContext<'db>,
state: DeferredExpressionState, state: DeferredExpressionState,
) -> Type<'db> { ) -> Type<'db> {
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, state); let previous_deferred_state = std::mem::replace(&mut self.deferred_state, state);
let ty = self.infer_expression(expression); let ty = self.infer_expression(expression, tcx);
self.deferred_state = previous_deferred_state; self.deferred_state = previous_deferred_state;
ty ty
} }
fn infer_maybe_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> { fn infer_maybe_standalone_expression(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
if let Some(standalone_expression) = self.index.try_expression(expression) { if let Some(standalone_expression) = self.index.try_expression(expression) {
self.infer_standalone_expression_impl(expression, standalone_expression) self.infer_standalone_expression_impl(expression, standalone_expression, tcx)
} else { } else {
self.infer_expression(expression) self.infer_expression(expression, tcx)
} }
} }
#[track_caller] #[track_caller]
fn infer_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> { fn infer_standalone_expression(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
let standalone_expression = self.index.expression(expression); let standalone_expression = self.index.expression(expression);
self.infer_standalone_expression_impl(expression, standalone_expression) self.infer_standalone_expression_impl(expression, standalone_expression, tcx)
} }
fn infer_standalone_expression_impl( fn infer_standalone_expression_impl(
&mut self, &mut self,
expression: &ast::Expr, expression: &ast::Expr,
standalone_expression: Expression<'db>, standalone_expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> { ) -> Type<'db> {
let types = infer_expression_types(self.db(), standalone_expression); let types = infer_expression_types(self.db(), standalone_expression, tcx);
self.extend_expression(types); self.extend_expression(types);
// Instead of calling `self.expression_type(expr)` after extending here, we get // Instead of calling `self.expression_type(expr)` after extending here, we get
@ -4990,7 +5023,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
types.expression_type(expression) types.expression_type(expression)
} }
fn infer_expression_impl(&mut self, expression: &ast::Expr) -> Type<'db> { fn infer_expression_impl(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ty = match expression { let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral { ast::Expr::NoneLiteral(ast::ExprNoneLiteral {
range: _, range: _,
@ -5005,10 +5042,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::FString(fstring) => self.infer_fstring_expression(fstring), ast::Expr::FString(fstring) => self.infer_fstring_expression(fstring),
ast::Expr::TString(tstring) => self.infer_tstring_expression(tstring), ast::Expr::TString(tstring) => self.infer_tstring_expression(tstring),
ast::Expr::EllipsisLiteral(literal) => self.infer_ellipsis_literal_expression(literal), ast::Expr::EllipsisLiteral(literal) => self.infer_ellipsis_literal_expression(literal),
ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple), ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple, tcx),
ast::Expr::List(list) => self.infer_list_expression(list), ast::Expr::List(list) => self.infer_list_expression(list, tcx),
ast::Expr::Set(set) => self.infer_set_expression(set), ast::Expr::Set(set) => self.infer_set_expression(set, tcx),
ast::Expr::Dict(dict) => self.infer_dict_expression(dict), ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx),
ast::Expr::Generator(generator) => self.infer_generator_expression(generator), ast::Expr::Generator(generator) => self.infer_generator_expression(generator),
ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp), ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp),
ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp), ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp),
@ -5024,7 +5061,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Named(named) => self.infer_named_expression(named), ast::Expr::Named(named) => self.infer_named_expression(named),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression), ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression), ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression), ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
ast::Expr::Starred(starred) => self.infer_starred_expression(starred), ast::Expr::Starred(starred) => self.infer_starred_expression(starred),
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression), ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from), ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
@ -5038,7 +5075,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ty ty
} }
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) { fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
if self.deferred_state.in_string_annotation() { if self.deferred_state.in_string_annotation() {
// Avoid storing the type of expressions that are part of a string annotation because // Avoid storing the type of expressions that are part of a string annotation because
@ -5120,11 +5156,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
conversion, conversion,
format_spec, format_spec,
} = expression; } = expression;
let ty = self.infer_expression(expression); let ty = self.infer_expression(expression, TypeContext::default());
if let Some(format_spec) = format_spec { if let Some(format_spec) = format_spec {
for element in format_spec.elements.interpolations() { for element in format_spec.elements.interpolations() {
self.infer_expression(&element.expression); self.infer_expression(
&element.expression,
TypeContext::default(),
);
} }
} }
@ -5166,10 +5205,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
format_spec, format_spec,
.. ..
} = tstring_interpolation_element; } = tstring_interpolation_element;
self.infer_expression(expression); self.infer_expression(expression, TypeContext::default());
if let Some(format_spec) = format_spec { if let Some(format_spec) = format_spec {
for element in format_spec.elements.interpolations() { for element in format_spec.elements.interpolations() {
self.infer_expression(&element.expression); self.infer_expression(&element.expression, TypeContext::default());
} }
} }
} }
@ -5187,7 +5226,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
KnownClass::EllipsisType.to_instance(self.db()) KnownClass::EllipsisType.to_instance(self.db())
} }
fn infer_tuple_expression(&mut self, tuple: &ast::ExprTuple) -> Type<'db> { fn infer_tuple_expression(
&mut self,
tuple: &ast::ExprTuple,
_tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprTuple { let ast::ExprTuple {
range: _, range: _,
node_index: _, node_index: _,
@ -5199,7 +5242,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let db = self.db(); let db = self.db();
let divergent = Type::divergent(self.scope()); let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| { let element_types = elts.iter().map(|element| {
let element_type = self.infer_expression(element); // TODO: Use the type context for more precise inference.
let element_type = self.infer_expression(element, TypeContext::default());
if element_type.has_divergent_type(self.db(), divergent) { if element_type.has_divergent_type(self.db(), divergent) {
divergent divergent
} else { } else {
@ -5210,7 +5254,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::heterogeneous_tuple(db, element_types) Type::heterogeneous_tuple(db, element_types)
} }
fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> { fn infer_list_expression(&mut self, list: &ast::ExprList, _tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprList { let ast::ExprList {
range: _, range: _,
node_index: _, node_index: _,
@ -5218,38 +5262,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _, ctx: _,
} = list; } = list;
// TODO: Use the type context for more precise inference.
for elt in elts { for elt in elts {
self.infer_expression(elt); self.infer_expression(elt, TypeContext::default());
} }
KnownClass::List KnownClass::List
.to_specialized_instance(self.db(), [todo_type!("list literal element type")]) .to_specialized_instance(self.db(), [todo_type!("list literal element type")])
} }
fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> { fn infer_set_expression(&mut self, set: &ast::ExprSet, _tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprSet { let ast::ExprSet {
range: _, range: _,
node_index: _, node_index: _,
elts, elts,
} = set; } = set;
// TODO: Use the type context for more precise inference.
for elt in elts { for elt in elts {
self.infer_expression(elt); self.infer_expression(elt, TypeContext::default());
} }
KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")]) KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")])
} }
fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> { fn infer_dict_expression(&mut self, dict: &ast::ExprDict, _tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprDict { let ast::ExprDict {
range: _, range: _,
node_index: _, node_index: _,
items, items,
} = dict; } = dict;
// TODO: Use the type context for more precise inference.
for item in items { for item in items {
self.infer_optional_expression(item.key.as_ref()); self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
self.infer_expression(&item.value); self.infer_expression(&item.value, TypeContext::default());
} }
KnownClass::Dict.to_specialized_instance( KnownClass::Dict.to_specialized_instance(
@ -5260,14 +5307,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
], ],
) )
} }
/// Infer the type of the `iter` expression of the first comprehension. /// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) { fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
let mut comprehensions_iter = comprehensions.iter(); let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else { let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator"); unreachable!("Comprehension must contain at least one generator");
}; };
self.infer_standalone_expression(&first_comprehension.iter); self.infer_standalone_expression(&first_comprehension.iter, TypeContext::default());
} }
fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> { fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
@ -5348,7 +5394,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
parenthesized: _, parenthesized: _,
} = generator; } = generator;
self.infer_expression(elt); self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators); self.infer_comprehensions(generators);
} }
@ -5360,7 +5406,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators, generators,
} = listcomp; } = listcomp;
self.infer_expression(elt); self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators); self.infer_comprehensions(generators);
} }
@ -5373,8 +5419,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators, generators,
} = dictcomp; } = dictcomp;
self.infer_expression(key); self.infer_expression(key, TypeContext::default());
self.infer_expression(value); self.infer_expression(value, TypeContext::default());
self.infer_comprehensions(generators); self.infer_comprehensions(generators);
} }
@ -5386,7 +5432,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators, generators,
} = setcomp; } = setcomp;
self.infer_expression(elt); self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators); self.infer_comprehensions(generators);
} }
@ -5419,16 +5465,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
infer_same_file_expression_type( infer_same_file_expression_type(
builder.db(), builder.db(),
builder.index.expression(iter_expr), builder.index.expression(iter_expr),
TypeContext::default(),
builder.module(), builder.module(),
) )
} else { } else {
builder.infer_standalone_expression(iter_expr) builder.infer_standalone_expression(iter_expr, TypeContext::default())
} }
.iterate(builder.db()) .iterate(builder.db())
.homogeneous_element_type(builder.db()) .homogeneous_element_type(builder.db())
}); });
for expr in ifs { for expr in ifs {
self.infer_standalone_expression(expr); self.infer_standalone_expression(expr, TypeContext::default());
} }
} }
@ -5442,7 +5489,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut infer_iterable_type = || { let mut infer_iterable_type = || {
let expression = self.index.expression(iterable); let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db(), expression); let result = infer_expression_types(self.db(), expression, TypeContext::default());
// Two things are different if it's the first comprehension: // Two things are different if it's the first comprehension:
// (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope, // (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope,
@ -5496,8 +5543,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
result.binding_type(definition) result.binding_type(definition)
} else { } else {
// For syntactically invalid targets, we still need to run type inference: // For syntactically invalid targets, we still need to run type inference:
self.infer_expression(&named.target); self.infer_expression(&named.target, TypeContext::default());
self.infer_expression(&named.value); self.infer_expression(&named.value, TypeContext::default());
Type::unknown() Type::unknown()
} }
} }
@ -5514,8 +5561,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
value, value,
} = named; } = named;
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
self.infer_expression(target); self.infer_expression(target, TypeContext::default());
self.add_binding(named.into(), definition, value_ty); self.add_binding(named.into(), definition, value_ty);
@ -5531,9 +5578,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
orelse, orelse,
} = if_expression; } = if_expression;
let test_ty = self.infer_standalone_expression(test); let test_ty = self.infer_standalone_expression(test, TypeContext::default());
let body_ty = self.infer_expression(body); let body_ty = self.infer_expression(body, TypeContext::default());
let orelse_ty = self.infer_expression(orelse); let orelse_ty = self.infer_expression(orelse, TypeContext::default());
match test_ty.try_bool(self.db()).unwrap_or_else(|err| { match test_ty.try_bool(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, &**test); err.report_diagnostic(&self.context, &**test);
@ -5546,7 +5593,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) { fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {
self.infer_expression(&lambda_expression.body); self.infer_expression(&lambda_expression.body, TypeContext::default());
} }
fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> { fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> {
@ -5564,7 +5611,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|param| { .map(|param| {
let mut parameter = Parameter::positional_only(Some(param.name().id.clone())); let mut parameter = Parameter::positional_only(Some(param.name().id.clone()));
if let Some(default) = param.default() { if let Some(default) = param.default() {
parameter = parameter.with_default_type(self.infer_expression(default)); parameter = parameter.with_default_type(
self.infer_expression(default, TypeContext::default()),
);
} }
parameter parameter
}) })
@ -5575,7 +5624,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|param| { .map(|param| {
let mut parameter = Parameter::positional_or_keyword(param.name().id.clone()); let mut parameter = Parameter::positional_or_keyword(param.name().id.clone());
if let Some(default) = param.default() { if let Some(default) = param.default() {
parameter = parameter.with_default_type(self.infer_expression(default)); parameter = parameter.with_default_type(
self.infer_expression(default, TypeContext::default()),
);
} }
parameter parameter
}) })
@ -5590,7 +5641,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|param| { .map(|param| {
let mut parameter = Parameter::keyword_only(param.name().id.clone()); let mut parameter = Parameter::keyword_only(param.name().id.clone());
if let Some(default) = param.default() { if let Some(default) = param.default() {
parameter = parameter.with_default_type(self.infer_expression(default)); parameter = parameter.with_default_type(
self.infer_expression(default, TypeContext::default()),
);
} }
parameter parameter
}) })
@ -5618,7 +5671,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
CallableType::function_like(self.db(), Signature::new(parameters, Some(Type::unknown()))) CallableType::function_like(self.db(), Signature::new(parameters, Some(Type::unknown())))
} }
fn infer_call_expression(&mut self, call_expression: &ast::ExprCall) -> Type<'db> { fn infer_call_expression(
&mut self,
call_expression: &ast::ExprCall,
_tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprCall { let ast::ExprCall {
range: _, range: _,
node_index: _, node_index: _,
@ -5631,12 +5688,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// are assignable to any parameter annotations. // are assignable to any parameter annotations.
let mut call_arguments = let mut call_arguments =
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| { CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
let ty = self.infer_expression(splatted_value); let ty = self.infer_expression(splatted_value, TypeContext::default());
self.store_expression_type(argument, ty); self.store_expression_type(argument, ty);
ty ty
}); });
let callable_type = self.infer_maybe_standalone_expression(func); // TODO: Use the type context for more precise inference.
let callable_type = self.infer_maybe_standalone_expression(func, TypeContext::default());
// Special handling for `TypedDict` method calls // Special handling for `TypedDict` method calls
if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() { if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() {
@ -5881,7 +5939,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _, ctx: _,
} = starred; } = starred;
let iterable_type = self.infer_expression(value); let iterable_type = self.infer_expression(value, TypeContext::default());
iterable_type iterable_type
.try_iterate(self.db()) .try_iterate(self.db())
.map(|tuple| tuple.homogeneous_element_type(self.db())) .map(|tuple| tuple.homogeneous_element_type(self.db()))
@ -5900,7 +5958,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _, node_index: _,
value, value,
} = yield_expression; } = yield_expression;
self.infer_optional_expression(value.as_deref()); self.infer_optional_expression(value.as_deref(), TypeContext::default());
todo_type!("yield expressions") todo_type!("yield expressions")
} }
@ -5911,7 +5969,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
value, value,
} = yield_from; } = yield_from;
let iterable_type = self.infer_expression(value); let iterable_type = self.infer_expression(value, TypeContext::default());
iterable_type iterable_type
.try_iterate(self.db()) .try_iterate(self.db())
.map(|tuple| tuple.homogeneous_element_type(self.db())) .map(|tuple| tuple.homogeneous_element_type(self.db()))
@ -5931,7 +5989,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _, node_index: _,
value, value,
} = await_expression; } = await_expression;
let expr_type = self.infer_expression(value); let expr_type = self.infer_expression(value, TypeContext::default());
expr_type.try_await(self.db()).unwrap_or_else(|err| { expr_type.try_await(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, expr_type, value.as_ref().into()); err.report_diagnostic(&self.context, expr_type, value.as_ref().into());
Type::unknown() Type::unknown()
@ -6576,7 +6634,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> { fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
let ast::ExprAttribute { value, attr, .. } = attribute; let ast::ExprAttribute { value, attr, .. } = attribute;
let value_type = self.infer_maybe_standalone_expression(value); let value_type = self.infer_maybe_standalone_expression(value, TypeContext::default());
let db = self.db(); let db = self.db();
let mut constraint_keys = vec![]; let mut constraint_keys = vec![];
@ -6687,7 +6745,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match ctx { match ctx {
ExprContext::Load => self.infer_attribute_load(attribute), ExprContext::Load => self.infer_attribute_load(attribute),
ExprContext::Store => { ExprContext::Store => {
self.infer_expression(value); self.infer_expression(value, TypeContext::default());
Type::Never Type::Never
} }
ExprContext::Del => { ExprContext::Del => {
@ -6695,7 +6753,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::Never Type::Never
} }
ExprContext::Invalid => { ExprContext::Invalid => {
self.infer_expression(value); self.infer_expression(value, TypeContext::default());
Type::unknown() Type::unknown()
} }
} }
@ -6709,7 +6767,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
operand, operand,
} = unary; } = unary;
let operand_type = self.infer_expression(operand); let operand_type = self.infer_expression(operand, TypeContext::default());
self.infer_unary_expression_type(*op, operand_type, unary) self.infer_unary_expression_type(*op, operand_type, unary)
} }
@ -6830,8 +6888,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _, node_index: _,
} = binary; } = binary;
let left_ty = self.infer_expression(left); let left_ty = self.infer_expression(left, TypeContext::default());
let right_ty = self.infer_expression(right); let right_ty = self.infer_expression(right, TypeContext::default());
self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op) self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op)
.unwrap_or_else(|| { .unwrap_or_else(|| {
@ -7276,9 +7334,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
values.iter().enumerate(), values.iter().enumerate(),
|builder, (index, value)| { |builder, (index, value)| {
let ty = if index == values.len() - 1 { let ty = if index == values.len() - 1 {
builder.infer_expression(value) builder.infer_expression(value, TypeContext::default())
} else { } else {
builder.infer_standalone_expression(value) builder.infer_standalone_expression(value, TypeContext::default())
}; };
(ty, value.range()) (ty, value.range())
@ -7359,7 +7417,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
comparators, comparators,
} = compare; } = compare;
self.infer_expression(left); self.infer_expression(left, TypeContext::default());
// https://docs.python.org/3/reference/expressions.html#comparisons // https://docs.python.org/3/reference/expressions.html#comparisons
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison // > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
@ -7376,7 +7434,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.zip(ops), .zip(ops),
|builder, ((left, right), op)| { |builder, ((left, right), op)| {
let left_ty = builder.expression_type(left); let left_ty = builder.expression_type(left);
let right_ty = builder.infer_expression(right); let right_ty = builder.infer_expression(right, TypeContext::default());
let range = TextRange::new(left.start(), right.end()); let range = TextRange::new(left.start(), right.end());
@ -8143,8 +8201,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match ctx { match ctx {
ExprContext::Load => self.infer_subscript_load(subscript), ExprContext::Load => self.infer_subscript_load(subscript),
ExprContext::Store => { ExprContext::Store => {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice); let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
Type::Never Type::Never
} }
@ -8153,8 +8211,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::Never Type::Never
} }
ExprContext::Invalid => { ExprContext::Invalid => {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice); let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
Type::unknown() Type::unknown()
} }
@ -8169,7 +8227,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
slice, slice,
ctx, ctx,
} = subscript; } = subscript;
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
let mut constraint_keys = vec![]; let mut constraint_keys = vec![];
// If `value` is a valid reference, we attempt type narrowing by assignment. // If `value` is a valid reference, we attempt type narrowing by assignment.
@ -8183,7 +8241,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Place::Type(ty, Boundness::Bound) = place.place { if let Place::Type(ty, Boundness::Bound) = place.place {
// Even if we can obtain the subscript type based on the assignments, we still perform default type inference // Even if we can obtain the subscript type based on the assignments, we still perform default type inference
// (to store the expression type and to report errors). // (to store the expression type and to report errors).
let slice_ty = self.infer_expression(slice); let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
return ty; return ty;
} }
@ -8228,7 +8286,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} }
} }
let slice_ty = self.infer_expression(slice); let slice_ty = self.infer_expression(slice, TypeContext::default());
let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx); let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys) self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys)
} }
@ -8767,9 +8825,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
step, step,
} = slice; } = slice;
let ty_lower = self.infer_optional_expression(lower.as_deref()); let ty_lower = self.infer_optional_expression(lower.as_deref(), TypeContext::default());
let ty_upper = self.infer_optional_expression(upper.as_deref()); let ty_upper = self.infer_optional_expression(upper.as_deref(), TypeContext::default());
let ty_step = self.infer_optional_expression(step.as_deref()); let ty_step = self.infer_optional_expression(step.as_deref(), TypeContext::default());
let type_to_slice_argument = |ty: Option<Type<'db>>| match ty { let type_to_slice_argument = |ty: Option<Type<'db>>| match ty {
Some(ty @ (Type::IntLiteral(_) | Type::BooleanLiteral(_))) => SliceArg::Arg(ty), Some(ty @ (Type::IntLiteral(_) | Type::BooleanLiteral(_))) => SliceArg::Arg(ty),

View file

@ -6,7 +6,7 @@ use crate::types::string_annotation::{
BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation, BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation,
}; };
use crate::types::{ use crate::types::{
KnownClass, SpecialFormType, Type, TypeAndQualifiers, TypeQualifiers, todo_type, KnownClass, SpecialFormType, Type, TypeAndQualifiers, TypeContext, TypeQualifiers, todo_type,
}; };
/// Annotation expressions. /// Annotation expressions.
@ -122,7 +122,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}, },
ast::Expr::Subscript(subscript @ ast::ExprSubscript { value, slice, .. }) => { ast::Expr::Subscript(subscript @ ast::ExprSubscript { value, slice, .. }) => {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
let slice = &**slice; let slice = &**slice;
@ -141,7 +141,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
if let [inner_annotation, metadata @ ..] = &arguments[..] { if let [inner_annotation, metadata @ ..] = &arguments[..] {
for element in metadata { for element in metadata {
self.infer_expression(element); self.infer_expression(element, TypeContext::default());
} }
let inner_annotation_ty = let inner_annotation_ty =
@ -151,7 +151,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
inner_annotation_ty inner_annotation_ty
} else { } else {
for argument in arguments { for argument in arguments {
self.infer_expression(argument); self.infer_expression(argument, TypeContext::default());
} }
self.store_expression_type(slice, Type::unknown()); self.store_expression_type(slice, Type::unknown());
TypeAndQualifiers::unknown() TypeAndQualifiers::unknown()

View file

@ -14,7 +14,7 @@ use crate::types::visitor::any_over_type;
use crate::types::{ use crate::types::{
CallableType, DynamicType, IntersectionBuilder, KnownClass, KnownInstanceType, CallableType, DynamicType, IntersectionBuilder, KnownClass, KnownInstanceType,
LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType, Type, LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType, Type,
TypeAliasType, TypeIsType, UnionBuilder, UnionType, todo_type, TypeAliasType, TypeContext, TypeIsType, UnionBuilder, UnionType, todo_type,
}; };
/// Type expressions /// Type expressions
@ -114,7 +114,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
node_index: _, node_index: _,
} = subscript; } = subscript;
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
self.infer_subscript_type_expression_no_store(subscript, slice, value_ty) self.infer_subscript_type_expression_no_store(subscript, slice, value_ty)
} }
@ -324,7 +324,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
} }
ast::Expr::Dict(dict) => { ast::Expr::Dict(dict) => {
self.infer_dict_expression(dict); self.infer_dict_expression(dict, TypeContext::default());
self.report_invalid_type_expression( self.report_invalid_type_expression(
expression, expression,
format_args!("Dict literals are not allowed in type expressions"), format_args!("Dict literals are not allowed in type expressions"),
@ -333,7 +333,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
} }
ast::Expr::Set(set) => { ast::Expr::Set(set) => {
self.infer_set_expression(set); self.infer_set_expression(set, TypeContext::default());
self.report_invalid_type_expression( self.report_invalid_type_expression(
expression, expression,
format_args!("Set literals are not allowed in type expressions"), format_args!("Set literals are not allowed in type expressions"),
@ -414,7 +414,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
} }
ast::Expr::Call(call_expr) => { ast::Expr::Call(call_expr) => {
self.infer_call_expression(call_expr); self.infer_call_expression(call_expr, TypeContext::default());
self.report_invalid_type_expression( self.report_invalid_type_expression(
expression, expression,
format_args!("Function calls are not allowed in type expressions"), format_args!("Function calls are not allowed in type expressions"),
@ -544,7 +544,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let value_ty = if builder.deferred_state.in_string_annotation() { let value_ty = if builder.deferred_state.in_string_annotation() {
// Using `.expression_type` does not work in string annotations, because // Using `.expression_type` does not work in string annotations, because
// we do not store types for sub-expressions. Re-infer the type here. // we do not store types for sub-expressions. Re-infer the type here.
builder.infer_expression(value) builder.infer_expression(value, TypeContext::default())
} else { } else {
builder.expression_type(value) builder.expression_type(value)
}; };
@ -559,7 +559,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
match tuple_slice { match tuple_slice {
ast::Expr::Tuple(elements) => { ast::Expr::Tuple(elements) => {
if let [element, ellipsis @ ast::Expr::EllipsisLiteral(_)] = &*elements.elts { if let [element, ellipsis @ ast::Expr::EllipsisLiteral(_)] = &*elements.elts {
self.infer_expression(ellipsis); self.infer_expression(ellipsis, TypeContext::default());
let result = let result =
TupleType::homogeneous(self.db(), self.infer_type_expression(element)); TupleType::homogeneous(self.db(), self.infer_type_expression(element));
self.store_expression_type(tuple_slice, Type::tuple(Some(result))); self.store_expression_type(tuple_slice, Type::tuple(Some(result)));
@ -617,7 +617,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
fn infer_subclass_of_type_expression(&mut self, slice: &ast::Expr) -> Type<'db> { fn infer_subclass_of_type_expression(&mut self, slice: &ast::Expr) -> Type<'db> {
match slice { match slice {
ast::Expr::Name(_) | ast::Expr::Attribute(_) => { ast::Expr::Name(_) | ast::Expr::Attribute(_) => {
let name_ty = self.infer_expression(slice); let name_ty = self.infer_expression(slice, TypeContext::default());
match name_ty { match name_ty {
Type::ClassLiteral(class_literal) => { Type::ClassLiteral(class_literal) => {
if class_literal.is_protocol(self.db()) { if class_literal.is_protocol(self.db()) {
@ -663,7 +663,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
slice: parameters, slice: parameters,
.. ..
}) => { }) => {
let parameters_ty = match self.infer_expression(value) { let parameters_ty = match self.infer_expression(value, TypeContext::default()) {
Type::SpecialForm(SpecialFormType::Union) => match &**parameters { Type::SpecialForm(SpecialFormType::Union) => match &**parameters {
ast::Expr::Tuple(tuple) => { ast::Expr::Tuple(tuple) => {
let ty = UnionType::from_elements_leave_aliases( let ty = UnionType::from_elements_leave_aliases(
@ -713,7 +713,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
// `infer_expression` (instead of `infer_type_expression`) here to avoid // `infer_expression` (instead of `infer_type_expression`) here to avoid
// false-positive `invalid-type-form` diagnostics (`1` is not a valid type // false-positive `invalid-type-form` diagnostics (`1` is not a valid type
// expression). // expression).
self.infer_expression(&subscript.slice); self.infer_expression(&subscript.slice, TypeContext::default());
Type::unknown() Type::unknown()
} }
Type::SpecialForm(special_form) => { Type::SpecialForm(special_form) => {
@ -912,14 +912,14 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let [type_expr, metadata @ ..] = &arguments[..] else { let [type_expr, metadata @ ..] = &arguments[..] else {
for argument in arguments { for argument in arguments {
self.infer_expression(argument); self.infer_expression(argument, TypeContext::default());
} }
self.store_expression_type(arguments_slice, Type::unknown()); self.store_expression_type(arguments_slice, Type::unknown());
return Type::unknown(); return Type::unknown();
}; };
for element in metadata { for element in metadata {
self.infer_expression(element); self.infer_expression(element, TypeContext::default());
} }
let ty = self.infer_type_expression(type_expr); let ty = self.infer_type_expression(type_expr);
@ -1107,7 +1107,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let num_arguments = arguments.len(); let num_arguments = arguments.len();
let type_of_type = if num_arguments == 1 { let type_of_type = if num_arguments == 1 {
// N.B. This uses `infer_expression` rather than `infer_type_expression` // N.B. This uses `infer_expression` rather than `infer_type_expression`
self.infer_expression(&arguments[0]) self.infer_expression(&arguments[0], TypeContext::default())
} else { } else {
for argument in arguments { for argument in arguments {
self.infer_type_expression(argument); self.infer_type_expression(argument);
@ -1137,7 +1137,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
if num_arguments != 1 { if num_arguments != 1 {
for argument in arguments { for argument in arguments {
self.infer_expression(argument); self.infer_expression(argument, TypeContext::default());
} }
report_invalid_argument_number_to_special_form( report_invalid_argument_number_to_special_form(
&self.context, &self.context,
@ -1152,7 +1152,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
return Type::unknown(); return Type::unknown();
} }
let argument_type = self.infer_expression(&arguments[0]); let argument_type = self.infer_expression(&arguments[0], TypeContext::default());
let bindings = argument_type.bindings(db); let bindings = argument_type.bindings(db);
// SAFETY: This is enforced by the constructor methods on `Bindings` even in // SAFETY: This is enforced by the constructor methods on `Bindings` even in
@ -1362,7 +1362,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
Type::tuple(self.infer_tuple_type_expression(arguments_slice)) Type::tuple(self.infer_tuple_type_expression(arguments_slice))
} }
SpecialFormType::Generic | SpecialFormType::Protocol => { SpecialFormType::Generic | SpecialFormType::Protocol => {
self.infer_expression(arguments_slice); self.infer_expression(arguments_slice, TypeContext::default());
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) { if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!( builder.into_diagnostic(format_args!(
"`{special_form}` is not allowed in type expressions", "`{special_form}` is not allowed in type expressions",
@ -1380,7 +1380,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
Ok(match parameters { Ok(match parameters {
// TODO handle type aliases // TODO handle type aliases
ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => { ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
if matches!(value_ty, Type::SpecialForm(SpecialFormType::Literal)) { if matches!(value_ty, Type::SpecialForm(SpecialFormType::Literal)) {
let ty = self.infer_literal_parameter_type(slice)?; let ty = self.infer_literal_parameter_type(slice)?;
@ -1389,7 +1389,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
self.store_expression_type(parameters, ty); self.store_expression_type(parameters, ty);
ty ty
} else { } else {
self.infer_expression(slice); self.infer_expression(slice, TypeContext::default());
self.store_expression_type(parameters, Type::unknown()); self.store_expression_type(parameters, Type::unknown());
return Err(vec![parameters]); return Err(vec![parameters]);
@ -1426,13 +1426,13 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
literal @ (ast::Expr::StringLiteral(_) literal @ (ast::Expr::StringLiteral(_)
| ast::Expr::BytesLiteral(_) | ast::Expr::BytesLiteral(_)
| ast::Expr::BooleanLiteral(_) | ast::Expr::BooleanLiteral(_)
| ast::Expr::NoneLiteral(_)) => self.infer_expression(literal), | ast::Expr::NoneLiteral(_)) => self.infer_expression(literal, TypeContext::default()),
literal @ ast::Expr::NumberLiteral(number) if number.value.is_int() => { literal @ ast::Expr::NumberLiteral(number) if number.value.is_int() => {
self.infer_expression(literal) self.infer_expression(literal, TypeContext::default())
} }
// For enum values // For enum values
ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => { ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => {
let value_ty = self.infer_expression(value); let value_ty = self.infer_expression(value, TypeContext::default());
if is_enum_class(self.db(), value_ty) { if is_enum_class(self.db(), value_ty) {
let ty = value_ty let ty = value_ty
@ -1461,7 +1461,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
ty ty
} }
_ => { _ => {
self.infer_expression(parameters); self.infer_expression(parameters, TypeContext::default());
return Err(vec![parameters]); return Err(vec![parameters]);
} }
}) })
@ -1507,7 +1507,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}); });
} }
ast::Expr::Subscript(subscript) => { ast::Expr::Subscript(subscript) => {
let value_ty = self.infer_expression(&subscript.value); let value_ty = self.infer_expression(&subscript.value, TypeContext::default());
self.infer_subscript_type_expression(subscript, value_ty); self.infer_subscript_type_expression(subscript, value_ty);
// TODO: Support `Concatenate[...]` // TODO: Support `Concatenate[...]`
return Some(Parameters::todo()); return Some(Parameters::todo());

View file

@ -5,7 +5,7 @@ use crate::place::{ConsideredDefinitions, Place, global_symbol};
use crate::semantic_index::definition::Definition; use crate::semantic_index::definition::Definition;
use crate::semantic_index::scope::FileScopeId; use crate::semantic_index::scope::FileScopeId;
use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map}; use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map};
use crate::types::{KnownInstanceType, check_types}; use crate::types::{KnownClass, KnownInstanceType, UnionType, check_types};
use ruff_db::diagnostic::Diagnostic; use ruff_db::diagnostic::Diagnostic;
use ruff_db::files::{File, system_path_to_file}; use ruff_db::files::{File, system_path_to_file};
use ruff_db::system::DbWithWritableSystem as _; use ruff_db::system::DbWithWritableSystem as _;
@ -409,17 +409,17 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
def f(self): def f(self):
self.attr: int | None = None self.attr: int | None = None
"#, "#,
)?; )?;
db.write_dedented( db.write_dedented(
"/src/main.py", "/src/main.py",
r#" r#"
from mod import C from mod import C
x = C().attr x = C().attr
"#, "#,
)?; )?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
@ -430,10 +430,10 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
def f(self): def f(self):
self.attr: str | None = None self.attr: str | None = None
"#, "#,
)?; )?;
let events = { let events = {
@ -442,17 +442,22 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events() db.take_salsa_events()
}; };
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
// Add a comment; this should not trigger the type of `x` to be re-inferred // Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
def f(self): def f(self):
# a comment! # a comment!
self.attr: str | None = None self.attr: str | None = None
"#, "#,
)?; )?;
let events = { let events = {
@ -462,7 +467,12 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
db.take_salsa_events() db.take_salsa_events()
}; };
assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events); assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
Ok(()) Ok(())
} }
@ -487,19 +497,19 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
if random.choice([True, False]): if random.choice([True, False]):
attr: int = 42 attr: int = 42
else: else:
attr: None = None attr: None = None
"#, "#,
)?; )?;
db.write_dedented( db.write_dedented(
"/src/main.py", "/src/main.py",
r#" r#"
from mod import C from mod import C
x = C().attr x = C().attr
"#, "#,
)?; )?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
@ -510,12 +520,12 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
if random.choice([True, False]): if random.choice([True, False]):
attr: str = "42" attr: str = "42"
else: else:
attr: None = None attr: None = None
"#, "#,
)?; )?;
let events = { let events = {
@ -524,19 +534,24 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None"); assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events() db.take_salsa_events()
}; };
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
// Add a comment; this should not trigger the type of `x` to be re-inferred // Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
# comment # comment
if random.choice([True, False]): if random.choice([True, False]):
attr: str = "42" attr: str = "42"
else: else:
attr: None = None attr: None = None
"#, "#,
)?; )?;
let events = { let events = {
@ -546,7 +561,12 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
db.take_salsa_events() db.take_salsa_events()
}; };
assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events); assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
Ok(()) Ok(())
} }
@ -569,22 +589,22 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
def __init__(self): def __init__(self):
self.instance_attr: str = "24" self.instance_attr: str = "24"
@classmethod @classmethod
def method(cls): def method(cls):
cls.class_attr: int = 42 cls.class_attr: int = 42
"#, "#,
)?; )?;
db.write_dedented( db.write_dedented(
"/src/main.py", "/src/main.py",
r#" r#"
from mod import C from mod import C
C.method() C.method()
x = C().class_attr x = C().class_attr
"#, "#,
)?; )?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap(); let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
@ -595,14 +615,14 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
def __init__(self): def __init__(self):
self.instance_attr: str = "24" self.instance_attr: str = "24"
@classmethod @classmethod
def method(cls): def method(cls):
cls.class_attr: str = "42" cls.class_attr: str = "42"
"#, "#,
)?; )?;
let events = { let events = {
@ -611,21 +631,26 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str"); assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str");
db.take_salsa_events() db.take_salsa_events()
}; };
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events); assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
// Add a comment; this should not trigger the type of `x` to be re-inferred // Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented( db.write_dedented(
"/src/mod.py", "/src/mod.py",
r#" r#"
class C: class C:
def __init__(self): def __init__(self):
self.instance_attr: str = "24" self.instance_attr: str = "24"
@classmethod @classmethod
def method(cls): def method(cls):
# comment # comment
cls.class_attr: str = "42" cls.class_attr: str = "42"
"#, "#,
)?; )?;
let events = { let events = {
@ -635,7 +660,88 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
db.take_salsa_events() db.take_salsa_events()
}; };
assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events); assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
Ok(())
}
/// Inferring the result of a call-expression shouldn't need to re-run after
/// a trivial change to the function's file (e.g. by adding a docstring to the function).
#[test]
fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
return 5
"#,
)?;
db.write_dedented(
"src/bar.py",
r#"
from foo import foo
a = foo()
"#,
)?;
let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(foo_call),
&events,
);
// Add a docstring to foo to trigger a re-run.
// The bar-call site of foo should not be re-run because of that
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
"Computes a value"
return 5
"#,
)?;
db.clear_salsa_events();
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(foo_call),
&events,
);
Ok(()) Ok(())
} }

View file

@ -12,7 +12,7 @@ use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type; use crate::types::infer::infer_same_file_expression_type;
use crate::types::{ use crate::types::{
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType, ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType,
Truthiness, Type, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types, Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
}; };
use ruff_db::parsed::{ParsedModuleRef, parsed_module}; use ruff_db::parsed::{ParsedModuleRef, parsed_module};
@ -773,7 +773,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return None; return None;
} }
let inference = infer_expression_types(self.db, expression); let inference = infer_expression_types(self.db, expression, TypeContext::default());
let comparator_tuples = std::iter::once(&**left) let comparator_tuples = std::iter::once(&**left)
.chain(comparators) .chain(comparators)
@ -863,7 +863,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression: Expression<'db>, expression: Expression<'db>,
is_positive: bool, is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression); let inference = infer_expression_types(self.db, expression, TypeContext::default());
let callable_ty = inference.expression_type(&*expr_call.func); let callable_ty = inference.expression_type(&*expr_call.func);
@ -983,7 +983,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?; let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject); let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?; let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?;
Some(NarrowingConstraints::from_iter([(place, ty)])) Some(NarrowingConstraints::from_iter([(place, ty)]))
} }
@ -996,7 +997,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?; let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject); let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, value, self.module); let ty =
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
Some(NarrowingConstraints::from_iter([(place, ty)])) Some(NarrowingConstraints::from_iter([(place, ty)]))
} }
@ -1025,7 +1027,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression: Expression<'db>, expression: Expression<'db>,
is_positive: bool, is_positive: bool,
) -> Option<NarrowingConstraints<'db>> { ) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression); let inference = infer_expression_types(self.db, expression, TypeContext::default());
let mut sub_constraints = expr_bool_op let mut sub_constraints = expr_bool_op
.values .values
.iter() .iter()

View file

@ -9,7 +9,7 @@ use crate::Db;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey; use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::scope::ScopeId; use crate::semantic_index::scope::ScopeId;
use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker}; use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker};
use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types}; use crate::types::{Type, TypeCheckDiagnostics, TypeContext, infer_expression_types};
use crate::unpack::{UnpackKind, UnpackValue}; use crate::unpack::{UnpackKind, UnpackValue};
use super::context::InferContext; use super::context::InferContext;
@ -48,8 +48,9 @@ impl<'db, 'ast> Unpacker<'db, 'ast> {
"Unpacking target must be a list or tuple expression" "Unpacking target must be a list or tuple expression"
); );
let value_type = infer_expression_types(self.db(), value.expression()) let value_type =
.expression_type(value.expression().node_ref(self.db(), self.module())); infer_expression_types(self.db(), value.expression(), TypeContext::default())
.expression_type(value.expression().node_ref(self.db(), self.module()));
let value_type = match value.kind() { let value_type = match value.kind() {
UnpackKind::Assign => { UnpackKind::Assign => {