diff --git a/crates/erg_compiler/lib/pystd/types.d.er b/crates/erg_compiler/lib/pystd/types.d.er index 57330b3b..cce81aa6 100644 --- a/crates/erg_compiler/lib/pystd/types.d.er +++ b/crates/erg_compiler/lib/pystd/types.d.er @@ -15,6 +15,8 @@ .MethodDescriptorType: ClassType .ClassMethodDescriptorType: ClassType .ModuleType: ClassType +.ModuleType. + __call__: (name: Str, doc := Str) -> ModuleType .EllipsisType: ClassType .GenericAlias: (Type, GenericTuple) -> ClassType # TODO: Tuple Type .UnionType: (Type, Type) -> Type diff --git a/crates/erg_compiler/lower.rs b/crates/erg_compiler/lower.rs index 9aaa582b..81246869 100644 --- a/crates/erg_compiler/lower.rs +++ b/crates/erg_compiler/lower.rs @@ -3225,7 +3225,16 @@ impl GenericASTLowerer { let mut errors = CompileErrors::empty(); let mut unverified_names = self.module.context.locals.keys().collect::>(); let mut super_impls = set! {}; - let tys_decls = if let Some(sups) = self.module.context.get_super_types(trait_type) { + let retained_decls = |ctx: &Context, super_impls: &Set<&VarName>| { + ctx.decls.clone().retained(|k, _| { + let implemented_in_super = super_impls.contains(k); + let class_decl = ctx.kind.is_class(); + !implemented_in_super && !class_decl + }) + }; + let tys_decls = if self.module.context.is_class(trait_type) { + vec![(impl_trait.clone(), retained_decls(trait_ctx, &super_impls))] + } else if let Some(sups) = self.module.context.get_super_types(trait_type) { sups.map(|sup| { if implemented.linear_contains(&sup) { return (sup, Dict::new()); @@ -3239,17 +3248,13 @@ impl GenericASTLowerer { for methods in &ctx.methods_list { super_impls.extend(methods.locals.keys()); } - ctx.decls.clone().retained(|k, _| { - let implemented_in_super = super_impls.contains(k); - let class_decl = ctx.kind.is_class(); - !implemented_in_super && !class_decl - }) + retained_decls(ctx, &super_impls) }); (sup, decls) }) .collect::>() } else { - vec![(impl_trait.clone(), trait_ctx.decls.clone())] + vec![(impl_trait.clone(), retained_decls(trait_ctx, &super_impls))] }; for (impl_trait, decls) in tys_decls { for (decl_name, decl_vi) in decls {