[red-knot] Emit a diagnostic if the value of a starred expression or a yield from expression is not iterable (#13240)

This commit is contained in:
Alex Waygood 2024-09-04 15:19:11 +01:00 committed by GitHub
parent 46a457318d
commit 0512428a6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 133 additions and 25 deletions

View file

@ -1,3 +1,4 @@
use infer::TypeInferenceBuilder;
use ruff_db::files::File; use ruff_db::files::File;
use ruff_python_ast as ast; use ruff_python_ast as ast;
@ -400,28 +401,42 @@ impl<'db> Type<'db> {
/// for y in x: /// for y in x:
/// pass /// pass
/// ``` /// ```
/// fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
// `self` represents the type of the iterable; // `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable: // `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_meta_type(db); let iterable_meta_type = self.to_meta_type(db);
let dunder_iter_method = type_of_class.member(db, "__iter__"); let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() { if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?; let Some(iterator_ty) = dunder_iter_method.call(db) else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
};
};
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__"); let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db); return dunder_next_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
});
} }
// Although it's not considered great practice, // Although it's not considered great practice,
// classes that define `__getitem__` are also iterable, // classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`. // even if they do not define `__iter__`.
// //
// TODO this is only valid if the `__getitem__` method is annotated as // TODO(Alex) this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex` // accepting `int` or `SupportsIndex`
let dunder_get_item_method = type_of_class.member(db, "__getitem__"); let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
dunder_get_item_method.call(db)
dunder_get_item_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
})
} }
#[must_use] #[must_use]
@ -463,6 +478,28 @@ impl<'db> Type<'db> {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
NotIterable { not_iterable_ty: Type<'db> },
}
impl<'db> IterationOutcome<'db> {
fn unwrap_with_diagnostic(
self,
iterable_node: ast::AnyNodeRef,
inference_builder: &mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Iterable { element_ty } => element_ty,
Self::NotIterable { not_iterable_ty } => {
inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty);
Type::Unknown
}
}
}
}
#[salsa::interned] #[salsa::interned]
pub struct FunctionType<'db> { pub struct FunctionType<'db> {
/// name of the function at definition /// name of the function at definition
@ -789,4 +826,65 @@ mod tests {
&["Object of type 'NotIterable' is not iterable"], &["Object of type 'NotIterable' is not iterable"],
); );
} }
#[test]
fn starred_expressions_must_be_iterable() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class NotIterable: pass
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
x = [*NotIterable()]
y = [*Iterable()]
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
#[test]
fn yield_from_expression_must_be_iterable() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class NotIterable: pass
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
def generator_function():
yield from Iterable()
yield from NotIterable()
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
} }

View file

@ -243,7 +243,7 @@ impl<'db> TypeInference<'db> {
/// Similarly, when we encounter a standalone-inferable expression (right-hand side of an /// Similarly, when we encounter a standalone-inferable expression (right-hand side of an
/// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we /// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we
/// don't infer its types more than once. /// don't infer its types more than once.
struct TypeInferenceBuilder<'db> { pub(super) struct TypeInferenceBuilder<'db> {
db: &'db dyn Db, db: &'db dyn Db,
index: &'db SemanticIndex<'db>, index: &'db SemanticIndex<'db>,
region: InferenceRegion<'db>, region: InferenceRegion<'db>,
@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_body(orelse); self.infer_body(orelse);
} }
/// Emit a diagnostic declaring that the object represented by `node` is not iterable
pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) {
self.add_diagnostic(
node,
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
not_iterable_ty.display(self.db)
),
);
}
fn infer_for_statement_definition( fn infer_for_statement_definition(
&mut self, &mut self,
target: &ast::ExprName, target: &ast::ExprName,
@ -1042,17 +1054,9 @@ impl<'db> TypeInferenceBuilder<'db> {
.types .types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope)); .expression_ty(iterable.scoped_ast_id(self.db, self.scope));
let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| { let loop_var_value_ty = iterable_ty
self.add_diagnostic( .iterate(self.db)
iterable.into(), .unwrap_with_diagnostic(iterable.into(), self);
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
Type::Unknown
});
self.types self.types
.expressions .expressions
@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _, ctx: _,
} = starred; } = starred;
self.infer_expression(value); let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);
// TODO // TODO
Type::Unknown Type::Unknown
@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> { fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> {
let ast::ExprYieldFrom { range: _, value } = yield_from; let ast::ExprYieldFrom { range: _, value } = yield_from;
self.infer_expression(value); let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);
// TODO get type from awaitable // TODO get type from `ReturnType` of generator
Type::Unknown Type::Unknown
} }