[red-knot] Add type inference for basic for loops (#13195)

This commit is contained in:
Alex Waygood 2024-09-04 11:19:50 +01:00 committed by GitHub
parent 57289099bb
commit 46a457318d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 331 additions and 131 deletions

View file

@ -1,16 +0,0 @@
use crate::module_name::ModuleName;
use crate::module_resolver::resolve_module;
use crate::semantic_index::global_scope;
use crate::semantic_index::symbol::ScopeId;
use crate::Db;
/// Salsa query to get the builtins scope.
///
/// Can return None if a custom typeshed is used that is missing `builtins.pyi`.
#[salsa::tracked]
pub(crate) fn builtins_scope(db: &dyn Db) -> Option<ScopeId<'_>> {
let builtins_name =
ModuleName::new_static("builtins").expect("Expected 'builtins' to be a valid module name");
let builtins_file = resolve_module(db, builtins_name)?.file();
Some(global_scope(db, builtins_file))
}

View file

@ -10,7 +10,6 @@ pub use python_version::PythonVersion;
pub use semantic_model::{HasTy, SemanticModel};
pub mod ast_node_ref;
mod builtins;
mod db;
mod module_name;
mod module_resolver;
@ -20,6 +19,7 @@ mod python_version;
pub mod semantic_index;
mod semantic_model;
pub(crate) mod site_packages;
mod stdlib;
pub mod types;
type FxOrderSet<V> = ordermap::set::OrderSet<V, BuildHasherDefault<FxHasher>>;

View file

@ -8,7 +8,7 @@ use crate::module_name::ModuleName;
use crate::module_resolver::{resolve_module, Module};
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::semantic_index;
use crate::types::{definition_ty, global_symbol_ty_by_name, infer_scope_types, Type};
use crate::types::{definition_ty, global_symbol_ty, infer_scope_types, Type};
use crate::Db;
pub struct SemanticModel<'db> {
@ -40,7 +40,7 @@ impl<'db> SemanticModel<'db> {
}
pub fn global_symbol_ty(&self, module: &Module, symbol_name: &str) -> Type<'db> {
global_symbol_ty_by_name(self.db, module.file(), symbol_name)
global_symbol_ty(self.db, module.file(), symbol_name)
}
}

View file

@ -0,0 +1,77 @@
use crate::module_name::ModuleName;
use crate::module_resolver::resolve_module;
use crate::semantic_index::global_scope;
use crate::semantic_index::symbol::ScopeId;
use crate::types::{global_symbol_ty, Type};
use crate::Db;
/// Enumeration of various core stdlib modules, for which we have dedicated Salsa queries.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum CoreStdlibModule {
Builtins,
Types,
Typeshed,
}
impl CoreStdlibModule {
fn name(self) -> ModuleName {
let module_name = match self {
Self::Builtins => "builtins",
Self::Types => "types",
Self::Typeshed => "_typeshed",
};
ModuleName::new_static(module_name)
.unwrap_or_else(|| panic!("{module_name} should be a valid module name!"))
}
}
/// Lookup the type of `symbol` in a given core module
///
/// Returns `Unbound` if the given core module cannot be resolved for some reason
fn core_module_symbol_ty<'db>(
db: &'db dyn Db,
core_module: CoreStdlibModule,
symbol: &str,
) -> Type<'db> {
resolve_module(db, core_module.name())
.map(|module| global_symbol_ty(db, module.file(), symbol))
.unwrap_or(Type::Unbound)
}
/// Lookup the type of `symbol` in the builtins namespace.
///
/// Returns `Unbound` if the `builtins` module isn't available for some reason.
#[inline]
pub(crate) fn builtins_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> {
core_module_symbol_ty(db, CoreStdlibModule::Builtins, symbol)
}
/// Lookup the type of `symbol` in the `types` module namespace.
///
/// Returns `Unbound` if the `types` module isn't available for some reason.
#[inline]
pub(crate) fn types_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> {
core_module_symbol_ty(db, CoreStdlibModule::Types, symbol)
}
/// Lookup the type of `symbol` in the `_typeshed` module namespace.
///
/// Returns `Unbound` if the `_typeshed` module isn't available for some reason.
#[inline]
pub(crate) fn typeshed_symbol_ty<'db>(db: &'db dyn Db, symbol: &str) -> Type<'db> {
core_module_symbol_ty(db, CoreStdlibModule::Typeshed, symbol)
}
/// Get the scope of a core stdlib module.
///
/// Can return `None` if a custom typeshed is used that is missing the core module in question.
fn core_module_scope(db: &dyn Db, core_module: CoreStdlibModule) -> Option<ScopeId<'_>> {
resolve_module(db, core_module.name()).map(|module| global_scope(db, module.file()))
}
/// Get the `builtins` module scope.
///
/// Can return `None` if a custom typeshed is used that is missing `builtins.pyi`.
pub(crate) fn builtins_module_scope(db: &dyn Db) -> Option<ScopeId<'_>> {
core_module_scope(db, CoreStdlibModule::Builtins)
}

