fix: declared type instantiation bug

This commit is contained in:
Shunsuke Shibayama 2024-02-07 01:58:39 +09:00
parent 9a031ad514
commit 53172d5132
2 changed files with 26 additions and 25 deletions

View file

@ -621,13 +621,6 @@ impl Context {
self.inc_ref(ident.inspect(), vi, ident, self);
return Ok(*t);
}
if let Some(outer) = &self.outer {
if let Ok(t) =
outer.instantiate_mono_t(ident, opt_decl_t, tmp_tv_cache, not_found_is_qvar)
{
return Ok(t);
}
}
if let Some(typ) = self
.consts
.get(ident.inspect())
@ -636,8 +629,16 @@ impl Context {
if let Some((_, vi)) = self.get_var_info(ident.inspect()) {
self.inc_ref(ident.inspect(), vi, ident, self);
}
Ok(typ)
} else if let Some(ctx) = self.get_type_ctx(ident.inspect()) {
return Ok(typ);
}
if let Some(outer) = &self.outer {
if let Ok(t) =
outer.instantiate_mono_t(ident, opt_decl_t, tmp_tv_cache, not_found_is_qvar)
{
return Ok(t);
}
}
if let Some(ctx) = self.get_type_ctx(ident.inspect()) {
if let Some((_, vi)) = self.get_var_info(ident.inspect()) {
self.inc_ref(ident.inspect(), vi, ident, self);
}
@ -836,6 +837,12 @@ impl Context {
Ok(Type::NamedTuple(ts))
}
other => {
let Some(ctx) = self.get_type_ctx(other).or_else(|| {
self.consts
.get(other)
.and_then(|v| self.convert_value_into_type(v.clone()).ok())
.and_then(|typ| self.get_nominal_type_ctx(&typ))
}) else {
if let Some(outer) = &self.outer {
if let Ok(t) = outer.instantiate_local_poly_t(
name,
@ -847,12 +854,6 @@ impl Context {
return Ok(t);
}
}
let Some(ctx) = self.get_type_ctx(other).or_else(|| {
self.consts
.get(other)
.and_then(|v| self.convert_value_into_type(v.clone()).ok())
.and_then(|typ| self.get_nominal_type_ctx(&typ))
}) else {
if let Some(decl_t) = opt_decl_t {
return Ok(decl_t.typ().clone());
}

View file

@ -1,13 +1,13 @@
# {Tensor!;} = pyimport "torch"
{Tensor!;} = pyimport "torch"
{Parameter;} = pyimport "torch/nn/parameter"
.Module: ClassType
.Module <: InheritableType
.Module|<: GenericCallable|.
__call__: |M <: .Module|(
self: M,
input: Obj #Tensor!(T, _),
) -> Obj #Tensor!(T, _)
__call__: |T|(
self: .Module,
input: Tensor!(T, _),
) -> Tensor!(T, _)
.Module.
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator Parameter
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, Parameter))