[red-knot] Emit a diagnostic if the value of a starred expression or a yield from expression is not iterable (#13240)

This commit is contained in:
Alex Waygood 2024-09-04 15:19:11 +01:00 committed by GitHub
parent 46a457318d
commit 0512428a6f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 133 additions and 25 deletions

View file

@ -1,3 +1,4 @@
use infer::TypeInferenceBuilder;
use ruff_db::files::File;
use ruff_python_ast as ast;
@ -400,28 +401,42 @@ impl<'db> Type<'db> {
/// for y in x:
/// pass
/// ```
///
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
fn iterate(&self, db: &'db dyn Db) -> IterationOutcome<'db> {
// `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_meta_type(db);
let iterable_meta_type = self.to_meta_type(db);
let dunder_iter_method = type_of_class.member(db, "__iter__");
let dunder_iter_method = iterable_meta_type.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let Some(iterator_ty) = dunder_iter_method.call(db) else {
return IterationOutcome::NotIterable {
not_iterable_ty: *self,
};
};
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db);
return dunder_next_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
});
}
// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`.
//
// TODO this is only valid if the `__getitem__` method is annotated as
// TODO(Alex) this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex`
let dunder_get_item_method = type_of_class.member(db, "__getitem__");
dunder_get_item_method.call(db)
let dunder_get_item_method = iterable_meta_type.member(db, "__getitem__");
dunder_get_item_method
.call(db)
.map(|element_ty| IterationOutcome::Iterable { element_ty })
.unwrap_or(IterationOutcome::NotIterable {
not_iterable_ty: *self,
})
}
#[must_use]
@ -463,6 +478,28 @@ impl<'db> Type<'db> {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum IterationOutcome<'db> {
Iterable { element_ty: Type<'db> },
NotIterable { not_iterable_ty: Type<'db> },
}
impl<'db> IterationOutcome<'db> {
fn unwrap_with_diagnostic(
self,
iterable_node: ast::AnyNodeRef,
inference_builder: &mut TypeInferenceBuilder<'db>,
) -> Type<'db> {
match self {
Self::Iterable { element_ty } => element_ty,
Self::NotIterable { not_iterable_ty } => {
inference_builder.not_iterable_diagnostic(iterable_node, not_iterable_ty);
Type::Unknown
}
}
}
}
#[salsa::interned]
pub struct FunctionType<'db> {
/// name of the function at definition
@ -789,4 +826,65 @@ mod tests {
&["Object of type 'NotIterable' is not iterable"],
);
}
#[test]
fn starred_expressions_must_be_iterable() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class NotIterable: pass
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
x = [*NotIterable()]
y = [*Iterable()]
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
#[test]
fn yield_from_expression_must_be_iterable() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class NotIterable: pass
class Iterator:
def __next__(self) -> int:
return 42
class Iterable:
def __iter__(self) -> Iterator:
def generator_function():
yield from Iterable()
yield from NotIterable()
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
}

View file

@ -243,7 +243,7 @@ impl<'db> TypeInference<'db> {
/// Similarly, when we encounter a standalone-inferable expression (right-hand side of an
/// assignment, type narrowing guard), we use the [`infer_expression_types()`] query to ensure we
/// don't infer its types more than once.
struct TypeInferenceBuilder<'db> {
pub(super) struct TypeInferenceBuilder<'db> {
db: &'db dyn Db,
index: &'db SemanticIndex<'db>,
region: InferenceRegion<'db>,
@ -1029,6 +1029,18 @@ impl<'db> TypeInferenceBuilder<'db> {
self.infer_body(orelse);
}
/// Emit a diagnostic declaring that the object represented by `node` is not iterable
pub(super) fn not_iterable_diagnostic(&mut self, node: AnyNodeRef, not_iterable_ty: Type<'db>) {
self.add_diagnostic(
node,
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
not_iterable_ty.display(self.db)
),
);
}
fn infer_for_statement_definition(
&mut self,
target: &ast::ExprName,
@ -1042,17 +1054,9 @@ impl<'db> TypeInferenceBuilder<'db> {
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| {
self.add_diagnostic(
iterable.into(),
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
Type::Unknown
});
let loop_var_value_ty = iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(iterable.into(), self);
self.types
.expressions
@ -1812,7 +1816,10 @@ impl<'db> TypeInferenceBuilder<'db> {
ctx: _,
} = starred;
self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);
// TODO
Type::Unknown
@ -1830,9 +1837,12 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_yield_from_expression(&mut self, yield_from: &ast::ExprYieldFrom) -> Type<'db> {
let ast::ExprYieldFrom { range: _, value } = yield_from;
self.infer_expression(value);
let iterable_ty = self.infer_expression(value);
iterable_ty
.iterate(self.db)
.unwrap_with_diagnostic(value.as_ref().into(), self);
// TODO get type from awaitable
// TODO get type from `ReturnType` of generator
Type::Unknown
}