[ty] Integrate type context for bidirectional inference (#20337)
Some checks are pending
CI / mkdocs (push) Waiting to run
CI / formatter instabilities and black similarity (push) Blocked by required conditions
CI / Determine changes (push) Waiting to run
CI / cargo fmt (push) Waiting to run
CI / cargo clippy (push) Blocked by required conditions
CI / cargo test (linux) (push) Blocked by required conditions
CI / cargo test (linux, release) (push) Blocked by required conditions
CI / cargo test (windows) (push) Blocked by required conditions
CI / cargo test (wasm) (push) Blocked by required conditions
CI / cargo build (release) (push) Waiting to run
CI / test ruff-lsp (push) Blocked by required conditions
CI / cargo build (msrv) (push) Blocked by required conditions
CI / cargo fuzz build (push) Blocked by required conditions
CI / fuzz parser (push) Blocked by required conditions
CI / test scripts (push) Blocked by required conditions
CI / ecosystem (push) Blocked by required conditions
CI / Fuzz for new ty panics (push) Blocked by required conditions
CI / cargo shear (push) Blocked by required conditions
CI / python package (push) Waiting to run
CI / pre-commit (push) Waiting to run
CI / check playground (push) Blocked by required conditions
CI / benchmarks-instrumented (push) Blocked by required conditions
CI / benchmarks-walltime (push) Blocked by required conditions
[ty Playground] Release / publish (push) Waiting to run

## Summary

Adds the infrastructure necessary to perform bidirectional type
inference (https://github.com/astral-sh/ty/issues/168) without any
typing changes.
This commit is contained in:
Ibraheem Ahmed 2025-09-11 15:19:12 -04:00 committed by GitHub
parent c4cd5c00fd
commit 36888198a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 548 additions and 340 deletions

View file

@ -7,7 +7,7 @@ use ruff_python_ast::statement_visitor::{StatementVisitor, walk_stmt};
use ruff_python_ast::{self as ast};
use crate::semantic_index::{SemanticIndex, semantic_index};
use crate::types::{Truthiness, Type, infer_expression_types};
use crate::types::{Truthiness, Type, TypeContext, infer_expression_types};
use crate::{Db, ModuleName, resolve_module};
#[allow(clippy::ref_option)]
@ -182,7 +182,8 @@ impl<'db> DunderAllNamesCollector<'db> {
///
/// This function panics if `expr` was not marked as a standalone expression during semantic indexing.
fn standalone_expression_type(&self, expr: &ast::Expr) -> Type<'db> {
infer_expression_types(self.db, self.index.expression(expr)).expression_type(expr)
infer_expression_types(self.db, self.index.expression(expr), TypeContext::default())
.expression_type(expr)
}
/// Evaluate the given expression and return its truthiness.

View file

@ -208,8 +208,8 @@ use crate::semantic_index::predicate::{
Predicates, ScopedPredicateId,
};
use crate::types::{
IntersectionBuilder, Truthiness, Type, UnionBuilder, UnionType, infer_expression_type,
static_expression_truthiness,
IntersectionBuilder, Truthiness, Type, TypeContext, UnionBuilder, UnionType,
infer_expression_type, static_expression_truthiness,
};
/// A ternary formula that defines under what conditions a binding is visible. (A ternary formula
@ -328,10 +328,12 @@ fn singleton_to_type(db: &dyn Db, singleton: ruff_python_ast::Singleton) -> Type
fn pattern_kind_to_type<'db>(db: &'db dyn Db, kind: &PatternPredicateKind<'db>) -> Type<'db> {
match kind {
PatternPredicateKind::Singleton(singleton) => singleton_to_type(db, *singleton),
PatternPredicateKind::Value(value) => infer_expression_type(db, *value),
PatternPredicateKind::Value(value) => {
infer_expression_type(db, *value, TypeContext::default())
}
PatternPredicateKind::Class(class_expr, kind) => {
if kind.is_irrefutable() {
infer_expression_type(db, *class_expr)
infer_expression_type(db, *class_expr, TypeContext::default())
.to_instance(db)
.unwrap_or(Type::Never)
} else {
@ -718,7 +720,7 @@ impl ReachabilityConstraints {
) -> Truthiness {
match predicate_kind {
PatternPredicateKind::Value(value) => {
let value_ty = infer_expression_type(db, *value);
let value_ty = infer_expression_type(db, *value, TypeContext::default());
if subject_ty.is_single_valued(db) {
Truthiness::from(subject_ty.is_equivalent_to(db, value_ty))
@ -769,7 +771,8 @@ impl ReachabilityConstraints {
truthiness
}
PatternPredicateKind::Class(class_expr, kind) => {
let class_ty = infer_expression_type(db, *class_expr).to_instance(db);
let class_ty =
infer_expression_type(db, *class_expr, TypeContext::default()).to_instance(db);
class_ty.map_or(Truthiness::Ambiguous, |class_ty| {
if subject_ty.is_subtype_of(db, class_ty) {
@ -797,7 +800,7 @@ impl ReachabilityConstraints {
}
fn analyze_single_pattern_predicate(db: &dyn Db, predicate: PatternPredicate) -> Truthiness {
let subject_ty = infer_expression_type(db, predicate.subject(db));
let subject_ty = infer_expression_type(db, predicate.subject(db), TypeContext::default());
let narrowed_subject_ty = IntersectionBuilder::new(db)
.add_positive(subject_ty)
@ -837,7 +840,7 @@ impl ReachabilityConstraints {
// selection algorithm).
// Avoiding this on the happy-path is important because these constraints can be
// very large in number, since we add them on all statement level function calls.
let ty = infer_expression_type(db, callable);
let ty = infer_expression_type(db, callable, TypeContext::default());
// Short-circuit for well known types that are known not to return `Never` when called.
// Without the short-circuit, we've seen that threads keep blocking each other
@ -875,7 +878,7 @@ impl ReachabilityConstraints {
} else if all_overloads_return_never {
Truthiness::AlwaysTrue
} else {
let call_expr_ty = infer_expression_type(db, call_expr);
let call_expr_ty = infer_expression_type(db, call_expr, TypeContext::default());
if call_expr_ty.is_equivalent_to(db, Type::Never) {
Truthiness::AlwaysTrue
} else {

View file

@ -23,8 +23,8 @@ pub(crate) use self::cyclic::{CycleDetector, PairVisitor, TypeTransformer};
pub use self::diagnostic::TypeCheckDiagnostics;
pub(crate) use self::diagnostic::register_lints;
pub(crate) use self::infer::{
infer_deferred_types, infer_definition_types, infer_expression_type, infer_expression_types,
infer_scope_types, static_expression_truthiness,
TypeContext, infer_deferred_types, infer_definition_types, infer_expression_type,
infer_expression_types, infer_scope_types, static_expression_truthiness,
};
pub(crate) use self::signatures::{CallableSignature, Signature};
pub(crate) use self::subclass_of::{SubclassOfInner, SubclassOfType};
@ -10824,12 +10824,10 @@ static_assertions::assert_eq_size!(Type, [u8; 16]);
pub(crate) mod tests {
use super::*;
use crate::db::tests::{TestDbBuilder, setup_db};
use crate::place::{global_symbol, typing_extensions_symbol, typing_symbol};
use crate::place::{typing_extensions_symbol, typing_symbol};
use crate::semantic_index::FileScopeId;
use ruff_db::files::system_path_to_file;
use ruff_db::parsed::parsed_module;
use ruff_db::system::DbWithWritableSystem as _;
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::PythonVersion;
use test_case::test_case;
@ -10868,65 +10866,6 @@ pub(crate) mod tests {
);
}
/// Inferring the result of a call-expression shouldn't need to re-run after
/// a trivial change to the function's file (e.g. by adding a docstring to the function).
#[test]
fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
return 5
"#,
)?;
db.write_dedented(
"src/bar.py",
r#"
from foo import foo
a = foo()
"#,
)?;
let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
// Add a docstring to foo to trigger a re-run.
// The bar-call site of foo should not be re-run because of that
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
"Computes a value"
return 5
"#,
)?;
db.clear_salsa_events();
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(&db, infer_expression_types, foo_call, &events);
Ok(())
}
/// All other tests also make sure that `Type::Todo` works as expected. This particular
/// test makes sure that we handle `Todo` types correctly, even if they originate from
/// different sources.

View file

@ -28,9 +28,9 @@ use crate::types::{
ApplyTypeMappingVisitor, Binding, BoundSuperError, BoundSuperType, CallableType,
DataclassParams, DeprecatedInstance, FindLegacyTypeVarsVisitor, HasRelationToVisitor,
IsEquivalentVisitor, KnownInstanceType, ManualPEP695TypeAliasType, MaterializationKind,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeMapping,
TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind, TypedDictParams,
UnionBuilder, VarianceInferable, declaration_type, infer_definition_types,
NormalizedVisitor, PropertyInstanceType, StringLiteralType, TypeAliasType, TypeContext,
TypeMapping, TypeRelation, TypeVarBoundOrConstraints, TypeVarInstance, TypeVarKind,
TypedDictParams, UnionBuilder, VarianceInferable, declaration_type, infer_definition_types,
};
use crate::{
Db, FxIndexMap, FxOrderSet, Program,
@ -2926,7 +2926,11 @@ impl<'db> ClassLiteral<'db> {
// `self.SOME_CONSTANT: Final = 1`, infer the type from the value
// on the right-hand side.
let inferred_ty = infer_expression_type(db, index.expression(value));
let inferred_ty = infer_expression_type(
db,
index.expression(value),
TypeContext::default(),
);
return Place::bound(inferred_ty).with_qualifiers(all_qualifiers);
}
@ -3014,6 +3018,7 @@ impl<'db> ClassLiteral<'db> {
let inferred_ty = infer_expression_type(
db,
index.expression(assign.value(&module)),
TypeContext::default(),
);
union_of_inferred_types = union_of_inferred_types.add(inferred_ty);
@ -3041,6 +3046,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type(
db,
index.expression(for_stmt.iterable(&module)),
TypeContext::default(),
);
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty =
@ -3071,6 +3077,7 @@ impl<'db> ClassLiteral<'db> {
let context_ty = infer_expression_type(
db,
index.expression(with_item.context_expr(&module)),
TypeContext::default(),
);
let inferred_ty = if with_item.is_async() {
context_ty.aenter(db)
@ -3104,6 +3111,7 @@ impl<'db> ClassLiteral<'db> {
let iterable_ty = infer_expression_type(
db,
index.expression(comprehension.iterable(&module)),
TypeContext::default(),
);
// TODO: Potential diagnostics resulting from the iterable are currently not reported.
let inferred_ty =

View file

@ -171,11 +171,21 @@ fn deferred_cycle_initial<'db>(
/// Use rarely; only for cases where we'd otherwise risk double-inferring an expression: RHS of an
/// assignment, which might be unpacking/multi-target and thus part of multiple definitions, or a
/// type narrowing guard expression (e.g. if statement test node).
#[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn infer_expression_types<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> &'db ExpressionInference<'db> {
infer_expression_types_impl(db, InferExpression::new(db, expression, tcx))
}
#[salsa::tracked(returns(ref), cycle_fn=expression_cycle_recover, cycle_initial=expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn infer_expression_types_impl<'db>(
db: &'db dyn Db,
input: InferExpression<'db>,
) -> ExpressionInference<'db> {
let (expression, tcx) = (input.expression(db), input.tcx(db));
let file = expression.file(db);
let module = parsed_module(db, file).load(db);
let _span = tracing::trace_span!(
@ -188,8 +198,13 @@ pub(crate) fn infer_expression_types<'db>(
let index = semantic_index(db, file);
TypeInferenceBuilder::new(db, InferenceRegion::Expression(expression), index, &module)
.finish_expression()
TypeInferenceBuilder::new(
db,
InferenceRegion::Expression(expression, tcx),
index,
&module,
)
.finish_expression()
}
/// How many fixpoint iterations to allow before falling back to Divergent type.
@ -199,11 +214,11 @@ fn expression_cycle_recover<'db>(
db: &'db dyn Db,
_value: &ExpressionInference<'db>,
count: u32,
expression: Expression<'db>,
input: InferExpression<'db>,
) -> salsa::CycleRecoveryAction<ExpressionInference<'db>> {
if count == ITERATIONS_BEFORE_FALLBACK {
salsa::CycleRecoveryAction::Fallback(ExpressionInference::cycle_fallback(
expression.scope(db),
input.expression(db).scope(db),
))
} else {
salsa::CycleRecoveryAction::Iterate
@ -212,9 +227,9 @@ fn expression_cycle_recover<'db>(
fn expression_cycle_initial<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
input: InferExpression<'db>,
) -> ExpressionInference<'db> {
ExpressionInference::cycle_initial(expression.scope(db))
ExpressionInference::cycle_initial(input.expression(db).scope(db))
}
/// Infers the type of an `expression` that is guaranteed to be in the same file as the calling query.
@ -225,9 +240,10 @@ fn expression_cycle_initial<'db>(
pub(super) fn infer_same_file_expression_type<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
tcx: TypeContext<'db>,
parsed: &ParsedModuleRef,
) -> Type<'db> {
let inference = infer_expression_types(db, expression);
let inference = infer_expression_types(db, expression, tcx);
inference.expression_type(expression.node_ref(db, parsed))
}
@ -238,34 +254,108 @@ pub(super) fn infer_same_file_expression_type<'db>(
///
/// Use [`infer_same_file_expression_type`] if it is guaranteed that `expression` is in the same
/// to avoid unnecessary salsa ingredients. This is normally the case inside the `TypeInferenceBuilder`.
#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
pub(crate) fn infer_expression_type<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> {
let file = expression.file(db);
infer_expression_type_impl(db, InferExpression::new(db, expression, tcx))
}
#[salsa::tracked(cycle_fn=single_expression_cycle_recover, cycle_initial=single_expression_cycle_initial, heap_size=ruff_memory_usage::heap_size)]
fn infer_expression_type_impl<'db>(db: &'db dyn Db, input: InferExpression<'db>) -> Type<'db> {
let file = input.expression(db).file(db);
let module = parsed_module(db, file).load(db);
// It's okay to call the "same file" version here because we're inside a salsa query.
infer_same_file_expression_type(db, expression, &module)
let inference = infer_expression_types_impl(db, input);
inference.expression_type(input.expression(db).node_ref(db, &module))
}
fn single_expression_cycle_recover<'db>(
_db: &'db dyn Db,
_value: &Type<'db>,
_count: u32,
_expression: Expression<'db>,
_input: InferExpression<'db>,
) -> salsa::CycleRecoveryAction<Type<'db>> {
salsa::CycleRecoveryAction::Iterate
}
fn single_expression_cycle_initial<'db>(
_db: &'db dyn Db,
_expression: Expression<'db>,
_input: InferExpression<'db>,
) -> Type<'db> {
Type::Never
}
/// An `Expression` with an optional `TypeContext`.
///
/// This is a Salsa supertype used as the input to `infer_expression_types` to avoid
/// interning an `ExpressionWithContext` unnecessarily when no type context is provided.
#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq, salsa::Supertype, salsa::Update)]
enum InferExpression<'db> {
Bare(Expression<'db>),
WithContext(ExpressionWithContext<'db>),
}
impl<'db> InferExpression<'db> {
fn new(
db: &'db dyn Db,
expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> InferExpression<'db> {
if tcx.annotation.is_some() {
InferExpression::WithContext(ExpressionWithContext::new(db, expression, tcx))
} else {
// Drop the empty `TypeContext` to avoid the interning cost.
InferExpression::Bare(expression)
}
}
fn expression(self, db: &'db dyn Db) -> Expression<'db> {
match self {
InferExpression::Bare(expression) => expression,
InferExpression::WithContext(expression_with_context) => {
expression_with_context.expression(db)
}
}
}
fn tcx(self, db: &'db dyn Db) -> TypeContext<'db> {
match self {
InferExpression::Bare(_) => TypeContext::default(),
InferExpression::WithContext(expression_with_context) => {
expression_with_context.tcx(db)
}
}
}
}
/// An `Expression` with a `TypeContext`.
#[salsa::interned(debug, heap_size=ruff_memory_usage::heap_size)]
struct ExpressionWithContext<'db> {
expression: Expression<'db>,
tcx: TypeContext<'db>,
}
/// The type context for a given expression, namely the type annotation
/// in an annotated assignment.
///
/// Knowing the outer type context when inferring an expression can enable
/// more precise inference results, aka "bidirectional type inference".
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Hash, get_size2::GetSize)]
pub(crate) struct TypeContext<'db> {
annotation: Option<Type<'db>>,
}
impl<'db> TypeContext<'db> {
pub(crate) fn new(annotation: Type<'db>) -> Self {
Self {
annotation: Some(annotation),
}
}
}
/// Returns the statically-known truthiness of a given expression.
///
/// Returns [`Truthiness::Ambiguous`] in case any non-definitely bound places
@ -275,7 +365,7 @@ pub(crate) fn static_expression_truthiness<'db>(
db: &'db dyn Db,
expression: Expression<'db>,
) -> Truthiness {
let inference = infer_expression_types(db, expression);
let inference = infer_expression_types_impl(db, InferExpression::Bare(expression));
if !inference.all_places_definitely_bound() {
return Truthiness::Ambiguous;
@ -366,7 +456,7 @@ pub(crate) fn nearest_enclosing_class<'db>(
#[derive(Copy, Clone, Debug)]
pub(crate) enum InferenceRegion<'db> {
/// infer types for a standalone [`Expression`]
Expression(Expression<'db>),
Expression(Expression<'db>, TypeContext<'db>),
/// infer types for a [`Definition`]
Definition(Definition<'db>),
/// infer deferred types for a [`Definition`]
@ -378,7 +468,7 @@ pub(crate) enum InferenceRegion<'db> {
impl<'db> InferenceRegion<'db> {
fn scope(self, db: &'db dyn Db) -> ScopeId<'db> {
match self {
InferenceRegion::Expression(expression) => expression.scope(db),
InferenceRegion::Expression(expression, _) => expression.scope(db),
InferenceRegion::Definition(definition) | InferenceRegion::Deferred(definition) => {
definition.scope(db)
}

View file

@ -90,8 +90,8 @@ use crate::types::{
IntersectionBuilder, IntersectionType, KnownClass, KnownInstanceType, MemberLookupPolicy,
MetaclassCandidate, PEP695TypeAliasType, Parameter, ParameterForm, Parameters, SpecialFormType,
SubclassOfType, TrackedConstraintSet, Truthiness, Type, TypeAliasType, TypeAndQualifiers,
TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation, TypeVarInstance,
TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type,
TypeContext, TypeQualifiers, TypeVarBoundOrConstraintsEvaluation, TypeVarDefaultEvaluation,
TypeVarInstance, TypeVarKind, UnionBuilder, UnionType, binding_type, todo_type,
};
use crate::types::{ClassBase, add_inferred_python_version_hint_to_diagnostic};
use crate::unpack::{EvaluationMode, UnpackPosition};
@ -440,7 +440,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
InferenceRegion::Scope(scope) => self.infer_region_scope(scope),
InferenceRegion::Definition(definition) => self.infer_region_definition(definition),
InferenceRegion::Deferred(definition) => self.infer_region_deferred(definition),
InferenceRegion::Expression(expression) => self.infer_region_expression(expression),
InferenceRegion::Expression(expression, tcx) => {
self.infer_region_expression(expression, tcx);
}
}
}
@ -1221,10 +1223,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
fn infer_region_expression(&mut self, expression: Expression<'db>) {
fn infer_region_expression(&mut self, expression: Expression<'db>, tcx: TypeContext<'db>) {
match expression.kind(self.db()) {
ExpressionKind::Normal => {
self.infer_expression_impl(expression.node_ref(self.db(), self.module()));
self.infer_expression_impl(expression.node_ref(self.db(), self.module()), tcx);
}
ExpressionKind::TypeExpression => {
self.infer_type_expression(expression.node_ref(self.db(), self.module()));
@ -1435,7 +1437,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let declared_ty = if resolved_place.is_unbound() && !place_table.place(place_id).is_symbol()
{
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node {
let value_type = self.infer_maybe_standalone_expression(value);
let value_type =
self.infer_maybe_standalone_expression(value, TypeContext::default());
if let Place::Type(ty, Boundness::Bound) = value_type.member(db, attr).place {
// TODO: also consider qualifiers on the attribute
ty
@ -1448,8 +1451,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
},
) = node
{
let value_ty = self.infer_expression(value);
let slice_ty = self.infer_expression(slice);
let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx)
} else {
unwrap_declared_ty()
@ -1517,9 +1520,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
// In the following cases, the bound type may not be the same as the RHS value type.
if let AnyNodeRef::ExprAttribute(ast::ExprAttribute { value, attr, .. }) = node {
let value_ty = self
.try_expression_type(value)
.unwrap_or_else(|| self.infer_maybe_standalone_expression(value));
let value_ty = self.try_expression_type(value).unwrap_or_else(|| {
self.infer_maybe_standalone_expression(value, TypeContext::default())
});
// If the member is a data descriptor, the RHS value may differ from the value actually assigned.
if value_ty
.class_member(db, attr.id.clone())
@ -1532,7 +1535,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} else if let AnyNodeRef::ExprSubscript(ast::ExprSubscript { value, .. }) = node {
let value_ty = self
.try_expression_type(value)
.unwrap_or_else(|| self.infer_expression(value));
.unwrap_or_else(|| self.infer_expression(value, TypeContext::default()));
if !value_ty.is_typed_dict() && !is_safe_mutable_class(db, value_ty) {
bound_ty = declared_ty;
@ -1719,7 +1722,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
std::mem::replace(&mut self.deferred_state, in_stub.into());
let mut call_arguments =
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
let ty = self.infer_expression(splatted_value);
let ty = self.infer_expression(splatted_value, TypeContext::default());
self.store_expression_type(argument, ty);
ty
});
@ -1988,7 +1991,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}) => {
// If this is a call expression, we would have added a `ReturnsNever` constraint,
// meaning this will be a standalone expression.
self.infer_maybe_standalone_expression(value);
self.infer_maybe_standalone_expression(value, TypeContext::default());
}
ast::Stmt::If(if_statement) => self.infer_if_statement(if_statement),
ast::Stmt::Try(try_statement) => self.infer_try_statement(try_statement),
@ -2085,7 +2088,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.iter_non_variadic_params()
.filter_map(|param| param.default.as_deref())
{
self.infer_expression(default);
self.infer_expression(default, TypeContext::default());
}
// If there are type params, parameters and returns are evaluated in that scope, that is, in
@ -2517,7 +2520,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// and we don't need to run inference here
if type_params.is_none() {
for keyword in class_node.keywords() {
self.infer_expression(&keyword.value);
self.infer_expression(&keyword.value, TypeContext::default());
}
// Inference of bases deferred in stubs, or if any are string literals.
@ -2527,7 +2530,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let previous_typevar_binding_context =
self.typevar_binding_context.replace(definition);
for base in class_node.bases() {
self.infer_expression(base);
self.infer_expression(base, TypeContext::default());
}
self.typevar_binding_context = previous_typevar_binding_context;
}
@ -2552,9 +2555,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let previous_typevar_binding_context = self.typevar_binding_context.replace(definition);
for base in class.bases() {
if self.in_stub() {
self.infer_expression_with_state(base, DeferredExpressionState::Deferred);
self.infer_expression_with_state(
base,
TypeContext::default(),
DeferredExpressionState::Deferred,
);
} else {
self.infer_expression(base);
self.infer_expression(base, TypeContext::default());
}
}
self.typevar_binding_context = previous_typevar_binding_context;
@ -2565,7 +2572,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
type_alias: &ast::StmtTypeAlias,
definition: Definition<'db>,
) {
self.infer_expression(&type_alias.name);
self.infer_expression(&type_alias.name, TypeContext::default());
let rhs_scope = self
.index
@ -2597,7 +2604,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
elif_else_clauses,
} = if_statement;
let test_ty = self.infer_standalone_expression(test);
let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test);
@ -2614,7 +2621,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
} = clause;
if let Some(test) = &test {
let test_ty = self.infer_standalone_expression(test);
let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, test);
@ -2681,15 +2688,16 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `with not_context_manager as a.x: ...
builder
.infer_standalone_expression(context_expr)
.infer_standalone_expression(context_expr, TypeContext::default())
.enter(builder.db())
});
} else {
// Call into the context expression inference to validate that it evaluates
// to a valid context manager.
let context_expression_ty = self.infer_expression(&item.context_expr);
let context_expression_ty =
self.infer_expression(&item.context_expr, TypeContext::default());
self.infer_context_expression(&item.context_expr, context_expression_ty, *is_async);
self.infer_optional_expression(target);
self.infer_optional_expression(target, TypeContext::default());
}
}
@ -2713,7 +2721,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
unpacked.expression_type(target)
}
TargetKind::Single => {
let context_expr_ty = self.infer_standalone_expression(context_expr);
let context_expr_ty =
self.infer_standalone_expression(context_expr, TypeContext::default());
self.infer_context_expression(context_expr, context_expr_ty, with_item.is_async())
}
};
@ -2755,7 +2764,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_exception(&mut self, node: Option<&ast::Expr>, is_star: bool) -> Type<'db> {
// If there is no handled exception, it's invalid syntax;
// a diagnostic will have already been emitted
let node_ty = node.map_or(Type::unknown(), |ty| self.infer_expression(ty));
let node_ty = node.map_or(Type::unknown(), |ty| {
self.infer_expression(ty, TypeContext::default())
});
let type_base_exception = KnownClass::BaseException.to_subclass_of(self.db());
// If it's an `except*` handler, this won't actually be the type of the bound symbol;
@ -2947,7 +2958,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
name: _,
default,
} = node;
self.infer_optional_expression(default.as_deref());
self.infer_optional_expression(default.as_deref(), TypeContext::default());
let pep_695_todo = Type::Dynamic(DynamicType::TodoPEP695ParamSpec);
self.add_declaration_with_binding(
node.into(),
@ -2967,7 +2978,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
name: _,
default,
} = node;
self.infer_optional_expression(default.as_deref());
self.infer_optional_expression(default.as_deref(), TypeContext::default());
let pep_695_todo = todo_type!("PEP-695 TypeVarTuple definition types");
self.add_declaration_with_binding(
node.into(),
@ -2984,7 +2995,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
cases,
} = match_statement;
self.infer_standalone_expression(subject);
self.infer_standalone_expression(subject, TypeContext::default());
for case in cases {
let ast::MatchCase {
@ -2997,7 +3008,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.infer_match_pattern(pattern);
if let Some(guard) = guard.as_deref() {
let guard_ty = self.infer_standalone_expression(guard);
let guard_ty = self.infer_standalone_expression(guard, TypeContext::default());
if let Err(err) = guard_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, guard);
@ -3052,7 +3063,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// the subject expression: https://github.com/astral-sh/ruff/pull/13147#discussion_r1739424510
match pattern {
ast::Pattern::MatchValue(match_value) => {
self.infer_standalone_expression(&match_value.value);
self.infer_standalone_expression(&match_value.value, TypeContext::default());
}
ast::Pattern::MatchClass(match_class) => {
let ast::PatternMatchClass {
@ -3067,7 +3078,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for keyword in &arguments.keywords {
self.infer_nested_match_pattern(&keyword.pattern);
}
self.infer_standalone_expression(cls);
self.infer_standalone_expression(cls, TypeContext::default());
}
ast::Pattern::MatchOr(match_or) => {
for pattern in &match_or.patterns {
@ -3083,7 +3094,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_nested_match_pattern(&mut self, pattern: &ast::Pattern) {
match pattern {
ast::Pattern::MatchValue(match_value) => {
self.infer_maybe_standalone_expression(&match_value.value);
self.infer_maybe_standalone_expression(&match_value.value, TypeContext::default());
}
ast::Pattern::MatchSequence(match_sequence) => {
for pattern in &match_sequence.patterns {
@ -3099,7 +3110,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
rest: _,
} = match_mapping;
for key in keys {
self.infer_expression(key);
self.infer_expression(key, TypeContext::default());
}
for pattern in patterns {
self.infer_nested_match_pattern(pattern);
@ -3118,7 +3129,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for keyword in &arguments.keywords {
self.infer_nested_match_pattern(&keyword.pattern);
}
self.infer_maybe_standalone_expression(cls);
self.infer_maybe_standalone_expression(cls, TypeContext::default());
}
ast::Pattern::MatchAs(match_as) => {
if let Some(pattern) = &match_as.pattern {
@ -3144,7 +3155,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
for target in targets {
self.infer_target(target, value, |builder, value_expr| {
builder.infer_standalone_expression(value_expr)
builder.infer_standalone_expression(value_expr, TypeContext::default())
});
}
}
@ -3184,8 +3195,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _,
} = target;
let value_ty = self.infer_expression(value);
let slice_ty = self.infer_expression(slice);
let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice, TypeContext::default());
let db = self.db();
let context = &self.context;
@ -3878,7 +3889,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
) => {
self.store_expression_type(target, assigned_ty.unwrap_or(Type::unknown()));
let object_ty = self.infer_expression(object);
let object_ty = self.infer_expression(object, TypeContext::default());
if let Some(assigned_ty) = assigned_ty {
self.validate_attribute_assignment(
@ -3899,7 +3910,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
_ => {
// TODO: Remove this once we handle all possible assignment targets.
self.infer_expression(target);
self.infer_expression(target, TypeContext::default());
}
}
}
@ -3924,7 +3935,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
unpacked.expression_type(target)
}
TargetKind::Single => {
let value_ty = self.infer_standalone_expression(value);
let value_ty = self.infer_standalone_expression(value, TypeContext::default());
// `TYPE_CHECKING` is a special variable that should only be assigned `False`
// at runtime, but is always considered `True` in type checking.
@ -3988,12 +3999,15 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
if let Some(value) = value {
self.infer_maybe_standalone_expression(value);
self.infer_maybe_standalone_expression(
value,
TypeContext::new(annotated.inner_type()),
);
}
// If we have an annotated assignment like `self.attr: int = 1`, we still need to
// do type inference on the `self.attr` target to get types for all sub-expressions.
self.infer_expression(target);
self.infer_expression(target, TypeContext::default());
// But here we explicitly overwrite the type for the overall `self.attr` node with
// the annotated type. We do no use `store_expression_type` here, because it checks
@ -4080,7 +4094,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
debug_assert!(PlaceExpr::try_from_expr(target).is_some());
if let Some(value) = value {
let inferred_ty = self.infer_maybe_standalone_expression(value);
let inferred_ty = self
.infer_maybe_standalone_expression(value, TypeContext::new(declared.inner_type()));
let mut inferred_ty = if target
.as_name_expr()
.is_some_and(|name| &name.id == "TYPE_CHECKING")
@ -4236,9 +4251,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
self.store_expression_type(target, previous_value);
previous_value
}
_ => self.infer_expression(target),
_ => self.infer_expression(target, TypeContext::default()),
};
let value_type = self.infer_expression(value);
let value_type = self.infer_expression(value, TypeContext::default());
self.infer_augmented_op(assignment, target_type, value_type)
}
@ -4263,7 +4278,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// but only if the target is a name. We should report a diagnostic here if the target isn't a name:
// `for a.x in not_iterable: ...
builder
.infer_standalone_expression(iter_expr)
.infer_standalone_expression(iter_expr, TypeContext::default())
.iterate(builder.db())
.homogeneous_element_type(builder.db())
});
@ -4290,7 +4305,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
unpacked.expression_type(target)
}
TargetKind::Single => {
let iterable_type = self.infer_standalone_expression(iterable);
let iterable_type =
self.infer_standalone_expression(iterable, TypeContext::default());
iterable_type
.try_iterate_with_mode(
@ -4318,7 +4334,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
orelse,
} = while_statement;
let test_ty = self.infer_standalone_expression(test);
let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test);
@ -4500,13 +4516,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
msg,
} = assert;
let test_ty = self.infer_standalone_expression(test);
let test_ty = self.infer_standalone_expression(test, TypeContext::default());
if let Err(err) = test_ty.try_bool(self.db()) {
err.report_diagnostic(&self.context, &**test);
}
self.infer_optional_expression(msg.as_deref());
self.infer_optional_expression(msg.as_deref(), TypeContext::default());
}
fn infer_raise_statement(&mut self, raise: &ast::StmtRaise) {
@ -4526,7 +4542,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
UnionType::from_elements(self.db(), [can_be_raised, Type::none(self.db())]);
if let Some(raised) = exc {
let raised_type = self.infer_expression(raised);
let raised_type = self.infer_expression(raised, TypeContext::default());
if !raised_type.is_assignable_to(self.db(), can_be_raised) {
report_invalid_exception_raised(&self.context, raised, raised_type);
@ -4534,7 +4550,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
if let Some(cause) = cause {
let cause_type = self.infer_expression(cause);
let cause_type = self.infer_expression(cause, TypeContext::default());
if !cause_type.is_assignable_to(self.db(), can_be_exception_cause) {
report_invalid_exception_cause(&self.context, cause, cause_type);
@ -4740,7 +4756,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
fn infer_return_statement(&mut self, ret: &ast::StmtReturn) {
if let Some(ty) = self.infer_optional_expression(ret.value.as_deref()) {
if let Some(ty) =
self.infer_optional_expression(ret.value.as_deref(), TypeContext::default())
{
let range = ret
.value
.as_ref()
@ -4758,7 +4776,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
targets,
} = delete;
for target in targets {
self.infer_expression(target);
self.infer_expression(target, TypeContext::default());
}
}
@ -4898,7 +4916,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
expression,
} = decorator;
self.infer_expression(expression)
self.infer_expression(expression, TypeContext::default())
}
fn infer_argument_types<'a>(
@ -4920,7 +4938,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::ArgOrKeyword::Arg(arg) => arg,
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
};
let ty = self.infer_argument_type(argument, form);
let ty = self.infer_argument_type(argument, form, TypeContext::default());
*argument_type = Some(ty);
}
}
@ -4929,58 +4947,73 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
&mut self,
ast_argument: &ast::Expr,
form: Option<ParameterForm>,
tcx: TypeContext<'db>,
) -> Type<'db> {
match form {
None | Some(ParameterForm::Value) => self.infer_expression(ast_argument),
None | Some(ParameterForm::Value) => self.infer_expression(ast_argument, tcx),
Some(ParameterForm::Type) => self.infer_type_expression(ast_argument),
}
}
fn infer_optional_expression(&mut self, expression: Option<&ast::Expr>) -> Option<Type<'db>> {
expression.map(|expr| self.infer_expression(expr))
fn infer_optional_expression(
&mut self,
expression: Option<&ast::Expr>,
tcx: TypeContext<'db>,
) -> Option<Type<'db>> {
expression.map(|expr| self.infer_expression(expr, tcx))
}
#[track_caller]
fn infer_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
fn infer_expression(&mut self, expression: &ast::Expr, tcx: TypeContext<'db>) -> Type<'db> {
debug_assert!(
!self.index.is_standalone_expression(expression),
"Calling `self.infer_expression` on a standalone-expression is not allowed because it can lead to double-inference. Use `self.infer_standalone_expression` instead."
);
self.infer_expression_impl(expression)
self.infer_expression_impl(expression, tcx)
}
fn infer_expression_with_state(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
state: DeferredExpressionState,
) -> Type<'db> {
let previous_deferred_state = std::mem::replace(&mut self.deferred_state, state);
let ty = self.infer_expression(expression);
let ty = self.infer_expression(expression, tcx);
self.deferred_state = previous_deferred_state;
ty
}
fn infer_maybe_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
fn infer_maybe_standalone_expression(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
if let Some(standalone_expression) = self.index.try_expression(expression) {
self.infer_standalone_expression_impl(expression, standalone_expression)
self.infer_standalone_expression_impl(expression, standalone_expression, tcx)
} else {
self.infer_expression(expression)
self.infer_expression(expression, tcx)
}
}
#[track_caller]
fn infer_standalone_expression(&mut self, expression: &ast::Expr) -> Type<'db> {
fn infer_standalone_expression(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
let standalone_expression = self.index.expression(expression);
self.infer_standalone_expression_impl(expression, standalone_expression)
self.infer_standalone_expression_impl(expression, standalone_expression, tcx)
}
fn infer_standalone_expression_impl(
&mut self,
expression: &ast::Expr,
standalone_expression: Expression<'db>,
tcx: TypeContext<'db>,
) -> Type<'db> {
let types = infer_expression_types(self.db(), standalone_expression);
let types = infer_expression_types(self.db(), standalone_expression, tcx);
self.extend_expression(types);
// Instead of calling `self.expression_type(expr)` after extending here, we get
@ -4990,7 +5023,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
types.expression_type(expression)
}
fn infer_expression_impl(&mut self, expression: &ast::Expr) -> Type<'db> {
fn infer_expression_impl(
&mut self,
expression: &ast::Expr,
tcx: TypeContext<'db>,
) -> Type<'db> {
let ty = match expression {
ast::Expr::NoneLiteral(ast::ExprNoneLiteral {
range: _,
@ -5005,10 +5042,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::FString(fstring) => self.infer_fstring_expression(fstring),
ast::Expr::TString(tstring) => self.infer_tstring_expression(tstring),
ast::Expr::EllipsisLiteral(literal) => self.infer_ellipsis_literal_expression(literal),
ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple),
ast::Expr::List(list) => self.infer_list_expression(list),
ast::Expr::Set(set) => self.infer_set_expression(set),
ast::Expr::Dict(dict) => self.infer_dict_expression(dict),
ast::Expr::Tuple(tuple) => self.infer_tuple_expression(tuple, tcx),
ast::Expr::List(list) => self.infer_list_expression(list, tcx),
ast::Expr::Set(set) => self.infer_set_expression(set, tcx),
ast::Expr::Dict(dict) => self.infer_dict_expression(dict, tcx),
ast::Expr::Generator(generator) => self.infer_generator_expression(generator),
ast::Expr::ListComp(listcomp) => self.infer_list_comprehension_expression(listcomp),
ast::Expr::DictComp(dictcomp) => self.infer_dict_comprehension_expression(dictcomp),
@ -5024,7 +5061,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ast::Expr::Named(named) => self.infer_named_expression(named),
ast::Expr::If(if_expression) => self.infer_if_expression(if_expression),
ast::Expr::Lambda(lambda_expression) => self.infer_lambda_expression(lambda_expression),
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression),
ast::Expr::Call(call_expression) => self.infer_call_expression(call_expression, tcx),
ast::Expr::Starred(starred) => self.infer_starred_expression(starred),
ast::Expr::Yield(yield_expression) => self.infer_yield_expression(yield_expression),
ast::Expr::YieldFrom(yield_from) => self.infer_yield_from_expression(yield_from),
@ -5038,7 +5075,6 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ty
}
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
if self.deferred_state.in_string_annotation() {
// Avoid storing the type of expressions that are part of a string annotation because
@ -5120,11 +5156,14 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
conversion,
format_spec,
} = expression;
let ty = self.infer_expression(expression);
let ty = self.infer_expression(expression, TypeContext::default());
if let Some(format_spec) = format_spec {
for element in format_spec.elements.interpolations() {
self.infer_expression(&element.expression);
self.infer_expression(
&element.expression,
TypeContext::default(),
);
}
}
@ -5166,10 +5205,10 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
format_spec,
..
} = tstring_interpolation_element;
self.infer_expression(expression);
self.infer_expression(expression, TypeContext::default());
if let Some(format_spec) = format_spec {
for element in format_spec.elements.interpolations() {
self.infer_expression(&element.expression);
self.infer_expression(&element.expression, TypeContext::default());
}
}
}
@ -5187,7 +5226,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
KnownClass::EllipsisType.to_instance(self.db())
}
fn infer_tuple_expression(&mut self, tuple: &ast::ExprTuple) -> Type<'db> {
fn infer_tuple_expression(
&mut self,
tuple: &ast::ExprTuple,
_tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprTuple {
range: _,
node_index: _,
@ -5199,7 +5242,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let db = self.db();
let divergent = Type::divergent(self.scope());
let element_types = elts.iter().map(|element| {
let element_type = self.infer_expression(element);
// TODO: Use the type context for more precise inference.
let element_type = self.infer_expression(element, TypeContext::default());
if element_type.has_divergent_type(self.db(), divergent) {
divergent
} else {
@ -5210,7 +5254,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::heterogeneous_tuple(db, element_types)
}
fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> {
fn infer_list_expression(&mut self, list: &ast::ExprList, _tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprList {
range: _,
node_index: _,
@ -5218,38 +5262,41 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _,
} = list;
// TODO: Use the type context for more precise inference.
for elt in elts {
self.infer_expression(elt);
self.infer_expression(elt, TypeContext::default());
}
KnownClass::List
.to_specialized_instance(self.db(), [todo_type!("list literal element type")])
}
fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> {
fn infer_set_expression(&mut self, set: &ast::ExprSet, _tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprSet {
range: _,
node_index: _,
elts,
} = set;
// TODO: Use the type context for more precise inference.
for elt in elts {
self.infer_expression(elt);
self.infer_expression(elt, TypeContext::default());
}
KnownClass::Set.to_specialized_instance(self.db(), [todo_type!("set literal element type")])
}
fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> {
fn infer_dict_expression(&mut self, dict: &ast::ExprDict, _tcx: TypeContext<'db>) -> Type<'db> {
let ast::ExprDict {
range: _,
node_index: _,
items,
} = dict;
// TODO: Use the type context for more precise inference.
for item in items {
self.infer_optional_expression(item.key.as_ref());
self.infer_expression(&item.value);
self.infer_optional_expression(item.key.as_ref(), TypeContext::default());
self.infer_expression(&item.value, TypeContext::default());
}
KnownClass::Dict.to_specialized_instance(
@ -5260,14 +5307,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
],
)
}
/// Infer the type of the `iter` expression of the first comprehension.
fn infer_first_comprehension_iter(&mut self, comprehensions: &[ast::Comprehension]) {
let mut comprehensions_iter = comprehensions.iter();
let Some(first_comprehension) = comprehensions_iter.next() else {
unreachable!("Comprehension must contain at least one generator");
};
self.infer_standalone_expression(&first_comprehension.iter);
self.infer_standalone_expression(&first_comprehension.iter, TypeContext::default());
}
fn infer_generator_expression(&mut self, generator: &ast::ExprGenerator) -> Type<'db> {
@ -5348,7 +5394,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
parenthesized: _,
} = generator;
self.infer_expression(elt);
self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators);
}
@ -5360,7 +5406,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators,
} = listcomp;
self.infer_expression(elt);
self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators);
}
@ -5373,8 +5419,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators,
} = dictcomp;
self.infer_expression(key);
self.infer_expression(value);
self.infer_expression(key, TypeContext::default());
self.infer_expression(value, TypeContext::default());
self.infer_comprehensions(generators);
}
@ -5386,7 +5432,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
generators,
} = setcomp;
self.infer_expression(elt);
self.infer_expression(elt, TypeContext::default());
self.infer_comprehensions(generators);
}
@ -5419,16 +5465,17 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
infer_same_file_expression_type(
builder.db(),
builder.index.expression(iter_expr),
TypeContext::default(),
builder.module(),
)
} else {
builder.infer_standalone_expression(iter_expr)
builder.infer_standalone_expression(iter_expr, TypeContext::default())
}
.iterate(builder.db())
.homogeneous_element_type(builder.db())
});
for expr in ifs {
self.infer_standalone_expression(expr);
self.infer_standalone_expression(expr, TypeContext::default());
}
}
@ -5442,7 +5489,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
let mut infer_iterable_type = || {
let expression = self.index.expression(iterable);
let result = infer_expression_types(self.db(), expression);
let result = infer_expression_types(self.db(), expression, TypeContext::default());
// Two things are different if it's the first comprehension:
// (1) We must lookup the `ScopedExpressionId` of the iterable expression in the outer scope,
@ -5496,8 +5543,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
result.binding_type(definition)
} else {
// For syntactically invalid targets, we still need to run type inference:
self.infer_expression(&named.target);
self.infer_expression(&named.value);
self.infer_expression(&named.target, TypeContext::default());
self.infer_expression(&named.value, TypeContext::default());
Type::unknown()
}
}
@ -5514,8 +5561,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
value,
} = named;
let value_ty = self.infer_expression(value);
self.infer_expression(target);
let value_ty = self.infer_expression(value, TypeContext::default());
self.infer_expression(target, TypeContext::default());
self.add_binding(named.into(), definition, value_ty);
@ -5531,9 +5578,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
orelse,
} = if_expression;
let test_ty = self.infer_standalone_expression(test);
let body_ty = self.infer_expression(body);
let orelse_ty = self.infer_expression(orelse);
let test_ty = self.infer_standalone_expression(test, TypeContext::default());
let body_ty = self.infer_expression(body, TypeContext::default());
let orelse_ty = self.infer_expression(orelse, TypeContext::default());
match test_ty.try_bool(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, &**test);
@ -5546,7 +5593,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
fn infer_lambda_body(&mut self, lambda_expression: &ast::ExprLambda) {
self.infer_expression(&lambda_expression.body);
self.infer_expression(&lambda_expression.body, TypeContext::default());
}
fn infer_lambda_expression(&mut self, lambda_expression: &ast::ExprLambda) -> Type<'db> {
@ -5564,7 +5611,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|param| {
let mut parameter = Parameter::positional_only(Some(param.name().id.clone()));
if let Some(default) = param.default() {
parameter = parameter.with_default_type(self.infer_expression(default));
parameter = parameter.with_default_type(
self.infer_expression(default, TypeContext::default()),
);
}
parameter
})
@ -5575,7 +5624,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|param| {
let mut parameter = Parameter::positional_or_keyword(param.name().id.clone());
if let Some(default) = param.default() {
parameter = parameter.with_default_type(self.infer_expression(default));
parameter = parameter.with_default_type(
self.infer_expression(default, TypeContext::default()),
);
}
parameter
})
@ -5590,7 +5641,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.map(|param| {
let mut parameter = Parameter::keyword_only(param.name().id.clone());
if let Some(default) = param.default() {
parameter = parameter.with_default_type(self.infer_expression(default));
parameter = parameter.with_default_type(
self.infer_expression(default, TypeContext::default()),
);
}
parameter
})
@ -5618,7 +5671,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
CallableType::function_like(self.db(), Signature::new(parameters, Some(Type::unknown())))
}
fn infer_call_expression(&mut self, call_expression: &ast::ExprCall) -> Type<'db> {
fn infer_call_expression(
&mut self,
call_expression: &ast::ExprCall,
_tcx: TypeContext<'db>,
) -> Type<'db> {
let ast::ExprCall {
range: _,
node_index: _,
@ -5631,12 +5688,13 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
// are assignable to any parameter annotations.
let mut call_arguments =
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
let ty = self.infer_expression(splatted_value);
let ty = self.infer_expression(splatted_value, TypeContext::default());
self.store_expression_type(argument, ty);
ty
});
let callable_type = self.infer_maybe_standalone_expression(func);
// TODO: Use the type context for more precise inference.
let callable_type = self.infer_maybe_standalone_expression(func, TypeContext::default());
// Special handling for `TypedDict` method calls
if let ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) = func.as_ref() {
@ -5881,7 +5939,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
ctx: _,
} = starred;
let iterable_type = self.infer_expression(value);
let iterable_type = self.infer_expression(value, TypeContext::default());
iterable_type
.try_iterate(self.db())
.map(|tuple| tuple.homogeneous_element_type(self.db()))
@ -5900,7 +5958,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _,
value,
} = yield_expression;
self.infer_optional_expression(value.as_deref());
self.infer_optional_expression(value.as_deref(), TypeContext::default());
todo_type!("yield expressions")
}
@ -5911,7 +5969,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
value,
} = yield_from;
let iterable_type = self.infer_expression(value);
let iterable_type = self.infer_expression(value, TypeContext::default());
iterable_type
.try_iterate(self.db())
.map(|tuple| tuple.homogeneous_element_type(self.db()))
@ -5931,7 +5989,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _,
value,
} = await_expression;
let expr_type = self.infer_expression(value);
let expr_type = self.infer_expression(value, TypeContext::default());
expr_type.try_await(self.db()).unwrap_or_else(|err| {
err.report_diagnostic(&self.context, expr_type, value.as_ref().into());
Type::unknown()
@ -6576,7 +6634,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
fn infer_attribute_load(&mut self, attribute: &ast::ExprAttribute) -> Type<'db> {
let ast::ExprAttribute { value, attr, .. } = attribute;
let value_type = self.infer_maybe_standalone_expression(value);
let value_type = self.infer_maybe_standalone_expression(value, TypeContext::default());
let db = self.db();
let mut constraint_keys = vec![];
@ -6687,7 +6745,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match ctx {
ExprContext::Load => self.infer_attribute_load(attribute),
ExprContext::Store => {
self.infer_expression(value);
self.infer_expression(value, TypeContext::default());
Type::Never
}
ExprContext::Del => {
@ -6695,7 +6753,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::Never
}
ExprContext::Invalid => {
self.infer_expression(value);
self.infer_expression(value, TypeContext::default());
Type::unknown()
}
}
@ -6709,7 +6767,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
operand,
} = unary;
let operand_type = self.infer_expression(operand);
let operand_type = self.infer_expression(operand, TypeContext::default());
self.infer_unary_expression_type(*op, operand_type, unary)
}
@ -6830,8 +6888,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
node_index: _,
} = binary;
let left_ty = self.infer_expression(left);
let right_ty = self.infer_expression(right);
let left_ty = self.infer_expression(left, TypeContext::default());
let right_ty = self.infer_expression(right, TypeContext::default());
self.infer_binary_expression_type(binary.into(), false, left_ty, right_ty, *op)
.unwrap_or_else(|| {
@ -7276,9 +7334,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
values.iter().enumerate(),
|builder, (index, value)| {
let ty = if index == values.len() - 1 {
builder.infer_expression(value)
builder.infer_expression(value, TypeContext::default())
} else {
builder.infer_standalone_expression(value)
builder.infer_standalone_expression(value, TypeContext::default())
};
(ty, value.range())
@ -7359,7 +7417,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
comparators,
} = compare;
self.infer_expression(left);
self.infer_expression(left, TypeContext::default());
// https://docs.python.org/3/reference/expressions.html#comparisons
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
@ -7376,7 +7434,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
.zip(ops),
|builder, ((left, right), op)| {
let left_ty = builder.expression_type(left);
let right_ty = builder.infer_expression(right);
let right_ty = builder.infer_expression(right, TypeContext::default());
let range = TextRange::new(left.start(), right.end());
@ -8143,8 +8201,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
match ctx {
ExprContext::Load => self.infer_subscript_load(subscript),
ExprContext::Store => {
let value_ty = self.infer_expression(value);
let slice_ty = self.infer_expression(slice);
let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
Type::Never
}
@ -8153,8 +8211,8 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
Type::Never
}
ExprContext::Invalid => {
let value_ty = self.infer_expression(value);
let slice_ty = self.infer_expression(slice);
let value_ty = self.infer_expression(value, TypeContext::default());
let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
Type::unknown()
}
@ -8169,7 +8227,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
slice,
ctx,
} = subscript;
let value_ty = self.infer_expression(value);
let value_ty = self.infer_expression(value, TypeContext::default());
let mut constraint_keys = vec![];
// If `value` is a valid reference, we attempt type narrowing by assignment.
@ -8183,7 +8241,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
if let Place::Type(ty, Boundness::Bound) = place.place {
// Even if we can obtain the subscript type based on the assignments, we still perform default type inference
// (to store the expression type and to report errors).
let slice_ty = self.infer_expression(slice);
let slice_ty = self.infer_expression(slice, TypeContext::default());
self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
return ty;
}
@ -8228,7 +8286,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
}
}
let slice_ty = self.infer_expression(slice);
let slice_ty = self.infer_expression(slice, TypeContext::default());
let result_ty = self.infer_subscript_expression_types(subscript, value_ty, slice_ty, *ctx);
self.narrow_expr_with_applicable_constraints(subscript, result_ty, &constraint_keys)
}
@ -8767,9 +8825,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
step,
} = slice;
let ty_lower = self.infer_optional_expression(lower.as_deref());
let ty_upper = self.infer_optional_expression(upper.as_deref());
let ty_step = self.infer_optional_expression(step.as_deref());
let ty_lower = self.infer_optional_expression(lower.as_deref(), TypeContext::default());
let ty_upper = self.infer_optional_expression(upper.as_deref(), TypeContext::default());
let ty_step = self.infer_optional_expression(step.as_deref(), TypeContext::default());
let type_to_slice_argument = |ty: Option<Type<'db>>| match ty {
Some(ty @ (Type::IntLiteral(_) | Type::BooleanLiteral(_))) => SliceArg::Arg(ty),