View file

@ -1,7 +1,6 @@
use ruff_db::files::File;
use ruff_python_ast as ast;
use crate::builtins::builtins_scope;
use crate::semantic_index::ast_ids::HasScopedAstId;
use crate::semantic_index::definition::{Definition, DefinitionKind};
use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId};
@ -9,6 +8,7 @@ use crate::semantic_index::{
global_scope, semantic_index, symbol_table, use_def_map, DefinitionWithConstraints,
DefinitionWithConstraintsIterator,
};
use crate::stdlib::{builtins_symbol_ty, types_symbol_ty, typeshed_symbol_ty};
use crate::types::narrow::narrowing_constraint;
use crate::{Db, FxOrderSet};
@ -40,7 +40,7 @@ pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics {
}
/// Infer the public type of a symbol (its type as seen from outside its scope).
pub(crate) fn symbol_ty<'db>(
pub(crate) fn symbol_ty_by_id<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
symbol: ScopedSymbolId,
@ -58,30 +58,17 @@ pub(crate) fn symbol_ty<'db>(
}
/// Shorthand for `symbol_ty` that takes a symbol name instead of an ID.
pub(crate) fn symbol_ty_by_name<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
name: &str,
) -> Type<'db> {
pub(crate) fn symbol_ty<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Type<'db> {
let table = symbol_table(db, scope);
table
.symbol_id_by_name(name)
.map(|symbol| symbol_ty(db, scope, symbol))
.map(|symbol| symbol_ty_by_id(db, scope, symbol))
.unwrap_or(Type::Unbound)
}
/// Shorthand for `symbol_ty` that looks up a module-global symbol by name in a file.
pub(crate) fn global_symbol_ty_by_name<'db>(db: &'db dyn Db, file: File, name: &str) -> Type<'db> {
symbol_ty_by_name(db, global_scope(db, file), name)
}
/// Shorthand for `symbol_ty` that looks up a symbol in the builtins.
///
/// Returns `Unbound` if the builtins module isn't available for some reason.
pub(crate) fn builtins_symbol_ty_by_name<'db>(db: &'db dyn Db, name: &str) -> Type<'db> {
builtins_scope(db)
.map(|builtins| symbol_ty_by_name(db, builtins, name))
.unwrap_or(Type::Unbound)
pub(crate) fn global_symbol_ty<'db>(db: &'db dyn Db, file: File, name: &str) -> Type<'db> {
symbol_ty(db, global_scope(db, file), name)
}
/// Infer the type of a [`Definition`].
@ -306,13 +293,9 @@ impl<'db> Type<'db> {
pub fn replace_unbound_with(&self, db: &'db dyn Db, replacement: Type<'db>) -> Type<'db> {
match self {
Type::Unbound => replacement,
Type::Union(union) => union
.elements(db)
.into_iter()
.fold(UnionBuilder::new(db), |builder, ty| {
builder.add(ty.replace_unbound_with(db, replacement))
})
.build(),
Type::Union(union) => {
union.map(db, |element| element.replace_unbound_with(db, replacement))
}
ty => *ty,
}
}
@ -331,7 +314,7 @@ impl<'db> Type<'db> {
/// us to explicitly consider whether to handle an error or propagate
/// it up the call stack.
#[must_use]
pub fn member(&self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> {
pub fn member(&self, db: &'db dyn Db, name: &str) -> Type<'db> {
match self {
Type::Any => Type::Any,
Type::Never => {
@ -348,19 +331,13 @@ impl<'db> Type<'db> {
// TODO: attribute lookup on function type
Type::Unknown
}
Type::Module(file) => global_symbol_ty_by_name(db, *file, name),
Type::Module(file) => global_symbol_ty(db, *file, name),
Type::Class(class) => class.class_member(db, name),
Type::Instance(_) => {
// TODO MRO? get_own_instance_member, get_instance_member
Type::Unknown
}
Type::Union(union) => union
.elements(db)
.iter()
.fold(UnionBuilder::new(db), |builder, element_ty| {
builder.add(element_ty.member(db, name))
})
.build(),
Type::Union(union) => union.map(db, |element| element.member(db, name)),
Type::Intersection(_) => {
// TODO perform the get_member on each type in the intersection
// TODO return the intersection of those results
@ -415,6 +392,38 @@ impl<'db> Type<'db> {
}
}
/// Given the type of an object that is iterated over in some way,
/// return the type of objects that are yielded by that iteration.
///
/// E.g., for the following loop, given the type of `x`, infer the type of `y`:
/// ```python
/// for y in x:
/// pass
/// ```
///
/// Returns `None` if `self` represents a type that is not iterable.
fn iterate(&self, db: &'db dyn Db) -> Option<Type<'db>> {
// `self` represents the type of the iterable;
// `__iter__` and `__next__` are both looked up on the class of the iterable:
let type_of_class = self.to_meta_type(db);
let dunder_iter_method = type_of_class.member(db, "__iter__");
if !dunder_iter_method.is_unbound() {
let iterator_ty = dunder_iter_method.call(db)?;
let dunder_next_method = iterator_ty.to_meta_type(db).member(db, "__next__");
return dunder_next_method.call(db);
}
// Although it's not considered great practice,
// classes that define `__getitem__` are also iterable,
// even if they do not define `__iter__`.
//
// TODO this is only valid if the `__getitem__` method is annotated as
// accepting `int` or `SupportsIndex`
let dunder_get_item_method = type_of_class.member(db, "__getitem__");
dunder_get_item_method.call(db)
}
#[must_use]
pub fn to_instance(&self) -> Type<'db> {
match self {
@ -424,6 +433,34 @@ impl<'db> Type<'db> {
_ => Type::Unknown, // TODO type errors
}
}
/// Given a type that is assumed to represent an instance of a class,
/// return a type that represents that class itself.
#[must_use]
pub fn to_meta_type(&self, db: &'db dyn Db) -> Type<'db> {
match self {
Type::Unbound => Type::Unbound,
Type::Never => Type::Never,
Type::Instance(class) => Type::Class(*class),
Type::Union(union) => union.map(db, |ty| ty.to_meta_type(db)),
Type::BooleanLiteral(_) => builtins_symbol_ty(db, "bool"),
Type::BytesLiteral(_) => builtins_symbol_ty(db, "bytes"),
Type::IntLiteral(_) => builtins_symbol_ty(db, "int"),
Type::Function(_) => types_symbol_ty(db, "FunctionType"),
Type::Module(_) => types_symbol_ty(db, "ModuleType"),
Type::None => typeshed_symbol_ty(db, "NoneType"),
// TODO not accurate if there's a custom metaclass...
Type::Class(_) => builtins_symbol_ty(db, "type"),
// TODO can we do better here? `type[LiteralString]`?
Type::StringLiteral(_) | Type::LiteralString => builtins_symbol_ty(db, "str"),
// TODO: `type[Any]`?
Type::Any => Type::Any,
// TODO: `type[Unknown]`?
Type::Unknown => Type::Unknown,
// TODO intersections
Type::Intersection(_) => Type::Unknown,
}
}
}
#[salsa::interned]
@ -504,7 +541,7 @@ impl<'db> ClassType<'db> {
/// Returns the class member of this class named `name`.
///
/// The member resolves to a member of the class itself or any of its bases.
pub fn class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> {
pub fn class_member(self, db: &'db dyn Db, name: &str) -> Type<'db> {
let member = self.own_class_member(db, name);
if !member.is_unbound() {
return member;
@ -514,12 +551,12 @@ impl<'db> ClassType<'db> {
}
/// Returns the inferred type of the class member named `name`.
pub fn own_class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> {
pub fn own_class_member(self, db: &'db dyn Db, name: &str) -> Type<'db> {
let scope = self.body_scope(db);
symbol_ty_by_name(db, scope, name)
symbol_ty(db, scope, name)
}
pub fn inherited_class_member(self, db: &'db dyn Db, name: &ast::name::Name) -> Type<'db> {
pub fn inherited_class_member(self, db: &'db dyn Db, name: &str) -> Type<'db> {
for base in self.bases(db) {
let member = base.member(db, name);
if !member.is_unbound() {
@ -542,6 +579,21 @@ impl<'db> UnionType<'db> {
pub fn contains(&self, db: &'db dyn Db, ty: Type<'db>) -> bool {
self.elements(db).contains(&ty)
}
/// Apply a transformation function to all elements of the union,
/// and create a new union from the resulting set of types
pub fn map(
&self,
db: &'db dyn Db,
mut transform_fn: impl FnMut(&Type<'db>) -> Type<'db>,
) -> Type<'db> {
self.elements(db)
.into_iter()
.fold(UnionBuilder::new(db), |builder, element| {
builder.add(transform_fn(element))
})
.build()
}
}
#[salsa::interned]
@ -688,4 +740,53 @@ mod tests {
&["Object of type 'Literal[123]' is not callable"],
);
}
#[test]
fn invalid_iterable() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
nonsense = 123
for x in nonsense:
pass
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'Literal[123]' is not iterable"],
);
}
#[test]
fn new_iteration_protocol_takes_precedence_over_old_style() {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class NotIterable:
def __getitem__(self, key: int) -> int:
return 42
__iter__ = None
for x in NotIterable():
pass
",
)
.unwrap();
let a_file = system_path_to_file(&db, "/src/a.py").unwrap();
let a_file_diagnostics = super::check_types(&db, a_file);
assert_diagnostic_messages(
&a_file_diagnostics,
&["Object of type 'NotIterable' is not iterable"],
);
}
}

View file

@ -25,13 +25,10 @@
//! * No type in an intersection can be a supertype of any other type in the intersection (just
//! eliminate the supertype from the intersection).
//! * An intersection containing two non-overlapping types should simplify to [`Type::Never`].
use crate::types::{IntersectionType, Type, UnionType};
use crate::types::{builtins_symbol_ty, IntersectionType, Type, UnionType};
use crate::{Db, FxOrderSet};
use ordermap::set::MutableValues;
use super::builtins_symbol_ty_by_name;
pub(crate) struct UnionBuilder<'db> {
elements: FxOrderSet<Type<'db>>,
db: &'db dyn Db,
@ -68,7 +65,7 @@ impl<'db> UnionBuilder<'db> {
if let Some(true_index) = self.elements.get_index_of(&Type::BooleanLiteral(true)) {
if self.elements.contains(&Type::BooleanLiteral(false)) {
*self.elements.get_index_mut2(true_index).unwrap() =
builtins_symbol_ty_by_name(self.db, "bool");
builtins_symbol_ty(self.db, "bool");
self.elements.remove(&Type::BooleanLiteral(false));
}
}
@ -278,7 +275,7 @@ mod tests {
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
use crate::types::builtins_symbol_ty_by_name;
use crate::types::builtins_symbol_ty;
use crate::ProgramSettings;
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
@ -351,7 +348,7 @@ mod tests {
#[test]
fn build_union_bool() {
let db = setup_db();
let bool_ty = builtins_symbol_ty_by_name(&db, "bool");
let bool_ty = builtins_symbol_ty(&db, "bool");
let t0 = Type::BooleanLiteral(true);
let t1 = Type::BooleanLiteral(true);

View file

@ -236,9 +236,7 @@ mod tests {
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use crate::db::tests::TestDb;
use crate::types::{
global_symbol_ty_by_name, BytesLiteralType, StringLiteralType, Type, UnionBuilder,
};
use crate::types::{global_symbol_ty, BytesLiteralType, StringLiteralType, Type, UnionBuilder};
use crate::{Program, ProgramSettings, PythonVersion, SearchPathSettings};
fn setup_db() -> TestDb {
@ -283,16 +281,16 @@ mod tests {
let vec: Vec<Type<'_>> = vec![
Type::Unknown,
Type::IntLiteral(-1),
global_symbol_ty_by_name(&db, mod_file, "A"),
global_symbol_ty(&db, mod_file, "A"),
Type::StringLiteral(StringLiteralType::new(&db, Box::from("A"))),
Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([0]))),
Type::BytesLiteral(BytesLiteralType::new(&db, Box::from([7]))),
Type::IntLiteral(0),
Type::IntLiteral(1),
Type::StringLiteral(StringLiteralType::new(&db, Box::from("B"))),
global_symbol_ty_by_name(&db, mod_file, "foo"),
global_symbol_ty_by_name(&db, mod_file, "bar"),
global_symbol_ty_by_name(&db, mod_file, "B"),
global_symbol_ty(&db, mod_file, "foo"),
global_symbol_ty(&db, mod_file, "bar"),
global_symbol_ty(&db, mod_file, "B"),
Type::BooleanLiteral(true),
Type::None,
];

View file

@ -37,7 +37,6 @@ use ruff_db::parsed::parsed_module;
use ruff_python_ast::{self as ast, AnyNodeRef, ExprContext, UnaryOp};
use ruff_text_size::Ranged;
use crate::builtins::builtins_scope;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
@ -46,11 +45,11 @@ use crate::semantic_index::expression::Expression;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
use crate::semantic_index::SemanticIndex;
use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
builtins_symbol_ty_by_name, definitions_ty, global_symbol_ty_by_name, symbol_ty,
symbol_ty_by_name, BytesLiteralType, ClassType, FunctionType, StringLiteralType, Type,
UnionBuilder,
builtins_symbol_ty, definitions_ty, global_symbol_ty, symbol_ty, symbol_ty_by_id,
BytesLiteralType, ClassType, FunctionType, StringLiteralType, Type, UnionBuilder,
};
use crate::Db;
@ -1043,18 +1042,17 @@ impl<'db> TypeInferenceBuilder<'db> {
.types
.expression_ty(iterable.scoped_ast_id(self.db, self.scope));
// TODO(Alex): only a valid iterable if the *type* of `iterable_ty` has an `__iter__`
// member (dunders are never looked up on an instance)
let _dunder_iter_ty = iterable_ty.member(self.db, &ast::name::Name::from("__iter__"));
// TODO(Alex):
// - infer the return type of the `__iter__` method, which gives us the iterator
// - lookup the `__next__` method on the iterator
// - infer the return type of the iterator's `__next__` method,
// which gives us the type of the variable being bound here
// (...or the type of the object being unpacked into multiple definitions, if it's something like
// `for k, v in d.items(): ...`)
let loop_var_value_ty = Type::Unknown;
let loop_var_value_ty = iterable_ty.iterate(self.db).unwrap_or_else(|| {
self.add_diagnostic(
iterable.into(),
"not-iterable",
format_args!(
"Object of type '{}' is not iterable",
iterable_ty.display(self.db)
),
);
Type::Unknown
});
self.types
.expressions
@ -1400,11 +1398,9 @@ impl<'db> TypeInferenceBuilder<'db> {
ast::Number::Int(n) => n
.as_i64()
.map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
ast::Number::Float(_) => builtins_symbol_ty_by_name(self.db, "float").to_instance(),
ast::Number::Complex { .. } => {
builtins_symbol_ty_by_name(self.db, "complex").to_instance()
}
.unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
ast::Number::Float(_) => builtins_symbol_ty(self.db, "float").to_instance(),
ast::Number::Complex { .. } => builtins_symbol_ty(self.db, "complex").to_instance(),
}
}
@ -1482,12 +1478,11 @@ impl<'db> TypeInferenceBuilder<'db> {
}
}
#[allow(clippy::unused_self)]
fn infer_ellipsis_literal_expression(
&mut self,
_literal: &ast::ExprEllipsisLiteral,
) -> Type<'db> {
builtins_symbol_ty_by_name(self.db, "Ellipsis")
builtins_symbol_ty(self.db, "Ellipsis")
}
fn infer_tuple_expression(&mut self, tuple: &ast::ExprTuple) -> Type<'db> {
@ -1503,7 +1498,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty_by_name(self.db, "tuple").to_instance()
builtins_symbol_ty(self.db, "tuple").to_instance()
}
fn infer_list_expression(&mut self, list: &ast::ExprList) -> Type<'db> {
@ -1518,7 +1513,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty_by_name(self.db, "list").to_instance()
builtins_symbol_ty(self.db, "list").to_instance()
}
fn infer_set_expression(&mut self, set: &ast::ExprSet) -> Type<'db> {
@ -1529,7 +1524,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty_by_name(self.db, "set").to_instance()
builtins_symbol_ty(self.db, "set").to_instance()
}
fn infer_dict_expression(&mut self, dict: &ast::ExprDict) -> Type<'db> {
@ -1541,7 +1536,7 @@ impl<'db> TypeInferenceBuilder<'db> {
}
// TODO generic
builtins_symbol_ty_by_name(self.db, "dict").to_instance()
builtins_symbol_ty(self.db, "dict").to_instance()
}
/// Infer the type of the `iter` expression of the first comprehension.
@ -1884,7 +1879,7 @@ impl<'db> TypeInferenceBuilder<'db> {
// runtime, it is the scope that creates the cell for our closure.) If the name
// isn't bound in that scope, we should get an unbound name, not continue
// falling back to other scopes / globals / builtins.
return symbol_ty_by_name(self.db, enclosing_scope_id, name);
return symbol_ty(self.db, enclosing_scope_id, name);
}
}
// No nonlocal binding, check module globals. Avoid infinite recursion if `self.scope`
@ -1892,11 +1887,11 @@ impl<'db> TypeInferenceBuilder<'db> {
let ty = if file_scope_id.is_global() {
Type::Unbound
} else {
global_symbol_ty_by_name(self.db, self.file, name)
global_symbol_ty(self.db, self.file, name)
};
// Fallback to builtins (without infinite recursion if we're already in builtins.)
if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_scope(self.db) {
ty.replace_unbound_with(self.db, builtins_symbol_ty_by_name(self.db, name))
if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_module_scope(self.db) {
ty.replace_unbound_with(self.db, builtins_symbol_ty(self.db, name))
} else {
ty
}
@ -1915,7 +1910,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let symbol = symbols
.symbol_id_by_name(id)
.expect("Expected the symbol table to create a symbol for every Name node");
return symbol_ty(self.db, self.scope, symbol);
return symbol_ty_by_id(self.db, self.scope, symbol);
}
match ctx {
@ -1986,22 +1981,22 @@ impl<'db> TypeInferenceBuilder<'db> {
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Add) => n
.checked_add(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
.unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Sub) => n
.checked_sub(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
.unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mult) => n
.checked_mul(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
.unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Div) => n
.checked_div(m)
.map(Type::IntLiteral)
.unwrap_or_else(|| builtins_symbol_ty_by_name(self.db, "int").to_instance()),
.unwrap_or_else(|| builtins_symbol_ty(self.db, "int").to_instance()),
(Type::IntLiteral(n), Type::IntLiteral(m), ast::Operator::Mod) => n
.checked_rem(m)
@ -2380,14 +2375,14 @@ mod tests {
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::name::Name;
use crate::builtins::builtins_scope;
use crate::db::tests::TestDb;
use crate::program::{Program, SearchPathSettings};
use crate::python_version::PythonVersion;
use crate::semantic_index::definition::Definition;
use crate::semantic_index::symbol::FileScopeId;
use crate::semantic_index::{global_scope, semantic_index, symbol_table, use_def_map};
use crate::types::{global_symbol_ty_by_name, infer_definition_types, symbol_ty_by_name};
use crate::stdlib::builtins_module_scope;
use crate::types::{global_symbol_ty, infer_definition_types, symbol_ty};
use crate::{HasTy, ProgramSettings, SemanticModel};
use super::TypeInferenceBuilder;
@ -2440,7 +2435,7 @@ mod tests {
fn assert_public_ty(db: &TestDb, file_name: &str, symbol_name: &str, expected: &str) {
let file = system_path_to_file(db, file_name).expect("Expected file to exist.");
let ty = global_symbol_ty_by_name(db, file, symbol_name);
let ty = global_symbol_ty(db, file, symbol_name);
assert_eq!(ty.display(db).to_string(), expected);
}
@ -2465,7 +2460,7 @@ mod tests {
assert_eq!(scope.name(db), *expected_scope_name);
}
let ty = symbol_ty_by_name(db, scope, symbol_name);
let ty = symbol_ty(db, scope, symbol_name);
assert_eq!(ty.display(db).to_string(), expected);
}
@ -2669,7 +2664,7 @@ mod tests {
)?;
let mod_file = system_path_to_file(&db, "src/mod.py").expect("Expected file to exist.");
let ty = global_symbol_ty_by_name(&db, mod_file, "Sub");
let ty = global_symbol_ty(&db, mod_file, "Sub");
let class = ty.expect_class();
@ -2696,7 +2691,7 @@ mod tests {
)?;
let mod_file = system_path_to_file(&db, "src/mod.py").unwrap();
let ty = global_symbol_ty_by_name(&db, mod_file, "C");
let ty = global_symbol_ty(&db, mod_file, "C");
let class_id = ty.expect_class();
let member_ty = class_id.class_member(&db, &Name::new_static("f"));
let func = member_ty.expect_function();
@ -2900,7 +2895,7 @@ mod tests {
db.write_file("src/a.py", "def example() -> int: return 42")?;
let mod_file = system_path_to_file(&db, "src/a.py").unwrap();
let function = global_symbol_ty_by_name(&db, mod_file, "example").expect_function();
let function = global_symbol_ty(&db, mod_file, "example").expect_function();
let returns = function.return_type(&db);
assert_eq!(returns.display(&db).to_string(), "int");
@ -2975,6 +2970,52 @@ mod tests {
Ok(())
}
#[test]
fn basic_for_loop() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class IntIterator:
def __next__(self) -> int:
return 42
class IntIterable:
def __iter__(self) -> IntIterator:
return IntIterator()
for x in IntIterable():
pass
",
)?;
assert_public_ty(&db, "src/a.py", "x", "int");
Ok(())
}
#[test]
fn for_loop_with_old_style_iteration_protocol() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
"
class OldStyleIterable:
def __getitem__(self, key: int) -> int:
return 42
for x in OldStyleIterable():
pass
",
)?;
assert_public_ty(&db, "src/a.py", "x", "int");
Ok(())
}
#[test]
fn class_constructor_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();
@ -3317,7 +3358,7 @@ mod tests {
)?;
let a = system_path_to_file(&db, "src/a.py").expect("Expected file to exist.");
let c_ty = global_symbol_ty_by_name(&db, a, "C");
let c_ty = global_symbol_ty(&db, a, "C");
let c_class = c_ty.expect_class();
let mut c_bases = c_class.bases(&db);
let b_ty = c_bases.next().unwrap();
@ -3354,8 +3395,8 @@ mod tests {
.unwrap()
.0
.to_scope_id(&db, file);
let y_ty = symbol_ty_by_name(&db, function_scope, "y");
let x_ty = symbol_ty_by_name(&db, function_scope, "x");
let y_ty = symbol_ty(&db, function_scope, "y");
let x_ty = symbol_ty(&db, function_scope, "x");
assert_eq!(y_ty.display(&db).to_string(), "Unbound");
assert_eq!(x_ty.display(&db).to_string(), "Literal[2]");
@ -3385,8 +3426,8 @@ mod tests {
.unwrap()
.0
.to_scope_id(&db, file);
let y_ty = symbol_ty_by_name(&db, function_scope, "y");
let x_ty = symbol_ty_by_name(&db, function_scope, "x");
let y_ty = symbol_ty(&db, function_scope, "y");
let x_ty = symbol_ty(&db, function_scope, "x");
assert_eq!(x_ty.display(&db).to_string(), "Unbound");
assert_eq!(y_ty.display(&db).to_string(), "Literal[1]");
@ -3416,7 +3457,7 @@ mod tests {
.unwrap()
.0
.to_scope_id(&db, file);
let y_ty = symbol_ty_by_name(&db, function_scope, "y");
let y_ty = symbol_ty(&db, function_scope, "y");
assert_eq!(
y_ty.display(&db).to_string(),
@ -3450,8 +3491,8 @@ mod tests {
.unwrap()
.0
.to_scope_id(&db, file);
let y_ty = symbol_ty_by_name(&db, class_scope, "y");
let x_ty = symbol_ty_by_name(&db, class_scope, "x");
let y_ty = symbol_ty(&db, class_scope, "y");
let x_ty = symbol_ty(&db, class_scope, "x");
assert_eq!(x_ty.display(&db).to_string(), "Unbound | Literal[2]");
assert_eq!(y_ty.display(&db).to_string(), "Literal[1]");
@ -3544,9 +3585,11 @@ mod tests {
assert_public_ty(&db, "/src/a.py", "x", "Literal[copyright]");
// imported builtins module is the same file as the implicit builtins
let file = system_path_to_file(&db, "/src/a.py").expect("Expected file to exist.");
let builtins_ty = global_symbol_ty_by_name(&db, file, "builtins");
let builtins_ty = global_symbol_ty(&db, file, "builtins");
let builtins_file = builtins_ty.expect_module();
let implicit_builtins_file = builtins_scope(&db).expect("builtins to exist").file(&db);
let implicit_builtins_file = builtins_module_scope(&db)
.expect("builtins module should exist")
.file(&db);
assert_eq!(builtins_file, implicit_builtins_file);
Ok(())
@ -3850,7 +3893,7 @@ mod tests {
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol_ty_by_name(&db, a, "x");
let x_ty = global_symbol_ty(&db, a, "x");
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
@ -3859,7 +3902,7 @@ mod tests {
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty_2 = global_symbol_ty_by_name(&db, a, "x");
let x_ty_2 = global_symbol_ty(&db, a, "x");
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[20]");
@ -3876,7 +3919,7 @@ mod tests {
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol_ty_by_name(&db, a, "x");
let x_ty = global_symbol_ty(&db, a, "x");
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
@ -3886,7 +3929,7 @@ mod tests {
db.clear_salsa_events();
let x_ty_2 = global_symbol_ty_by_name(&db, a, "x");
let x_ty_2 = global_symbol_ty(&db, a, "x");
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");
@ -3912,7 +3955,7 @@ mod tests {
])?;
let a = system_path_to_file(&db, "/src/a.py").unwrap();
let x_ty = global_symbol_ty_by_name(&db, a, "x");
let x_ty = global_symbol_ty(&db, a, "x");
assert_eq!(x_ty.display(&db).to_string(), "Literal[10]");
@ -3922,7 +3965,7 @@ mod tests {
db.clear_salsa_events();
let x_ty_2 = global_symbol_ty_by_name(&db, a, "x");
let x_ty_2 = global_symbol_ty(&db, a, "x");
assert_eq!(x_ty_2.display(&db).to_string(), "Literal[10]");