From ce030a467f9833e89498cb2cdcf48d9c03f54f6d Mon Sep 17 00:00:00 2001 From: Carl Meyer Date: Mon, 29 Apr 2024 16:22:30 -0600 Subject: [PATCH] [red-knot] resolve base class types (#11178) ## Summary Resolve base class types, as long as they are simple names. ## Test Plan cargo test --- crates/red_knot/src/symbols.rs | 36 +++++++++-- crates/red_knot/src/types.rs | 31 +++++----- crates/red_knot/src/types/infer.rs | 97 +++++++++++++++++++++++++----- 3 files changed, 129 insertions(+), 35 deletions(-) diff --git a/crates/red_knot/src/symbols.rs b/crates/red_knot/src/symbols.rs index 0af5efb47e..c8eee4489d 100644 --- a/crates/red_knot/src/symbols.rs +++ b/crates/red_knot/src/symbols.rs @@ -95,7 +95,7 @@ impl Symbol { // TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST; // this is at best O(log n). If looking up definitions is a bottleneck we should look for // alternatives here. -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) enum Definition { // For the import cases, we don't need reference to any arbitrary AST subtrees (annotations, // RHS), and referencing just the import statement node is imprecise (a single import statement @@ -110,12 +110,12 @@ pub(crate) enum Definition { // TODO with statements, except handlers, function args... } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct ImportDefinition { pub(crate) module: ModuleName, } -#[derive(Debug)] +#[derive(Clone, Debug)] pub(crate) struct ImportFromDefinition { pub(crate) module: Option, pub(crate) name: Name, @@ -160,6 +160,7 @@ impl SymbolTable { let mut builder = SymbolTableBuilder { table: SymbolTable::new(), scopes: vec![root_scope_id], + current_definition: None, }; builder.visit_body(&module.body); builder.table @@ -386,6 +387,8 @@ where struct SymbolTableBuilder { table: SymbolTable, scopes: Vec, + /// the definition whose target(s) we are currently walking + current_definition: Option, } impl SymbolTableBuilder { @@ -448,8 +451,13 @@ impl SymbolTableBuilder { impl PreorderVisitor<'_> for SymbolTableBuilder { fn visit_expr(&mut self, expr: &ast::Expr) { - if let ast::Expr::Name(ast::ExprName { id, .. }) = expr { + if let ast::Expr::Name(ast::ExprName { id, ctx, .. }) = expr { self.add_symbol(id); + if matches!(ctx, ast::ExprContext::Store | ast::ExprContext::Del) { + if let Some(curdef) = self.current_definition.clone() { + self.add_symbol_with_def(id, curdef); + } + } } ast::visitor::preorder::walk_expr(self, expr); } @@ -532,6 +540,13 @@ impl PreorderVisitor<'_> for SymbolTableBuilder { self.table.dependencies.push(dependency); } + ast::Stmt::Assign(node) => { + debug_assert!(self.current_definition.is_none()); + self.current_definition = + Some(Definition::Assignment(TypedNodeKey::from_node(node))); + ast::visitor::preorder::walk_stmt(self, stmt); + self.current_definition = None; + } _ => { ast::visitor::preorder::walk_stmt(self, stmt); } @@ -649,6 +664,19 @@ mod tests { ); } + #[test] + fn assign() { + let parsed = parse("x = foo"); + let table = SymbolTable::from_ast(parsed.ast()); + assert_eq!(names(table.root_symbols()), vec!["foo", "x"]); + assert_eq!( + table + .definitions(table.root_symbol_id_by_name("x").unwrap()) + .len(), + 1 + ); + } + #[test] fn class_scope() { let parsed = parse( diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index fb4832158c..7827c036d6 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -53,12 +53,6 @@ impl From for Type { } } -impl From for Type { - fn from(id: ClassTypeId) -> Self { - Type::Class(id) - } -} - impl From for Type { fn from(id: UnionTypeId) -> Self { Type::Union(id) @@ -129,8 +123,8 @@ impl TypeStore { self.add_or_get_module(file_id).add_function(name) } - fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId { - self.add_or_get_module(file_id).add_class(name) + fn add_class(&self, file_id: FileId, name: &str, bases: Vec) -> ClassTypeId { + self.add_or_get_module(file_id).add_class(name, bases) } fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId { @@ -322,9 +316,11 @@ impl ModuleTypeStore { } } - fn add_class(&mut self, name: &str) -> ClassTypeId { + fn add_class(&mut self, name: &str, bases: Vec) -> ClassTypeId { let class_id = self.classes.push(ClassType { name: Name::new(name), + // TODO: if no bases are given, that should imply [object] + bases, }); ClassTypeId { file_id: self.file_id, @@ -408,12 +404,17 @@ impl std::fmt::Display for DisplayType<'_> { #[derive(Debug)] pub(crate) struct ClassType { name: Name, + bases: Vec, } impl ClassType { fn name(&self) -> &str { self.name.as_str() } + + fn bases(&self) -> &[Type] { + self.bases.as_slice() + } } #[derive(Debug)] @@ -497,7 +498,7 @@ mod tests { let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let id = store.add_class(file_id, "C"); + let id = store.add_class(file_id, "C", Vec::new()); assert_eq!(store.get_class(id).name(), "C"); let inst = Type::Instance(id); assert_eq!(format!("{}", inst.display(&store)), "C"); @@ -519,8 +520,8 @@ mod tests { let mut store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let c1 = store.add_class(file_id, "C1"); - let c2 = store.add_class(file_id, "C2"); + let c1 = store.add_class(file_id, "C1", Vec::new()); + let c2 = store.add_class(file_id, "C2", Vec::new()); let elems = vec![Type::Instance(c1), Type::Instance(c2)]; let id = store.add_union(file_id, &elems); assert_eq!( @@ -536,9 +537,9 @@ mod tests { let mut store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let c1 = store.add_class(file_id, "C1"); - let c2 = store.add_class(file_id, "C2"); - let c3 = store.add_class(file_id, "C3"); + let c1 = store.add_class(file_id, "C1", Vec::new()); + let c2 = store.add_class(file_id, "C2", Vec::new()); + let c3 = store.add_class(file_id, "C3", Vec::new()); let pos = vec![Type::Instance(c1), Type::Instance(c2)]; let neg = vec![Type::Instance(c3)]; let id = store.add_intersection(file_id, &pos, &neg); diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index d017785321..aba308d14a 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -7,6 +7,7 @@ use crate::module::ModuleName; use crate::symbols::{Definition, ImportFromDefinition, SymbolId}; use crate::types::Type; use crate::FileId; +use ruff_python_ast as ast; // FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`. #[tracing::instrument(level = "trace", skip(db))] @@ -58,8 +59,14 @@ where let ast = parsed.ast(); let node = node_key.resolve_unwrap(ast.as_any_node_ref()); + let bases: Vec<_> = node + .bases() + .iter() + .map(|base_expr| infer_expr_type(db, file_id, base_expr)) + .collect(); + let store = &db.jar().type_store; - let ty: Type = store.add_class(file_id, &node.name.id).into(); + let ty = Type::Class(store.add_class(file_id, &node.name.id, bases)); store.cache_node_type(file_id, *node_key.erased(), ty); ty }), @@ -79,6 +86,13 @@ where store.cache_node_type(file_id, *node_key.erased(), ty); ty }), + Definition::Assignment(node_key) => { + let parsed = db.parse(file_id); + let ast = parsed.ast(); + let node = node_key.resolve_unwrap(ast.as_any_node_ref()); + // TODO handle unpacking assignment correctly + infer_expr_type(db, file_id, &node.value) + } _ => todo!("other kinds of definitions"), }; @@ -89,6 +103,24 @@ where ty } +fn infer_expr_type(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type +where + Db: SemanticDb + HasJar, +{ + // TODO cache the resolution of the type on the node + let symbols = db.symbol_table(file_id); + match expr { + ast::Expr::Name(name) => { + if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) { + db.infer_symbol_type(file_id, symbol_id) + } else { + Type::Unknown + } + } + _ => todo!("full expression type resolution"), + } +} + #[cfg(test)] mod tests { use crate::db::tests::TestDb; @@ -123,32 +155,65 @@ mod tests { #[test] fn follow_import_to_class() -> std::io::Result<()> { - let TestCase { - src, - db, - temp_dir: _temp_dir, - } = create_test()?; + let case = create_test()?; + let db = &case.db; - let a_path = src.path().join("a.py"); - let b_path = src.path().join("b.py"); - std::fs::write(a_path, "from b import C as D")?; + let a_path = case.src.path().join("a.py"); + let b_path = case.src.path().join("b.py"); + std::fs::write(a_path, "from b import C as D; E = D")?; std::fs::write(b_path, "class C: pass")?; let a_file = db .resolve_module(ModuleName::new("a")) .expect("module should be found") - .path(&db) + .path(db) .file(); let a_syms = db.symbol_table(a_file); - let d_sym = a_syms - .root_symbol_id_by_name("D") - .expect("D symbol should be found"); + let e_sym = a_syms + .root_symbol_id_by_name("E") + .expect("E symbol should be found"); - let ty = db.infer_symbol_type(a_file, d_sym); - - let jar = HasJar::::jar(&db); + let ty = db.infer_symbol_type(a_file, e_sym); + let jar = HasJar::::jar(db); assert!(matches!(ty, Type::Class(_))); assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]"); + + Ok(()) + } + + #[test] + fn resolve_base_class_by_name() -> std::io::Result<()> { + let case = create_test()?; + let db = &case.db; + + let path = case.src.path().join("mod.py"); + std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?; + let file = db + .resolve_module(ModuleName::new("mod")) + .expect("module should be found") + .path(db) + .file(); + let syms = db.symbol_table(file); + let sym = syms + .root_symbol_id_by_name("Sub") + .expect("Sub symbol should be found"); + + let ty = db.infer_symbol_type(file, sym); + + let Type::Class(class_id) = ty else { + panic!("Sub is not a Class") + }; + let jar = HasJar::::jar(db); + let base_names: Vec<_> = jar + .type_store + .get_class(class_id) + .bases() + .iter() + .map(|base_ty| format!("{}", base_ty.display(&jar.type_store))) + .collect(); + + assert_eq!(base_names, vec!["Literal[Base]"]); + Ok(()) } }