View file

@ -6,7 +6,7 @@ use crate::types::string_annotation::{
BYTE_STRING_TYPE_ANNOTATION, FSTRING_TYPE_ANNOTATION, parse_string_annotation,
};
use crate::types::{
KnownClass, SpecialFormType, Type, TypeAndQualifiers, TypeQualifiers, todo_type,
KnownClass, SpecialFormType, Type, TypeAndQualifiers, TypeContext, TypeQualifiers, todo_type,
};
/// Annotation expressions.
@ -122,7 +122,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
},
ast::Expr::Subscript(subscript @ ast::ExprSubscript { value, slice, .. }) => {
let value_ty = self.infer_expression(value);
let value_ty = self.infer_expression(value, TypeContext::default());
let slice = &**slice;
@ -141,7 +141,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
if let [inner_annotation, metadata @ ..] = &arguments[..] {
for element in metadata {
self.infer_expression(element);
self.infer_expression(element, TypeContext::default());
}
let inner_annotation_ty =
@ -151,7 +151,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
inner_annotation_ty
} else {
for argument in arguments {
self.infer_expression(argument);
self.infer_expression(argument, TypeContext::default());
}
self.store_expression_type(slice, Type::unknown());
TypeAndQualifiers::unknown()

View file

@ -14,7 +14,7 @@ use crate::types::visitor::any_over_type;
use crate::types::{
CallableType, DynamicType, IntersectionBuilder, KnownClass, KnownInstanceType,
LintDiagnosticGuard, Parameter, Parameters, SpecialFormType, SubclassOfType, Type,
TypeAliasType, TypeIsType, UnionBuilder, UnionType, todo_type,
TypeAliasType, TypeContext, TypeIsType, UnionBuilder, UnionType, todo_type,
};
/// Type expressions
@ -114,7 +114,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
node_index: _,
} = subscript;
let value_ty = self.infer_expression(value);
let value_ty = self.infer_expression(value, TypeContext::default());
self.infer_subscript_type_expression_no_store(subscript, slice, value_ty)
}
@ -324,7 +324,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::Dict(dict) => {
self.infer_dict_expression(dict);
self.infer_dict_expression(dict, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Dict literals are not allowed in type expressions"),
@ -333,7 +333,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::Set(set) => {
self.infer_set_expression(set);
self.infer_set_expression(set, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Set literals are not allowed in type expressions"),
@ -414,7 +414,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
}
ast::Expr::Call(call_expr) => {
self.infer_call_expression(call_expr);
self.infer_call_expression(call_expr, TypeContext::default());
self.report_invalid_type_expression(
expression,
format_args!("Function calls are not allowed in type expressions"),
@ -544,7 +544,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let value_ty = if builder.deferred_state.in_string_annotation() {
// Using `.expression_type` does not work in string annotations, because
// we do not store types for sub-expressions. Re-infer the type here.
builder.infer_expression(value)
builder.infer_expression(value, TypeContext::default())
} else {
builder.expression_type(value)
};
@ -559,7 +559,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
match tuple_slice {
ast::Expr::Tuple(elements) => {
if let [element, ellipsis @ ast::Expr::EllipsisLiteral(_)] = &*elements.elts {
self.infer_expression(ellipsis);
self.infer_expression(ellipsis, TypeContext::default());
let result =
TupleType::homogeneous(self.db(), self.infer_type_expression(element));
self.store_expression_type(tuple_slice, Type::tuple(Some(result)));
@ -617,7 +617,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
fn infer_subclass_of_type_expression(&mut self, slice: &ast::Expr) -> Type<'db> {
match slice {
ast::Expr::Name(_) | ast::Expr::Attribute(_) => {
let name_ty = self.infer_expression(slice);
let name_ty = self.infer_expression(slice, TypeContext::default());
match name_ty {
Type::ClassLiteral(class_literal) => {
if class_literal.is_protocol(self.db()) {
@ -663,7 +663,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
slice: parameters,
..
}) => {
let parameters_ty = match self.infer_expression(value) {
let parameters_ty = match self.infer_expression(value, TypeContext::default()) {
Type::SpecialForm(SpecialFormType::Union) => match &**parameters {
ast::Expr::Tuple(tuple) => {
let ty = UnionType::from_elements_leave_aliases(
@ -713,7 +713,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
// `infer_expression` (instead of `infer_type_expression`) here to avoid
// false-positive `invalid-type-form` diagnostics (`1` is not a valid type
// expression).
self.infer_expression(&subscript.slice);
self.infer_expression(&subscript.slice, TypeContext::default());
Type::unknown()
}
Type::SpecialForm(special_form) => {
@ -912,14 +912,14 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let [type_expr, metadata @ ..] = &arguments[..] else {
for argument in arguments {
self.infer_expression(argument);
self.infer_expression(argument, TypeContext::default());
}
self.store_expression_type(arguments_slice, Type::unknown());
return Type::unknown();
};
for element in metadata {
self.infer_expression(element);
self.infer_expression(element, TypeContext::default());
}
let ty = self.infer_type_expression(type_expr);
@ -1107,7 +1107,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
let num_arguments = arguments.len();
let type_of_type = if num_arguments == 1 {
// N.B. This uses `infer_expression` rather than `infer_type_expression`
self.infer_expression(&arguments[0])
self.infer_expression(&arguments[0], TypeContext::default())
} else {
for argument in arguments {
self.infer_type_expression(argument);
@ -1137,7 +1137,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
if num_arguments != 1 {
for argument in arguments {
self.infer_expression(argument);
self.infer_expression(argument, TypeContext::default());
}
report_invalid_argument_number_to_special_form(
&self.context,
@ -1152,7 +1152,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
return Type::unknown();
}
let argument_type = self.infer_expression(&arguments[0]);
let argument_type = self.infer_expression(&arguments[0], TypeContext::default());
let bindings = argument_type.bindings(db);
// SAFETY: This is enforced by the constructor methods on `Bindings` even in
@ -1362,7 +1362,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
Type::tuple(self.infer_tuple_type_expression(arguments_slice))
}
SpecialFormType::Generic | SpecialFormType::Protocol => {
self.infer_expression(arguments_slice);
self.infer_expression(arguments_slice, TypeContext::default());
if let Some(builder) = self.context.report_lint(&INVALID_TYPE_FORM, subscript) {
builder.into_diagnostic(format_args!(
"`{special_form}` is not allowed in type expressions",
@ -1380,7 +1380,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
Ok(match parameters {
// TODO handle type aliases
ast::Expr::Subscript(ast::ExprSubscript { value, slice, .. }) => {
let value_ty = self.infer_expression(value);
let value_ty = self.infer_expression(value, TypeContext::default());
if matches!(value_ty, Type::SpecialForm(SpecialFormType::Literal)) {
let ty = self.infer_literal_parameter_type(slice)?;
@ -1389,7 +1389,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
self.store_expression_type(parameters, ty);
ty
} else {
self.infer_expression(slice);
self.infer_expression(slice, TypeContext::default());
self.store_expression_type(parameters, Type::unknown());
return Err(vec![parameters]);
@ -1426,13 +1426,13 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
literal @ (ast::Expr::StringLiteral(_)
| ast::Expr::BytesLiteral(_)
| ast::Expr::BooleanLiteral(_)
| ast::Expr::NoneLiteral(_)) => self.infer_expression(literal),
| ast::Expr::NoneLiteral(_)) => self.infer_expression(literal, TypeContext::default()),
literal @ ast::Expr::NumberLiteral(number) if number.value.is_int() => {
self.infer_expression(literal)
self.infer_expression(literal, TypeContext::default())
}
// For enum values
ast::Expr::Attribute(ast::ExprAttribute { value, attr, .. }) => {
let value_ty = self.infer_expression(value);
let value_ty = self.infer_expression(value, TypeContext::default());
if is_enum_class(self.db(), value_ty) {
let ty = value_ty
@ -1461,7 +1461,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
ty
}
_ => {
self.infer_expression(parameters);
self.infer_expression(parameters, TypeContext::default());
return Err(vec![parameters]);
}
})
@ -1507,7 +1507,7 @@ impl<'db> TypeInferenceBuilder<'db, '_> {
});
}
ast::Expr::Subscript(subscript) => {
let value_ty = self.infer_expression(&subscript.value);
let value_ty = self.infer_expression(&subscript.value, TypeContext::default());
self.infer_subscript_type_expression(subscript, value_ty);
// TODO: Support `Concatenate[...]`
return Some(Parameters::todo());

View file

@ -5,7 +5,7 @@ use crate::place::{ConsideredDefinitions, Place, global_symbol};
use crate::semantic_index::definition::Definition;
use crate::semantic_index::scope::FileScopeId;
use crate::semantic_index::{global_scope, place_table, semantic_index, use_def_map};
use crate::types::{KnownInstanceType, check_types};
use crate::types::{KnownClass, KnownInstanceType, UnionType, check_types};
use ruff_db::diagnostic::Diagnostic;
use ruff_db::files::{File, system_path_to_file};
use ruff_db::system::DbWithWritableSystem as _;
@ -409,17 +409,17 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
db.write_dedented(
"/src/mod.py",
r#"
class C:
def f(self):
self.attr: int | None = None
"#,
class C:
def f(self):
self.attr: int | None = None
"#,
)?;
db.write_dedented(
"/src/main.py",
r#"
from mod import C
x = C().attr
"#,
from mod import C
x = C().attr
"#,
)?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
@ -430,10 +430,10 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
db.write_dedented(
"/src/mod.py",
r#"
class C:
def f(self):
self.attr: str | None = None
"#,
class C:
def f(self):
self.attr: str | None = None
"#,
)?;
let events = {
@ -442,17 +442,22 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
// Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
def f(self):
# a comment!
self.attr: str | None = None
"#,
class C:
def f(self):
# a comment!
self.attr: str | None = None
"#,
)?;
let events = {
@ -462,7 +467,12 @@ fn dependency_implicit_instance_attribute() -> anyhow::Result<()> {
db.take_salsa_events()
};
assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
Ok(())
}
@ -487,19 +497,19 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
db.write_dedented(
"/src/mod.py",
r#"
class C:
if random.choice([True, False]):
attr: int = 42
else:
attr: None = None
"#,
class C:
if random.choice([True, False]):
attr: int = 42
else:
attr: None = None
"#,
)?;
db.write_dedented(
"/src/main.py",
r#"
from mod import C
x = C().attr
"#,
from mod import C
x = C().attr
"#,
)?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
@ -510,12 +520,12 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
db.write_dedented(
"/src/mod.py",
r#"
class C:
if random.choice([True, False]):
attr: str = "42"
else:
attr: None = None
"#,
class C:
if random.choice([True, False]):
attr: str = "42"
else:
attr: None = None
"#,
)?;
let events = {
@ -524,19 +534,24 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str | None");
db.take_salsa_events()
};
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
// Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
# comment
if random.choice([True, False]):
attr: str = "42"
else:
attr: None = None
"#,
class C:
# comment
if random.choice([True, False]):
attr: str = "42"
else:
attr: None = None
"#,
)?;
let events = {
@ -546,7 +561,12 @@ fn dependency_own_instance_member() -> anyhow::Result<()> {
db.take_salsa_events()
};
assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
Ok(())
}
@ -569,22 +589,22 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
db.write_dedented(
"/src/mod.py",
r#"
class C:
def __init__(self):
self.instance_attr: str = "24"
class C:
def __init__(self):
self.instance_attr: str = "24"
@classmethod
def method(cls):
cls.class_attr: int = 42
"#,
@classmethod
def method(cls):
cls.class_attr: int = 42
"#,
)?;
db.write_dedented(
"/src/main.py",
r#"
from mod import C
C.method()
x = C().class_attr
"#,
from mod import C
C.method()
x = C().class_attr
"#,
)?;
let file_main = system_path_to_file(&db, "/src/main.py").unwrap();
@ -595,14 +615,14 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
db.write_dedented(
"/src/mod.py",
r#"
class C:
def __init__(self):
self.instance_attr: str = "24"
class C:
def __init__(self):
self.instance_attr: str = "24"
@classmethod
def method(cls):
cls.class_attr: str = "42"
"#,
@classmethod
def method(cls):
cls.class_attr: str = "42"
"#,
)?;
let events = {
@ -611,21 +631,26 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
assert_eq!(attr_ty.display(&db).to_string(), "Unknown | str");
db.take_salsa_events()
};
assert_function_query_was_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
// Add a comment; this should not trigger the type of `x` to be re-inferred
db.write_dedented(
"/src/mod.py",
r#"
class C:
def __init__(self):
self.instance_attr: str = "24"
class C:
def __init__(self):
self.instance_attr: str = "24"
@classmethod
def method(cls):
# comment
cls.class_attr: str = "42"
"#,
@classmethod
def method(cls):
# comment
cls.class_attr: str = "42"
"#,
)?;
let events = {
@ -635,7 +660,88 @@ fn dependency_implicit_class_member() -> anyhow::Result<()> {
db.take_salsa_events()
};
assert_function_query_was_not_run(&db, infer_expression_types, x_rhs_expression(&db), &events);
assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(x_rhs_expression(&db)),
&events,
);
Ok(())
}
/// Inferring the result of a call-expression shouldn't need to re-run after
/// a trivial change to the function's file (e.g. by adding a docstring to the function).
#[test]
fn call_type_doesnt_rerun_when_only_callee_changed() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
return 5
"#,
)?;
db.write_dedented(
"src/bar.py",
r#"
from foo import foo
a = foo()
"#,
)?;
let bar = system_path_to_file(&db, "src/bar.py")?;
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(foo_call),
&events,
);
// Add a docstring to foo to trigger a re-run.
// The bar-call site of foo should not be re-run because of that
db.write_dedented(
"src/foo.py",
r#"
def foo() -> int:
"Computes a value"
return 5
"#,
)?;
db.clear_salsa_events();
let a = global_symbol(&db, bar, "a").place;
assert_eq!(
a.expect_type(),
UnionType::from_elements(&db, [Type::unknown(), KnownClass::Int.to_instance(&db)])
);
let events = db.take_salsa_events();
let module = parsed_module(&db, bar).load(&db);
let call = &*module.syntax().body[1].as_assign_stmt().unwrap().value;
let foo_call = semantic_index(&db, bar).expression(call);
assert_function_query_was_not_run(
&db,
infer_expression_types_impl,
InferExpression::Bare(foo_call),
&events,
);
Ok(())
}

View file

@ -12,7 +12,7 @@ use crate::types::function::KnownFunction;
use crate::types::infer::infer_same_file_expression_type;
use crate::types::{
ClassLiteral, ClassType, IntersectionBuilder, KnownClass, SubclassOfInner, SubclassOfType,
Truthiness, Type, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
Truthiness, Type, TypeContext, TypeVarBoundOrConstraints, UnionBuilder, infer_expression_types,
};
use ruff_db::parsed::{ParsedModuleRef, parsed_module};
@ -773,7 +773,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
return None;
}
let inference = infer_expression_types(self.db, expression);
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let comparator_tuples = std::iter::once(&**left)
.chain(comparators)
@ -863,7 +863,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression);
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let callable_ty = inference.expression_type(&*expr_call.func);
@ -983,7 +983,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, cls, self.module).to_instance(self.db)?;
let ty = infer_same_file_expression_type(self.db, cls, TypeContext::default(), self.module)
.to_instance(self.db)?;
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -996,7 +997,8 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
let subject = place_expr(subject.node_ref(self.db, self.module))?;
let place = self.expect_place(&subject);
let ty = infer_same_file_expression_type(self.db, value, self.module);
let ty =
infer_same_file_expression_type(self.db, value, TypeContext::default(), self.module);
Some(NarrowingConstraints::from_iter([(place, ty)]))
}
@ -1025,7 +1027,7 @@ impl<'db, 'ast> NarrowingConstraintsBuilder<'db, 'ast> {
expression: Expression<'db>,
is_positive: bool,
) -> Option<NarrowingConstraints<'db>> {
let inference = infer_expression_types(self.db, expression);
let inference = infer_expression_types(self.db, expression, TypeContext::default());
let mut sub_constraints = expr_bool_op
.values
.iter()

View file

@ -9,7 +9,7 @@ use crate::Db;
use crate::semantic_index::ast_ids::node_key::ExpressionNodeKey;
use crate::semantic_index::scope::ScopeId;
use crate::types::tuple::{ResizeTupleError, Tuple, TupleLength, TupleSpec, TupleUnpacker};
use crate::types::{Type, TypeCheckDiagnostics, infer_expression_types};
use crate::types::{Type, TypeCheckDiagnostics, TypeContext, infer_expression_types};
use crate::unpack::{UnpackKind, UnpackValue};
use super::context::InferContext;
@ -48,8 +48,9 @@ impl<'db, 'ast> Unpacker<'db, 'ast> {
"Unpacking target must be a list or tuple expression"
);
let value_type = infer_expression_types(self.db(), value.expression())
.expression_type(value.expression().node_ref(self.db(), self.module()));
let value_type =
infer_expression_types(self.db(), value.expression(), TypeContext::default())
.expression_type(value.expression().node_ref(self.db(), self.module()));
let value_type = match value.kind() {
UnpackKind::Assign => {