From 83cd92bb4837e8a38dd325941903c1c68f1c0366 Mon Sep 17 00:00:00 2001 From: Shunsuke Shibayama Date: Wed, 13 Dec 2023 14:37:48 +0900 Subject: [PATCH] fix: superclass declaration bug --- crates/erg_compiler/context/register.rs | 41 +++++++++++++++++++ crates/erg_compiler/declare.rs | 18 ++++---- .../lib/external/torch.d/__init__.d.er | 4 ++ .../lib/external/torchvision.d/__init__.d.er | 2 + .../torchvision.d/datasets.d/__init__.d.er | 5 +++ .../torchvision.d/transforms.d/__init__.d.er | 1 + .../transforms.d/transforms.d.er | 1 + 7 files changed, 65 insertions(+), 7 deletions(-) diff --git a/crates/erg_compiler/context/register.rs b/crates/erg_compiler/context/register.rs index 2b9518a1..0d5c8ed7 100644 --- a/crates/erg_compiler/context/register.rs +++ b/crates/erg_compiler/context/register.rs @@ -1488,6 +1488,47 @@ impl Context { // .retain(|t| !ctx.same_type_of(t, trait_)); } + pub(crate) fn register_base_class(&mut self, ctx: &Self, class: Type) -> CompileResult<()> { + let class_ctx = ctx.get_nominal_type_ctx(&class).ok_or_else(|| { + CompileError::type_not_found( + self.cfg.input.clone(), + line!() as usize, + ().loc(), + self.caused_by(), + &class, + ) + })?; + if class_ctx.typ.has_qvar() { + let _substituter = Substituter::substitute_typarams(ctx, &class_ctx.typ, &class)?; + self.super_classes.push(class); + let mut tv_cache = TyVarCache::new(ctx.level, ctx); + let classes = class_ctx.super_classes.iter().cloned().map(|ty| { + if ty.has_undoable_linked_var() { + ctx.detach(ty, &mut tv_cache) + } else { + ty + } + }); + self.super_classes.extend(classes); + let traits = class_ctx.super_traits.iter().cloned().map(|ty| { + if ty.has_undoable_linked_var() { + ctx.detach(ty, &mut tv_cache) + } else { + ty + } + }); + self.super_traits.extend(traits); + } else { + self.super_classes.push(class); + let classes = class_ctx.super_classes.clone(); + self.super_classes.extend(classes); + let traits = class_ctx.super_traits.clone(); + self.super_traits.extend(traits); + } + unique_in_place(&mut self.super_classes); + Ok(()) + } + pub(crate) fn register_gen_const( &mut self, ident: &Identifier, diff --git a/crates/erg_compiler/declare.rs b/crates/erg_compiler/declare.rs index 90ea628f..97ed04de 100644 --- a/crates/erg_compiler/declare.rs +++ b/crates/erg_compiler/declare.rs @@ -937,7 +937,7 @@ impl GenericASTLowerer { Ok(()) } - fn declare_subtype(&mut self, ident: &ast::Identifier, trait_: &Type) -> LowerResult<()> { + fn declare_subtype(&mut self, ident: &ast::Identifier, sup: &Type) -> LowerResult<()> { if ident.is_raw() { return Ok(()); } @@ -952,12 +952,16 @@ impl GenericASTLowerer { }; if let Some(ctx) = self.module.context.rec_get_mut_type(&name) { let mut tmp = mem::take(ctx); - tmp.register_marker_trait(&self.module.context, trait_.clone()) - .map_err(|err| { - let ctx = self.module.context.rec_get_mut_type(&name).unwrap(); - mem::swap(ctx, &mut tmp); - err - })?; + let res = if self.module.context.is_class(sup) { + tmp.register_base_class(&self.module.context, sup.clone()) + } else { + tmp.register_marker_trait(&self.module.context, sup.clone()) + }; + res.map_err(|err| { + let ctx = self.module.context.rec_get_mut_type(&name).unwrap(); + mem::swap(ctx, &mut tmp); + err + })?; let ctx = self.module.context.rec_get_mut_type(&name).unwrap(); mem::swap(ctx, &mut tmp); Ok(()) diff --git a/crates/erg_compiler/lib/external/torch.d/__init__.d.er b/crates/erg_compiler/lib/external/torch.d/__init__.d.er index d8648001..c531b4c1 100644 --- a/crates/erg_compiler/lib/external/torch.d/__init__.d.er +++ b/crates/erg_compiler/lib/external/torch.d/__init__.d.er @@ -1 +1,5 @@ +.nn = pyimport "./nn" +.serialization = pyimport "./serialization" +.util = pyimport "./util" + {.load!; .save!;} = pyimport "./serialization" diff --git a/crates/erg_compiler/lib/external/torchvision.d/__init__.d.er b/crates/erg_compiler/lib/external/torchvision.d/__init__.d.er index e69de29b..e08f18c0 100644 --- a/crates/erg_compiler/lib/external/torchvision.d/__init__.d.er +++ b/crates/erg_compiler/lib/external/torchvision.d/__init__.d.er @@ -0,0 +1,2 @@ +.datasets = pyimport "./datasets" +.transforms = pyimport "./transforms" diff --git a/crates/erg_compiler/lib/external/torchvision.d/datasets.d/__init__.d.er b/crates/erg_compiler/lib/external/torchvision.d/datasets.d/__init__.d.er index e69de29b..1d858077 100644 --- a/crates/erg_compiler/lib/external/torchvision.d/datasets.d/__init__.d.er +++ b/crates/erg_compiler/lib/external/torchvision.d/datasets.d/__init__.d.er @@ -0,0 +1,5 @@ +.mnist = pyimport "./mnist" +.utils = pyimport "./utils" +.vision = pyimport "./vision" + +{.MNIST; .FashionMNIST;} = .mnist diff --git a/crates/erg_compiler/lib/external/torchvision.d/transforms.d/__init__.d.er b/crates/erg_compiler/lib/external/torchvision.d/transforms.d/__init__.d.er index e69de29b..28878be4 100644 --- a/crates/erg_compiler/lib/external/torchvision.d/transforms.d/__init__.d.er +++ b/crates/erg_compiler/lib/external/torchvision.d/transforms.d/__init__.d.er @@ -0,0 +1 @@ +{.ToTensor;} = pyimport "./transforms" diff --git a/crates/erg_compiler/lib/external/torchvision.d/transforms.d/transforms.d.er b/crates/erg_compiler/lib/external/torchvision.d/transforms.d/transforms.d.er index 3bfa9192..be34ddfd 100644 --- a/crates/erg_compiler/lib/external/torchvision.d/transforms.d/transforms.d.er +++ b/crates/erg_compiler/lib/external/torchvision.d/transforms.d/transforms.d.er @@ -1,3 +1,4 @@ .ToTensor: ClassType +.ToTensor <: GenericCallable .ToTensor. __call__: () -> .ToTensor