fix: superclass declaration bug

This commit is contained in:
Shunsuke Shibayama 2023-12-13 14:37:48 +09:00
parent 4f02d6ce2d
commit 83cd92bb48
7 changed files with 65 additions and 7 deletions

View file

@ -1488,6 +1488,47 @@ impl Context {
// .retain(|t| !ctx.same_type_of(t, trait_)); // .retain(|t| !ctx.same_type_of(t, trait_));
} }
pub(crate) fn register_base_class(&mut self, ctx: &Self, class: Type) -> CompileResult<()> {
let class_ctx = ctx.get_nominal_type_ctx(&class).ok_or_else(|| {
CompileError::type_not_found(
self.cfg.input.clone(),
line!() as usize,
().loc(),
self.caused_by(),
&class,
)
})?;
if class_ctx.typ.has_qvar() {
let _substituter = Substituter::substitute_typarams(ctx, &class_ctx.typ, &class)?;
self.super_classes.push(class);
let mut tv_cache = TyVarCache::new(ctx.level, ctx);
let classes = class_ctx.super_classes.iter().cloned().map(|ty| {
if ty.has_undoable_linked_var() {
ctx.detach(ty, &mut tv_cache)
} else {
ty
}
});
self.super_classes.extend(classes);
let traits = class_ctx.super_traits.iter().cloned().map(|ty| {
if ty.has_undoable_linked_var() {
ctx.detach(ty, &mut tv_cache)
} else {
ty
}
});
self.super_traits.extend(traits);
} else {
self.super_classes.push(class);
let classes = class_ctx.super_classes.clone();
self.super_classes.extend(classes);
let traits = class_ctx.super_traits.clone();
self.super_traits.extend(traits);
}
unique_in_place(&mut self.super_classes);
Ok(())
}
pub(crate) fn register_gen_const( pub(crate) fn register_gen_const(
&mut self, &mut self,
ident: &Identifier, ident: &Identifier,

View file

@ -937,7 +937,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
Ok(()) Ok(())
} }
fn declare_subtype(&mut self, ident: &ast::Identifier, trait_: &Type) -> LowerResult<()> { fn declare_subtype(&mut self, ident: &ast::Identifier, sup: &Type) -> LowerResult<()> {
if ident.is_raw() { if ident.is_raw() {
return Ok(()); return Ok(());
} }
@ -952,8 +952,12 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
}; };
if let Some(ctx) = self.module.context.rec_get_mut_type(&name) { if let Some(ctx) = self.module.context.rec_get_mut_type(&name) {
let mut tmp = mem::take(ctx); let mut tmp = mem::take(ctx);
tmp.register_marker_trait(&self.module.context, trait_.clone()) let res = if self.module.context.is_class(sup) {
.map_err(|err| { tmp.register_base_class(&self.module.context, sup.clone())
} else {
tmp.register_marker_trait(&self.module.context, sup.clone())
};
res.map_err(|err| {
let ctx = self.module.context.rec_get_mut_type(&name).unwrap(); let ctx = self.module.context.rec_get_mut_type(&name).unwrap();
mem::swap(ctx, &mut tmp); mem::swap(ctx, &mut tmp);
err err

View file

@ -1 +1,5 @@
.nn = pyimport "./nn"
.serialization = pyimport "./serialization"
.util = pyimport "./util"
{.load!; .save!;} = pyimport "./serialization" {.load!; .save!;} = pyimport "./serialization"

View file

@ -0,0 +1,2 @@
.datasets = pyimport "./datasets"
.transforms = pyimport "./transforms"

View file

@ -0,0 +1,5 @@
.mnist = pyimport "./mnist"
.utils = pyimport "./utils"
.vision = pyimport "./vision"
{.MNIST; .FashionMNIST;} = .mnist

View file

@ -0,0 +1 @@
{.ToTensor;} = pyimport "./transforms"

View file

@ -1,3 +1,4 @@
.ToTensor: ClassType .ToTensor: ClassType
.ToTensor <: GenericCallable
.ToTensor. .ToTensor.
__call__: () -> .ToTensor __call__: () -> .ToTensor