diff --git a/crates/red_knot/src/types.rs b/crates/red_knot/src/types.rs index 7827c036d6..bce8eff114 100644 --- a/crates/red_knot/src/types.rs +++ b/crates/red_knot/src/types.rs @@ -119,8 +119,9 @@ impl TypeStore { self.modules.get(&file_id) } - fn add_function(&self, file_id: FileId, name: &str) -> FunctionTypeId { - self.add_or_get_module(file_id).add_function(name) + fn add_function(&self, file_id: FileId, name: &str, decorators: Vec) -> FunctionTypeId { + self.add_or_get_module(file_id) + .add_function(name, decorators) } fn add_class(&self, file_id: FileId, name: &str, bases: Vec) -> ClassTypeId { @@ -306,9 +307,10 @@ impl ModuleTypeStore { } } - fn add_function(&mut self, name: &str) -> FunctionTypeId { + fn add_function(&mut self, name: &str, decorators: Vec) -> FunctionTypeId { let func_id = self.functions.push(FunctionType { name: Name::new(name), + decorators, }); FunctionTypeId { file_id: self.file_id, @@ -420,12 +422,17 @@ impl ClassType { #[derive(Debug)] pub(crate) struct FunctionType { name: Name, + decorators: Vec, } impl FunctionType { fn name(&self) -> &str { self.name.as_str() } + + fn decorators(&self) -> &[Type] { + self.decorators.as_slice() + } } #[derive(Debug)] @@ -509,8 +516,9 @@ mod tests { let store = TypeStore::default(); let files = Files::default(); let file_id = files.intern(Path::new("/foo")); - let id = store.add_function(file_id, "func"); + let id = store.add_function(file_id, "func", vec![Type::Unknown]); assert_eq!(store.get_function(id).name(), "func"); + assert_eq!(store.get_function(id).decorators(), vec![Type::Unknown]); let func = Type::Function(id); assert_eq!(format!("{}", func.display(&store)), "func"); } diff --git a/crates/red_knot/src/types/infer.rs b/crates/red_knot/src/types/infer.rs index da327c5b57..f7e890fbb1 100644 --- a/crates/red_knot/src/types/infer.rs +++ b/crates/red_knot/src/types/infer.rs @@ -80,7 +80,15 @@ where .resolve(ast.as_any_node_ref()) .expect("node key should resolve"); - let ty = type_store.add_function(file_id, &node.name.id).into(); + let decorator_tys = node + .decorator_list + .iter() + .map(|decorator| infer_expr_type(db, file_id, &decorator.expression)) + .collect::>()?; + + let ty = type_store + .add_function(file_id, &node.name.id, decorator_tys) + .into(); type_store.cache_node_type(file_id, *node_key.erased(), ty); ty }