mirror of
https://github.com/astral-sh/ruff.git
synced 2025-09-30 22:01:18 +00:00
use annotated parameters as type context
This commit is contained in:
parent
0092794302
commit
0babd8d9d4
8 changed files with 404 additions and 63 deletions
|
@ -1090,7 +1090,11 @@ impl<'db> InnerIntersectionBuilder<'db> {
|
|||
// don't need to worry about finding any particular constraint more than once.
|
||||
let constraints = constraints.elements(db);
|
||||
let mut positive_constraint_count = 0;
|
||||
for positive in &self.positive {
|
||||
for (i, positive) in self.positive.iter().enumerate() {
|
||||
if i == typevar_index {
|
||||
continue;
|
||||
}
|
||||
|
||||
// This linear search should be fine as long as we don't encounter typevars with
|
||||
// thousands of constraints.
|
||||
positive_constraint_count += constraints
|
||||
|
|
|
@ -33,10 +33,10 @@ use crate::types::{
|
|||
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
|
||||
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
|
||||
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
|
||||
WrapperDescriptorKind, enums, ide_support, todo_type,
|
||||
WrapperDescriptorKind, enums, ide_support, infer_isolated_expression, todo_type,
|
||||
};
|
||||
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
|
||||
use ruff_python_ast::{self as ast, PythonVersion};
|
||||
use ruff_python_ast::{self as ast, ArgOrKeyword, PythonVersion};
|
||||
|
||||
/// Binding information for a possible union of callables. At a call site, the arguments must be
|
||||
/// compatible with _all_ of the types in the union for the call to be valid.
|
||||
|
@ -1732,7 +1732,7 @@ impl<'db> CallableBinding<'db> {
|
|||
}
|
||||
|
||||
/// Returns the index of the matching overload in the form of [`MatchingOverloadIndex`].
|
||||
fn matching_overload_index(&self) -> MatchingOverloadIndex {
|
||||
pub(crate) fn matching_overload_index(&self) -> MatchingOverloadIndex {
|
||||
let mut matching_overloads = self.matching_overloads();
|
||||
match matching_overloads.next() {
|
||||
None => MatchingOverloadIndex::None,
|
||||
|
@ -1750,8 +1750,15 @@ impl<'db> CallableBinding<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns all overloads for this call binding, including overloads that did not match.
|
||||
pub(crate) fn overloads(&self) -> &[Binding<'db>] {
|
||||
self.overloads.as_slice()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all the overloads that matched for this call binding.
|
||||
pub(crate) fn matching_overloads(&self) -> impl Iterator<Item = (usize, &Binding<'db>)> {
|
||||
pub(crate) fn matching_overloads(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (usize, &Binding<'db>)> + Clone {
|
||||
self.overloads
|
||||
.iter()
|
||||
.enumerate()
|
||||
|
@ -1982,7 +1989,7 @@ enum OverloadCallReturnType<'db> {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum MatchingOverloadIndex {
|
||||
pub(crate) enum MatchingOverloadIndex {
|
||||
/// No matching overloads found.
|
||||
None,
|
||||
|
||||
|
@ -1993,6 +2000,16 @@ enum MatchingOverloadIndex {
|
|||
Multiple(Vec<usize>),
|
||||
}
|
||||
|
||||
impl MatchingOverloadIndex {
|
||||
pub(crate) fn count(self) -> usize {
|
||||
match self {
|
||||
MatchingOverloadIndex::None => 0,
|
||||
MatchingOverloadIndex::Single(_) => 1,
|
||||
MatchingOverloadIndex::Multiple(items) => items.len(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ArgumentForms {
|
||||
values: Vec<Option<ParameterForm>>,
|
||||
|
@ -2464,9 +2481,17 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
|
|||
if let Some(return_ty) = self.signature.return_ty
|
||||
&& let Some(call_expression_tcx) = self.call_expression_tcx.annotation
|
||||
{
|
||||
// Ignore any specialization errors here, because the type context is only used to
|
||||
// optionally widen the return type.
|
||||
let _ = builder.infer(return_ty, call_expression_tcx);
|
||||
match call_expression_tcx {
|
||||
// A type variable is not a useful type-context for expression inference, and applying it
|
||||
// to the return type can lead to confusing unions in nested generic calls.
|
||||
Type::TypeVar(_) => {}
|
||||
|
||||
_ => {
|
||||
// Ignore any specialization errors here, because the type context is only used to
|
||||
// optionally widen the return type.
|
||||
let _ = builder.infer(return_ty, call_expression_tcx);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let parameters = self.signature.parameters();
|
||||
|
@ -3278,6 +3303,23 @@ impl<'db> BindingError<'db> {
|
|||
return;
|
||||
};
|
||||
|
||||
// Re-infer the argument type of call expressions, ignoring the type context for more
|
||||
// precise error messages.
|
||||
let provided_ty = match Self::get_argument_node(node, *argument_index) {
|
||||
None => *provided_ty,
|
||||
|
||||
// Ignore starred arguments, as those are difficult to re-infer.
|
||||
Some(
|
||||
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }),
|
||||
) => *provided_ty,
|
||||
|
||||
Some(
|
||||
ast::ArgOrKeyword::Arg(value)
|
||||
| ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }),
|
||||
) => infer_isolated_expression(context.db(), context.scope(), value),
|
||||
};
|
||||
|
||||
let provided_ty_display = provided_ty.display(context.db());
|
||||
let expected_ty_display = expected_ty.display(context.db());
|
||||
|
||||
|
@ -3613,22 +3655,29 @@ impl<'db> BindingError<'db> {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_node(node: ast::AnyNodeRef, argument_index: Option<usize>) -> ast::AnyNodeRef {
|
||||
fn get_node(node: ast::AnyNodeRef<'_>, argument_index: Option<usize>) -> ast::AnyNodeRef<'_> {
|
||||
// If we have a Call node and an argument index, report the diagnostic on the correct
|
||||
// argument node; otherwise, report it on the entire provided node.
|
||||
match Self::get_argument_node(node, argument_index) {
|
||||
Some(ast::ArgOrKeyword::Arg(expr)) => expr.into(),
|
||||
Some(ast::ArgOrKeyword::Keyword(expr)) => expr.into(),
|
||||
None => node,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_argument_node(
|
||||
node: ast::AnyNodeRef<'_>,
|
||||
argument_index: Option<usize>,
|
||||
) -> Option<ArgOrKeyword<'_>> {
|
||||
match (node, argument_index) {
|
||||
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => {
|
||||
match call_node
|
||||
(ast::AnyNodeRef::ExprCall(call_node), Some(argument_index)) => Some(
|
||||
call_node
|
||||
.arguments
|
||||
.arguments_source_order()
|
||||
.nth(argument_index)
|
||||
.expect("argument index should not be out of range")
|
||||
{
|
||||
ast::ArgOrKeyword::Arg(expr) => expr.into(),
|
||||
ast::ArgOrKeyword::Keyword(keyword) => keyword.into(),
|
||||
}
|
||||
}
|
||||
_ => node,
|
||||
.expect("argument index should not be out of range"),
|
||||
),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1975,7 +1975,7 @@ pub(super) fn report_invalid_assignment<'db>(
|
|||
if let DefinitionKind::AnnotatedAssignment(annotated_assignment) = definition.kind(context.db())
|
||||
&& let Some(value) = annotated_assignment.value(context.module())
|
||||
{
|
||||
// Re-infer the RHS of the annotated assignment, ignoring the type context, for more precise
|
||||
// Re-infer the RHS of the annotated assignment, ignoring the type context for more precise
|
||||
// error messages.
|
||||
source_ty = infer_isolated_expression(context.db(), definition.scope(context.db()), value);
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::iter;
|
||||
use std::{iter, mem};
|
||||
|
||||
use itertools::{Either, Itertools};
|
||||
use ruff_db::diagnostic::{Annotation, DiagnosticId, Severity};
|
||||
|
@ -44,6 +44,7 @@ use crate::semantic_index::symbol::{ScopedSymbolId, Symbol};
|
|||
use crate::semantic_index::{
|
||||
ApplicableConstraints, EnclosingSnapshotResult, SemanticIndex, place_table,
|
||||
};
|
||||
use crate::types::call::bind::MatchingOverloadIndex;
|
||||
use crate::types::call::{Binding, Bindings, CallArguments, CallError, CallErrorKind};
|
||||
use crate::types::class::{CodeGeneratorKind, FieldKind, MetaclassErrorKind, MethodDecorator};
|
||||
use crate::types::context::{InNoTypeCheck, InferContext};
|
||||
|
@ -258,6 +259,8 @@ pub(super) struct TypeInferenceBuilder<'db, 'ast> {
|
|||
/// is a stub file but we're still in a non-deferred region.
|
||||
deferred_state: DeferredExpressionState,
|
||||
|
||||
multi_inference_state: MultiInferenceState,
|
||||
|
||||
/// For function definitions, the undecorated type of the function.
|
||||
undecorated_type: Option<Type<'db>>,
|
||||
|
||||
|
@ -288,10 +291,11 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
context: InferContext::new(db, scope, module),
|
||||
index,
|
||||
region,
|
||||
scope,
|
||||
return_types_and_ranges: vec![],
|
||||
called_functions: FxHashSet::default(),
|
||||
deferred_state: DeferredExpressionState::None,
|
||||
scope,
|
||||
multi_inference_state: MultiInferenceState::Panic,
|
||||
expressions: FxHashMap::default(),
|
||||
bindings: VecMap::default(),
|
||||
declarations: VecMap::default(),
|
||||
|
@ -3255,6 +3259,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
key,
|
||||
assigned_ty,
|
||||
value.as_ref(),
|
||||
true,
|
||||
slice.as_ref(),
|
||||
rhs,
|
||||
TypedDictAssignmentKind::Subscript,
|
||||
|
@ -4913,6 +4918,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
self.infer_expression(expression, TypeContext::default())
|
||||
}
|
||||
|
||||
/// Infer the argument types for a single binding.
|
||||
fn infer_argument_types<'a>(
|
||||
&mut self,
|
||||
ast_arguments: &ast::Arguments,
|
||||
|
@ -4922,22 +4928,135 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
debug_assert!(
|
||||
ast_arguments.len() == arguments.len() && arguments.len() == argument_forms.len()
|
||||
);
|
||||
let iter = (arguments.iter_mut())
|
||||
.zip(argument_forms.iter().copied())
|
||||
.zip(ast_arguments.arguments_source_order());
|
||||
for (((_, argument_type), form), arg_or_keyword) in iter {
|
||||
let argument = match arg_or_keyword {
|
||||
// We already inferred the type of splatted arguments.
|
||||
|
||||
let iter = itertools::izip!(
|
||||
arguments.iter_mut(),
|
||||
argument_forms.iter().copied(),
|
||||
ast_arguments.arguments_source_order()
|
||||
);
|
||||
|
||||
for ((_, argument_type), argument_form, ast_argument) in iter {
|
||||
let argument = match ast_argument {
|
||||
// Splatted arguments are inferred before parameter matching to
|
||||
// determine their length.
|
||||
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
|
||||
|
||||
ast::ArgOrKeyword::Arg(arg) => arg,
|
||||
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
|
||||
};
|
||||
let ty = self.infer_argument_type(argument, form, TypeContext::default());
|
||||
|
||||
let ty = self.infer_argument_type(argument, argument_form, TypeContext::default());
|
||||
*argument_type = Some(ty);
|
||||
}
|
||||
}
|
||||
|
||||
/// Infer the argument types for multiple potential bindings and overloads.
|
||||
fn infer_all_argument_types<'a>(
|
||||
&mut self,
|
||||
ast_arguments: &ast::Arguments,
|
||||
arguments: &mut CallArguments<'a, 'db>,
|
||||
bindings: &Bindings<'db>,
|
||||
) {
|
||||
debug_assert!(
|
||||
ast_arguments.len() == arguments.len()
|
||||
&& arguments.len() == bindings.argument_forms().len()
|
||||
);
|
||||
|
||||
let iter = itertools::izip!(
|
||||
0..,
|
||||
arguments.iter_mut(),
|
||||
bindings.argument_forms().iter().copied(),
|
||||
ast_arguments.arguments_source_order()
|
||||
);
|
||||
|
||||
let bindings_count = bindings.into_iter().count();
|
||||
|
||||
for (argument_index, (_, argument_type), argument_form, ast_argument) in iter {
|
||||
let ast_argument = match ast_argument {
|
||||
// Splatted arguments are inferred before parameter matching to
|
||||
// determine their length.
|
||||
//
|
||||
// TODO: Re-infer splatted arguments with their type context.
|
||||
ast::ArgOrKeyword::Arg(ast::Expr::Starred(_))
|
||||
| ast::ArgOrKeyword::Keyword(ast::Keyword { arg: None, .. }) => continue,
|
||||
|
||||
ast::ArgOrKeyword::Arg(arg) => arg,
|
||||
ast::ArgOrKeyword::Keyword(ast::Keyword { value, .. }) => value,
|
||||
};
|
||||
|
||||
// Type-form arguments are inferred without type context, so we can infer the argument type directly.
|
||||
if let Some(ParameterForm::Type) = argument_form {
|
||||
*argument_type = Some(self.infer_type_expression(ast_argument));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Otherwise, we infer the type of each argument once for each matching overload signature,
|
||||
// with the given annotated type as type context.
|
||||
for binding in bindings {
|
||||
let argument_index = if binding.bound_type.is_some() {
|
||||
argument_index + 1
|
||||
} else {
|
||||
argument_index
|
||||
};
|
||||
|
||||
let (overloads, overloads_count) = match binding.matching_overload_index() {
|
||||
index @ (MatchingOverloadIndex::Single(_)
|
||||
| MatchingOverloadIndex::Multiple(_)) => (
|
||||
Either::Right(binding.matching_overloads().map(|(_, overload)| overload)),
|
||||
index.count(),
|
||||
),
|
||||
|
||||
// If there is a single overload that does not match, we still infer the argument
|
||||
// types for better diagnostics.
|
||||
MatchingOverloadIndex::None => match binding.overloads() {
|
||||
[overload] => (Either::Left([overload].into_iter()), 1),
|
||||
_ => continue,
|
||||
},
|
||||
};
|
||||
|
||||
let multi_inference_state = if (bindings_count, overloads_count) == (1, 1) {
|
||||
// If there is only a single binding and overload, there is a unique parameter type annotation for
|
||||
// each argument.
|
||||
self.multi_inference_state
|
||||
} else {
|
||||
// Otherwise, each type is a valid independent inference of the given argument, and we may
|
||||
// require different permutations of argument types to correctly perform argument expansion
|
||||
// during overload evaluation, so we take the intersection of all the types we inferred for
|
||||
// each argument.
|
||||
MultiInferenceState::Intersect {
|
||||
// Note that the argument must be assignable to its parameter type for every binding in the union.
|
||||
//
|
||||
// However, if there are multiple overloads for a given binding, type-checking should not fail
|
||||
// if the parameter type annotation of a given overload is not fulfilled.
|
||||
fallback: overloads_count > 1,
|
||||
}
|
||||
};
|
||||
|
||||
// Update the state of the inference builder to apply intersections to all nested expressions.
|
||||
let old_multi_inference_state =
|
||||
mem::replace(&mut self.multi_inference_state, multi_inference_state);
|
||||
|
||||
for overload in overloads {
|
||||
let argument_matches = &overload.argument_matches()[argument_index];
|
||||
let [parameter_index] = argument_matches.parameters.as_slice() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let parameter_type =
|
||||
overload.signature.parameters()[*parameter_index].annotated_type();
|
||||
|
||||
self.infer_expression_impl(ast_argument, TypeContext::new(parameter_type));
|
||||
}
|
||||
|
||||
// Restore the multi-inference state.
|
||||
self.multi_inference_state = old_multi_inference_state;
|
||||
}
|
||||
|
||||
*argument_type = self.try_expression_type(ast_argument);
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_argument_type(
|
||||
&mut self,
|
||||
ast_argument: &ast::Expr,
|
||||
|
@ -5018,6 +5137,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
types.expression_type(expression)
|
||||
}
|
||||
|
||||
/// Infer the type of an expression.
|
||||
fn infer_expression_impl(
|
||||
&mut self,
|
||||
expression: &ast::Expr,
|
||||
|
@ -5070,6 +5190,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
|
||||
ty
|
||||
}
|
||||
|
||||
fn store_expression_type(&mut self, expression: &ast::Expr, ty: Type<'db>) {
|
||||
if self.deferred_state.in_string_annotation() {
|
||||
// Avoid storing the type of expressions that are part of a string annotation because
|
||||
|
@ -5077,8 +5198,24 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// on the string expression itself that represents the annotation.
|
||||
return;
|
||||
}
|
||||
let previous = self.expressions.insert(expression.into(), ty);
|
||||
assert_eq!(previous, None);
|
||||
|
||||
let db = self.db();
|
||||
|
||||
match self.multi_inference_state {
|
||||
MultiInferenceState::Panic => {
|
||||
let previous = self.expressions.insert(expression.into(), ty);
|
||||
assert_eq!(previous, None);
|
||||
}
|
||||
|
||||
MultiInferenceState::Intersect { .. } => {
|
||||
self.expressions
|
||||
.entry(expression.into())
|
||||
.and_modify(|current| {
|
||||
*current = IntersectionType::from_elements(db, [*current, ty]);
|
||||
})
|
||||
.or_insert(ty);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_number_literal_expression(&mut self, literal: &ast::ExprNumberLiteral) -> Type<'db> {
|
||||
|
@ -5315,15 +5452,33 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
}
|
||||
}
|
||||
|
||||
validate_typed_dict_dict_literal(
|
||||
let report_diagnostics = match self.multi_inference_state {
|
||||
// Do not eagerly report diagnostics when performing overload evaluation
|
||||
// with multiple potential overloads, as we may fallback to an untyped
|
||||
// dictionary literal.
|
||||
MultiInferenceState::Intersect { fallback: true } => false,
|
||||
_ => true,
|
||||
};
|
||||
|
||||
let result = validate_typed_dict_dict_literal(
|
||||
&self.context,
|
||||
typed_dict,
|
||||
dict,
|
||||
dict.into(),
|
||||
report_diagnostics,
|
||||
|expr| self.expression_type(expr),
|
||||
);
|
||||
|
||||
return Type::TypedDict(typed_dict);
|
||||
match result {
|
||||
// Successfully validated the dictionary literal.
|
||||
Ok(_) => return Type::TypedDict(typed_dict),
|
||||
|
||||
// The dictionary is not valid, but we are eagerly reporting diagnostics.
|
||||
Err(_) if report_diagnostics => return Type::TypedDict(typed_dict),
|
||||
|
||||
// Otherwise, fallback to an untyped dictionary literal.
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid false positives for the functional `TypedDict` form, which is currently
|
||||
|
@ -5976,7 +6131,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
let bindings = callable_type
|
||||
.bindings(self.db())
|
||||
.match_parameters(self.db(), &call_arguments);
|
||||
self.infer_argument_types(arguments, &mut call_arguments, bindings.argument_forms());
|
||||
self.infer_all_argument_types(arguments, &mut call_arguments, &bindings);
|
||||
|
||||
// Validate `TypedDict` constructor calls after argument type inference
|
||||
if let Some(class_literal) = callable_type.into_class_literal() {
|
||||
|
@ -9096,6 +9251,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// builder only state
|
||||
typevar_binding_context: _,
|
||||
deferred_state: _,
|
||||
multi_inference_state: _,
|
||||
called_functions: _,
|
||||
index: _,
|
||||
region: _,
|
||||
|
@ -9158,6 +9314,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// builder only state
|
||||
typevar_binding_context: _,
|
||||
deferred_state: _,
|
||||
multi_inference_state: _,
|
||||
called_functions: _,
|
||||
index: _,
|
||||
region: _,
|
||||
|
@ -9229,6 +9386,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
|
|||
// Builder only state
|
||||
typevar_binding_context: _,
|
||||
deferred_state: _,
|
||||
multi_inference_state: _,
|
||||
called_functions: _,
|
||||
index: _,
|
||||
region: _,
|
||||
|
@ -9274,6 +9432,22 @@ impl GenericContextError {
|
|||
}
|
||||
}
|
||||
|
||||
/// Dictates the behavior when an expression is inferred multiple times.
|
||||
#[derive(Default, Debug, Clone, Copy)]
|
||||
enum MultiInferenceState {
|
||||
/// Panic if the expression has already been inferred.
|
||||
#[default]
|
||||
Panic,
|
||||
|
||||
/// Store the intersection of all types inferred for the expression.
|
||||
Intersect {
|
||||
// Determines whether or not a given expression is required to be assignable to its type context
|
||||
// despite it being inferred multiple times, i.e. whether eager diagnostics are appropriate, or a
|
||||
// fallback type should be assumed.
|
||||
fallback: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// The deferred state of a specific expression in an inference region.
|
||||
#[derive(Default, Debug, Clone, Copy)]
|
||||
enum DeferredExpressionState {
|
||||
|
@ -9547,7 +9721,7 @@ impl<K, V> Default for VecMap<K, V> {
|
|||
|
||||
/// Set based on a `Vec`. It doesn't enforce
|
||||
/// uniqueness on insertion. Instead, it relies on the caller
|
||||
/// that elements are uniuqe. For example, the way we visit definitions
|
||||
/// that elements are unique. For example, the way we visit definitions
|
||||
/// in the `TypeInference` builder make already implicitly guarantees that each definition
|
||||
/// is only visited once.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
|
|
|
@ -132,6 +132,7 @@ impl TypedDictAssignmentKind {
|
|||
}
|
||||
|
||||
/// Validates assignment of a value to a specific key on a `TypedDict`.
|
||||
///
|
||||
/// Returns true if the assignment is valid, false otherwise.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
||||
|
@ -140,6 +141,7 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||
key: &str,
|
||||
value_ty: Type<'db>,
|
||||
typed_dict_node: impl Into<AnyNodeRef<'ast>>,
|
||||
report_diagnostics: bool,
|
||||
key_node: impl Into<AnyNodeRef<'ast>>,
|
||||
value_node: impl Into<AnyNodeRef<'ast>>,
|
||||
assignment_kind: TypedDictAssignmentKind,
|
||||
|
@ -149,14 +151,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||
|
||||
// Check if key exists in `TypedDict`
|
||||
let Some((_, item)) = items.iter().find(|(name, _)| *name == key) else {
|
||||
report_invalid_key_on_typed_dict(
|
||||
context,
|
||||
typed_dict_node.into(),
|
||||
key_node.into(),
|
||||
Type::TypedDict(typed_dict),
|
||||
Type::string_literal(db, key),
|
||||
&items,
|
||||
);
|
||||
if report_diagnostics {
|
||||
report_invalid_key_on_typed_dict(
|
||||
context,
|
||||
typed_dict_node.into(),
|
||||
key_node.into(),
|
||||
Type::TypedDict(typed_dict),
|
||||
Type::string_literal(db, key),
|
||||
&items,
|
||||
);
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
|
@ -177,8 +182,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||
};
|
||||
|
||||
if assignment_kind.is_subscript() && item.is_read_only() {
|
||||
if let Some(builder) =
|
||||
context.report_lint(assignment_kind.diagnostic_type(), key_node.into())
|
||||
if report_diagnostics
|
||||
&& let Some(builder) =
|
||||
context.report_lint(assignment_kind.diagnostic_type(), key_node.into())
|
||||
{
|
||||
let typed_dict_ty = Type::TypedDict(typed_dict);
|
||||
let typed_dict_d = typed_dict_ty.display(db);
|
||||
|
@ -207,7 +213,9 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||
}
|
||||
|
||||
// Invalid assignment - emit diagnostic
|
||||
if let Some(builder) = context.report_lint(assignment_kind.diagnostic_type(), value_node.into())
|
||||
if report_diagnostics
|
||||
&& let Some(builder) =
|
||||
context.report_lint(assignment_kind.diagnostic_type(), value_node.into())
|
||||
{
|
||||
let typed_dict_ty = Type::TypedDict(typed_dict);
|
||||
let typed_dict_d = typed_dict_ty.display(db);
|
||||
|
@ -240,13 +248,17 @@ pub(super) fn validate_typed_dict_key_assignment<'db, 'ast>(
|
|||
}
|
||||
|
||||
/// Validates that all required keys are provided in a `TypedDict` construction.
|
||||
///
|
||||
/// Reports errors for any keys that are required but not provided.
|
||||
///
|
||||
/// Returns true if the assignment is valid, false otherwise.
|
||||
pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
|
||||
context: &InferContext<'db, 'ast>,
|
||||
typed_dict: TypedDictType<'db>,
|
||||
provided_keys: &OrderSet<&str>,
|
||||
error_node: AnyNodeRef<'ast>,
|
||||
) {
|
||||
report_diagnostics: bool,
|
||||
) -> bool {
|
||||
let db = context.db();
|
||||
let items = typed_dict.items(db);
|
||||
|
||||
|
@ -255,14 +267,23 @@ pub(super) fn validate_typed_dict_required_keys<'db, 'ast>(
|
|||
.filter_map(|(key_name, field)| field.is_required().then_some(key_name.as_str()))
|
||||
.collect();
|
||||
|
||||
for missing_key in required_keys.difference(provided_keys) {
|
||||
report_missing_typed_dict_key(
|
||||
context,
|
||||
error_node,
|
||||
Type::TypedDict(typed_dict),
|
||||
missing_key,
|
||||
);
|
||||
let missing_keys = required_keys.difference(provided_keys);
|
||||
|
||||
let mut has_missing_key = false;
|
||||
for missing_key in missing_keys {
|
||||
has_missing_key = true;
|
||||
|
||||
if report_diagnostics {
|
||||
report_missing_typed_dict_key(
|
||||
context,
|
||||
error_node,
|
||||
Type::TypedDict(typed_dict),
|
||||
missing_key,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
!has_missing_key
|
||||
}
|
||||
|
||||
pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
|
||||
|
@ -292,7 +313,7 @@ pub(super) fn validate_typed_dict_constructor<'db, 'ast>(
|
|||
)
|
||||
};
|
||||
|
||||
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
|
||||
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node, true);
|
||||
}
|
||||
|
||||
/// Validates a `TypedDict` constructor call with a single positional dictionary argument
|
||||
|
@ -325,6 +346,7 @@ fn validate_from_dict_literal<'db, 'ast>(
|
|||
key_str,
|
||||
value_type,
|
||||
error_node,
|
||||
true,
|
||||
key_expr,
|
||||
&dict_item.value,
|
||||
TypedDictAssignmentKind::Constructor,
|
||||
|
@ -363,6 +385,7 @@ fn validate_from_keywords<'db, 'ast>(
|
|||
arg_name.as_str(),
|
||||
arg_type,
|
||||
error_node,
|
||||
true,
|
||||
keyword,
|
||||
&keyword.value,
|
||||
TypedDictAssignmentKind::Constructor,
|
||||
|
@ -373,15 +396,17 @@ fn validate_from_keywords<'db, 'ast>(
|
|||
provided_keys
|
||||
}
|
||||
|
||||
/// Validates a `TypedDict` dictionary literal assignment
|
||||
/// Validates a `TypedDict` dictionary literal assignment,
|
||||
/// e.g. `person: Person = {"name": "Alice", "age": 30}`
|
||||
pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
||||
context: &InferContext<'db, 'ast>,
|
||||
typed_dict: TypedDictType<'db>,
|
||||
dict_expr: &'ast ast::ExprDict,
|
||||
error_node: AnyNodeRef<'ast>,
|
||||
report_diagnostics: bool,
|
||||
expression_type_fn: impl Fn(&ast::Expr) -> Type<'db>,
|
||||
) -> OrderSet<&'ast str> {
|
||||
) -> Result<OrderSet<&'ast str>, OrderSet<&'ast str>> {
|
||||
let mut valid = true;
|
||||
let mut provided_keys = OrderSet::new();
|
||||
|
||||
// Validate each key-value pair in the dictionary literal
|
||||
|
@ -392,12 +417,14 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
|||
provided_keys.insert(key_str);
|
||||
|
||||
let value_type = expression_type_fn(&item.value);
|
||||
validate_typed_dict_key_assignment(
|
||||
|
||||
valid &= validate_typed_dict_key_assignment(
|
||||
context,
|
||||
typed_dict,
|
||||
key_str,
|
||||
value_type,
|
||||
error_node,
|
||||
report_diagnostics,
|
||||
key_expr,
|
||||
&item.value,
|
||||
TypedDictAssignmentKind::Constructor,
|
||||
|
@ -406,7 +433,17 @@ pub(super) fn validate_typed_dict_dict_literal<'db, 'ast>(
|
|||
}
|
||||
}
|
||||
|
||||
validate_typed_dict_required_keys(context, typed_dict, &provided_keys, error_node);
|
||||
valid &= validate_typed_dict_required_keys(
|
||||
context,
|
||||
typed_dict,
|
||||
&provided_keys,
|
||||
error_node,
|
||||
report_diagnostics,
|
||||
);
|
||||
|
||||
provided_keys
|
||||
if valid {
|
||||
Ok(provided_keys)
|
||||
} else {
|
||||
Err(provided_keys)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue