[red-knot] resolve base class types (#11178)

## Summary

Resolve base class types, as long as they are simple names.

## Test Plan

cargo test
This commit is contained in:
Carl Meyer 2024-04-29 16:22:30 -06:00 committed by GitHub
parent 04a922866a
commit ce030a467f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 129 additions and 35 deletions

View file

@ -95,7 +95,7 @@ impl Symbol {
// TODO storing TypedNodeKey for definitions means we have to search to find them again in the AST;
// this is at best O(log n). If looking up definitions is a bottleneck we should look for
// alternatives here.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub(crate) enum Definition {
// For the import cases, we don't need reference to any arbitrary AST subtrees (annotations,
// RHS), and referencing just the import statement node is imprecise (a single import statement
@ -110,12 +110,12 @@ pub(crate) enum Definition {
// TODO with statements, except handlers, function args...
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub(crate) struct ImportDefinition {
pub(crate) module: ModuleName,
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub(crate) struct ImportFromDefinition {
pub(crate) module: Option<ModuleName>,
pub(crate) name: Name,
@ -160,6 +160,7 @@ impl SymbolTable {
let mut builder = SymbolTableBuilder {
table: SymbolTable::new(),
scopes: vec![root_scope_id],
current_definition: None,
};
builder.visit_body(&module.body);
builder.table
@ -386,6 +387,8 @@ where
struct SymbolTableBuilder {
table: SymbolTable,
scopes: Vec<ScopeId>,
/// the definition whose target(s) we are currently walking
current_definition: Option<Definition>,
}
impl SymbolTableBuilder {
@ -448,8 +451,13 @@ impl SymbolTableBuilder {
impl PreorderVisitor<'_> for SymbolTableBuilder {
fn visit_expr(&mut self, expr: &ast::Expr) {
if let ast::Expr::Name(ast::ExprName { id, .. }) = expr {
if let ast::Expr::Name(ast::ExprName { id, ctx, .. }) = expr {
self.add_symbol(id);
if matches!(ctx, ast::ExprContext::Store | ast::ExprContext::Del) {
if let Some(curdef) = self.current_definition.clone() {
self.add_symbol_with_def(id, curdef);
}
}
}
ast::visitor::preorder::walk_expr(self, expr);
}
@ -532,6 +540,13 @@ impl PreorderVisitor<'_> for SymbolTableBuilder {
self.table.dependencies.push(dependency);
}
ast::Stmt::Assign(node) => {
debug_assert!(self.current_definition.is_none());
self.current_definition =
Some(Definition::Assignment(TypedNodeKey::from_node(node)));
ast::visitor::preorder::walk_stmt(self, stmt);
self.current_definition = None;
}
_ => {
ast::visitor::preorder::walk_stmt(self, stmt);
}
@ -649,6 +664,19 @@ mod tests {
);
}
#[test]
fn assign() {
let parsed = parse("x = foo");
let table = SymbolTable::from_ast(parsed.ast());
assert_eq!(names(table.root_symbols()), vec!["foo", "x"]);
assert_eq!(
table
.definitions(table.root_symbol_id_by_name("x").unwrap())
.len(),
1
);
}
#[test]
fn class_scope() {
let parsed = parse(

View file

@ -53,12 +53,6 @@ impl From<FunctionTypeId> for Type {
}
}
impl From<ClassTypeId> for Type {
fn from(id: ClassTypeId) -> Self {
Type::Class(id)
}
}
impl From<UnionTypeId> for Type {
fn from(id: UnionTypeId) -> Self {
Type::Union(id)
@ -129,8 +123,8 @@ impl TypeStore {
self.add_or_get_module(file_id).add_function(name)
}
fn add_class(&self, file_id: FileId, name: &str) -> ClassTypeId {
self.add_or_get_module(file_id).add_class(name)
fn add_class(&self, file_id: FileId, name: &str, bases: Vec<Type>) -> ClassTypeId {
self.add_or_get_module(file_id).add_class(name, bases)
}
fn add_union(&mut self, file_id: FileId, elems: &[Type]) -> UnionTypeId {
@ -322,9 +316,11 @@ impl ModuleTypeStore {
}
}
fn add_class(&mut self, name: &str) -> ClassTypeId {
fn add_class(&mut self, name: &str, bases: Vec<Type>) -> ClassTypeId {
let class_id = self.classes.push(ClassType {
name: Name::new(name),
// TODO: if no bases are given, that should imply [object]
bases,
});
ClassTypeId {
file_id: self.file_id,
@ -408,12 +404,17 @@ impl std::fmt::Display for DisplayType<'_> {
#[derive(Debug)]
pub(crate) struct ClassType {
name: Name,
bases: Vec<Type>,
}
impl ClassType {
fn name(&self) -> &str {
self.name.as_str()
}
fn bases(&self) -> &[Type] {
self.bases.as_slice()
}
}
#[derive(Debug)]
@ -497,7 +498,7 @@ mod tests {
let store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let id = store.add_class(file_id, "C");
let id = store.add_class(file_id, "C", Vec::new());
assert_eq!(store.get_class(id).name(), "C");
let inst = Type::Instance(id);
assert_eq!(format!("{}", inst.display(&store)), "C");
@ -519,8 +520,8 @@ mod tests {
let mut store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1");
let c2 = store.add_class(file_id, "C2");
let c1 = store.add_class(file_id, "C1", Vec::new());
let c2 = store.add_class(file_id, "C2", Vec::new());
let elems = vec![Type::Instance(c1), Type::Instance(c2)];
let id = store.add_union(file_id, &elems);
assert_eq!(
@ -536,9 +537,9 @@ mod tests {
let mut store = TypeStore::default();
let files = Files::default();
let file_id = files.intern(Path::new("/foo"));
let c1 = store.add_class(file_id, "C1");
let c2 = store.add_class(file_id, "C2");
let c3 = store.add_class(file_id, "C3");
let c1 = store.add_class(file_id, "C1", Vec::new());
let c2 = store.add_class(file_id, "C2", Vec::new());
let c3 = store.add_class(file_id, "C3", Vec::new());
let pos = vec![Type::Instance(c1), Type::Instance(c2)];
let neg = vec![Type::Instance(c3)];
let id = store.add_intersection(file_id, &pos, &neg);

View file

@ -7,6 +7,7 @@ use crate::module::ModuleName;
use crate::symbols::{Definition, ImportFromDefinition, SymbolId};
use crate::types::Type;
use crate::FileId;
use ruff_python_ast as ast;
// FIXME: Figure out proper dead-lock free synchronisation now that this takes `&db` instead of `&mut db`.
#[tracing::instrument(level = "trace", skip(db))]
@ -58,8 +59,14 @@ where
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
let bases: Vec<_> = node
.bases()
.iter()
.map(|base_expr| infer_expr_type(db, file_id, base_expr))
.collect();
let store = &db.jar().type_store;
let ty: Type = store.add_class(file_id, &node.name.id).into();
let ty = Type::Class(store.add_class(file_id, &node.name.id, bases));
store.cache_node_type(file_id, *node_key.erased(), ty);
ty
}),
@ -79,6 +86,13 @@ where
store.cache_node_type(file_id, *node_key.erased(), ty);
ty
}),
Definition::Assignment(node_key) => {
let parsed = db.parse(file_id);
let ast = parsed.ast();
let node = node_key.resolve_unwrap(ast.as_any_node_ref());
// TODO handle unpacking assignment correctly
infer_expr_type(db, file_id, &node.value)
}
_ => todo!("other kinds of definitions"),
};
@ -89,6 +103,24 @@ where
ty
}
fn infer_expr_type<Db>(db: &Db, file_id: FileId, expr: &ast::Expr) -> Type
where
Db: SemanticDb + HasJar<SemanticJar>,
{
// TODO cache the resolution of the type on the node
let symbols = db.symbol_table(file_id);
match expr {
ast::Expr::Name(name) => {
if let Some(symbol_id) = symbols.root_symbol_id_by_name(&name.id) {
db.infer_symbol_type(file_id, symbol_id)
} else {
Type::Unknown
}
}
_ => todo!("full expression type resolution"),
}
}
#[cfg(test)]
mod tests {
use crate::db::tests::TestDb;
@ -123,32 +155,65 @@ mod tests {
#[test]
fn follow_import_to_class() -> std::io::Result<()> {
let TestCase {
src,
db,
temp_dir: _temp_dir,
} = create_test()?;
let case = create_test()?;
let db = &case.db;
let a_path = src.path().join("a.py");
let b_path = src.path().join("b.py");
std::fs::write(a_path, "from b import C as D")?;
let a_path = case.src.path().join("a.py");
let b_path = case.src.path().join("b.py");
std::fs::write(a_path, "from b import C as D; E = D")?;
std::fs::write(b_path, "class C: pass")?;
let a_file = db
.resolve_module(ModuleName::new("a"))
.expect("module should be found")
.path(&db)
.path(db)
.file();
let a_syms = db.symbol_table(a_file);
let d_sym = a_syms
.root_symbol_id_by_name("D")
.expect("D symbol should be found");
let e_sym = a_syms
.root_symbol_id_by_name("E")
.expect("E symbol should be found");
let ty = db.infer_symbol_type(a_file, d_sym);
let jar = HasJar::<SemanticJar>::jar(&db);
let ty = db.infer_symbol_type(a_file, e_sym);
let jar = HasJar::<SemanticJar>::jar(db);
assert!(matches!(ty, Type::Class(_)));
assert_eq!(format!("{}", ty.display(&jar.type_store)), "Literal[C]");
Ok(())
}
#[test]
fn resolve_base_class_by_name() -> std::io::Result<()> {
let case = create_test()?;
let db = &case.db;
let path = case.src.path().join("mod.py");
std::fs::write(path, "class Base: pass\nclass Sub(Base): pass")?;
let file = db
.resolve_module(ModuleName::new("mod"))
.expect("module should be found")
.path(db)
.file();
let syms = db.symbol_table(file);
let sym = syms
.root_symbol_id_by_name("Sub")
.expect("Sub symbol should be found");
let ty = db.infer_symbol_type(file, sym);
let Type::Class(class_id) = ty else {
panic!("Sub is not a Class")
};
let jar = HasJar::<SemanticJar>::jar(db);
let base_names: Vec<_> = jar
.type_store
.get_class(class_id)
.bases()
.iter()
.map(|base_ty| format!("{}", base_ty.display(&jar.type_store)))
.collect();
assert_eq!(base_names, vec!["Literal[Base]"]);
Ok(())
}
}