fix: subtype relation bug

This commit is contained in:
Shunsuke Shibayama 2024-02-03 19:36:27 +09:00
parent 8d5641503f
commit 9c9f8b7a0a
7 changed files with 233 additions and 121 deletions

View file

@ -329,7 +329,11 @@ impl Context {
if !self.is_trait(lhs) {
return (Maybe, false);
}
self._nominal_subtype_of(lhs, rhs, |ty_ctx| &ty_ctx.super_traits)
let (cred, judge) = self._nominal_subtype_of(lhs, rhs, |ty_ctx| &ty_ctx.super_traits[..]);
if judge {
return (cred, judge);
}
self._nominal_subtype_of(lhs, rhs, |ty_ctx| &ty_ctx.super_classes[..])
}
/// lhs :> rhs?

View file

@ -507,13 +507,15 @@ impl Context {
self.instantiate_mono_t(simple, opt_decl_t, tmp_tv_cache, not_found_is_qvar)
}
ast::PreDeclTypeSpec::Poly(poly) => match &poly.acc {
ast::ConstAccessor::Local(local) => self.instantiate_local_poly_t(
local,
&poly.args,
opt_decl_t,
tmp_tv_cache,
not_found_is_qvar,
),
ast::ConstAccessor::Local(local) => self
.instantiate_local_poly_t(
local,
&poly.args,
opt_decl_t,
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|(_, err)| err),
ast::ConstAccessor::Attr(attr) => {
let ctxs = self.get_singular_ctxs(&attr.obj.clone().downgrade(), self)?;
for ctx in ctxs {
@ -667,7 +669,7 @@ impl Context {
opt_decl_t: Option<&ParamTy>,
tmp_tv_cache: &mut TyVarCache,
not_found_is_qvar: bool,
) -> TyCheckResult<Type> {
) -> Failable<Type> {
match name.inspect().trim_start_matches([':', '.']) {
"Array" => {
let ctx = &self
@ -677,19 +679,22 @@ impl Context {
// TODO: kw
let mut pos_args = args.pos_args();
if let Some(first) = pos_args.next() {
let t = self.instantiate_const_expr_as_type(
&first.expr,
Some((ctx, 0)),
tmp_tv_cache,
not_found_is_qvar,
)?;
let t = self
.instantiate_const_expr_as_type(
&first.expr,
Some((ctx, 0)),
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|err| (Type::Failure, err))?;
let len = if let Some(len) = pos_args.next() {
self.instantiate_const_expr(
&len.expr,
Some((ctx, 1)),
tmp_tv_cache,
not_found_is_qvar,
)?
)
.map_err(|err| (Type::Failure, err))?
} else {
TyParam::erased(Nat)
};
@ -701,104 +706,131 @@ impl Context {
"Ref" => {
let mut pos_args = args.pos_args();
let Some(first) = pos_args.next() else {
return Err(TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"Ref",
self.caused_by(),
vec![Str::from("T")],
)));
return Err((
Failure,
TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"Ref",
self.caused_by(),
vec![Str::from("T")],
)),
));
};
let t = self.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)?;
let t = self
.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|err| (Type::Failure, err))?;
Ok(ref_(t))
}
"RefMut" => {
// TODO after
let mut pos_args = args.pos_args();
let Some(first) = pos_args.next() else {
return Err(TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"RefMut",
self.caused_by(),
vec![Str::from("T")],
)));
return Err((
Failure,
TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"RefMut",
self.caused_by(),
vec![Str::from("T")],
)),
));
};
let t = self.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)?;
let t = self
.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|err| (Type::Failure, err))?;
Ok(ref_mut(t, None))
}
"Structural" => {
let mut pos_args = args.pos_args();
let Some(first) = pos_args.next() else {
return Err(TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"Structural",
self.caused_by(),
vec![Str::from("Type")],
)));
return Err((
Failure,
TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"Structural",
self.caused_by(),
vec![Str::from("Type")],
)),
));
};
let t = self.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)?;
let t = self
.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|err| (Type::Failure, err))?;
Ok(t.structuralize())
}
"NamedTuple" => {
let mut pose_args = args.pos_args();
let Some(first) = pose_args.next() else {
return Err(TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"NamedTuple",
self.caused_by(),
vec![Str::from("Fields")],
)));
return Err((
Failure,
TyCheckErrors::from(TyCheckError::args_missing_error(
self.cfg.input.clone(),
line!() as usize,
args.loc(),
"NamedTuple",
self.caused_by(),
vec![Str::from("Fields")],
)),
));
};
let ConstExpr::Record(fields) = &first.expr else {
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.cfg.input.clone(),
line!() as usize,
first.expr.loc(),
self.caused_by(),
"NamedTuple",
None,
&mono("Record"),
&self.instantiate_const_expr_as_type(
&first.expr,
return Err((
Failure,
TyCheckErrors::from(TyCheckError::type_mismatch_error(
self.cfg.input.clone(),
line!() as usize,
first.expr.loc(),
self.caused_by(),
"NamedTuple",
None,
tmp_tv_cache,
not_found_is_qvar,
)?,
None,
None,
)));
&mono("Record"),
&self
.instantiate_const_expr_as_type(
&first.expr,
None,
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|err| (Type::Failure, err))?,
None,
None,
)),
));
};
let mut ts = vec![];
for def in fields.attrs.iter() {
let t = self.instantiate_const_expr_as_type(
&def.body.block[0],
None,
tmp_tv_cache,
not_found_is_qvar,
)?;
let vis = self.instantiate_vis_modifier(&def.ident.vis)?;
let t = self
.instantiate_const_expr_as_type(
&def.body.block[0],
None,
tmp_tv_cache,
not_found_is_qvar,
)
.map_err(|err| (Type::Failure, err))?;
let vis = self
.instantiate_vis_modifier(&def.ident.vis)
.map_err(|err| (Type::Failure, err))?;
ts.push((Field::new(vis, def.ident.inspect().clone()), t));
}
Ok(Type::NamedTuple(ts))
@ -824,15 +856,19 @@ impl Context {
if let Some(decl_t) = opt_decl_t {
return Ok(decl_t.typ().clone());
}
return Err(TyCheckErrors::from(TyCheckError::no_type_error(
self.cfg.input.clone(),
line!() as usize,
name.loc(),
self.caused_by(),
other,
self.get_similar_name(other),
)));
return Err((
Failure,
TyCheckErrors::from(TyCheckError::no_type_error(
self.cfg.input.clone(),
line!() as usize,
name.loc(),
self.caused_by(),
other,
self.get_similar_name(other),
)),
));
};
let mut errs = TyCheckErrors::empty();
// FIXME: kw args
let mut new_params = vec![];
for ((i, arg), (name, param_vi)) in
@ -844,23 +880,25 @@ impl Context {
tmp_tv_cache,
not_found_is_qvar,
);
let param = param.or_else(|e| {
if not_found_is_qvar {
let name = arg.expr.to_string();
// FIXME: handle `::` as a right way
let name = Str::rc(name.trim_start_matches("::"));
let tp = TyParam::named_free_var(
name.clone(),
self.level,
Constraint::Uninited,
);
let varname = VarName::from_str(name);
tmp_tv_cache.push_or_init_typaram(&varname, &tp, self)?;
Ok(tp)
} else {
Err(e)
}
})?;
let param = param
.or_else(|e| {
if not_found_is_qvar {
let name = arg.expr.to_string();
// FIXME: handle `::` as a right way
let name = Str::rc(name.trim_start_matches("::"));
let tp = TyParam::named_free_var(
name.clone(),
self.level,
Constraint::Uninited,
);
let varname = VarName::from_str(name);
tmp_tv_cache.push_or_init_typaram(&varname, &tp, self)?;
Ok(tp)
} else {
Err(e)
}
})
.map_err(|err| (Type::Failure, err))?;
let arg_t = self
.get_tp_t(&param)
.map_err(|err| {
@ -871,7 +909,8 @@ impl Context {
if self.subtype_of(&arg_t, &param_vi.t) {
new_params.push(param);
} else {
return Err(TyCheckErrors::from(TyCheckError::type_mismatch_error(
new_params.push(TyParam::erased(param_vi.t.clone()));
errs.push(TyCheckError::type_mismatch_error(
self.cfg.input.clone(),
line!() as usize,
arg.expr.loc(),
@ -882,11 +921,16 @@ impl Context {
&arg_t,
None,
None,
)));
));
}
}
// FIXME: non-builtin
Ok(poly(ctx.typ.qual_name(), new_params))
let t = poly(ctx.typ.qual_name(), new_params);
if errs.is_empty() {
Ok(t)
} else {
Err((t, errs))
}
}
}
}

View file

@ -32,5 +32,10 @@
.Tensor(_, _).
dtype: .DType
shape: .Size
view: (|T, Old: [Nat; _], S: {A: [Nat; _] | A.prod() == Old.prod()}|(
self: .Tensor(T, Old),
shape: {S},
) -> .Tensor(T, S)) \
and (|T|(self: .Tensor(T, _), shape: [Int; _]) -> .Tensor(T, _))
.relu: |T, S: [Nat; _]|(x: .Tensor(T, S)) -> .Tensor(T, S)

View file

@ -2,6 +2,7 @@
# {.Tensor;} = pyimport "torch"
.Module: ClassType
.Module <: InheritableType
.Module.
parameters: (self: Ref(.Module), recurse := Bool) -> Iterator .Parameter
named_parameters: (self: Ref(.Module), prefix := Str, recurse := Bool, remove_duplicate := Bool) -> Iterator((Str, .Parameter))

View file

@ -1,4 +1,9 @@
{Parameter;} = pyimport "torch/nn/parameter"
.Optimizer: ClassType
.Optimizer <: InheritableType
.Optimizer.
__call__: (params: Iterable(Parameter)) -> .Optimizer
.ASGD: ClassType
.ASGD <: .Optimizer
@ -8,6 +13,18 @@
.Adagrad <: .Optimizer
.Adam: ClassType
.Adam <: .Optimizer
.Adam.
__call__: (
params: Iterable(Parameter),
lr := Float,
betas := (Float, Float),
eps := Float,
weight_decay := Float,
amsgrad := Bool,
foreach := Bool,
maximize := Bool,
) -> .Adam
.AdamW: ClassType
.AdamW <: .Optimizer
.Adamax: ClassType

View file

@ -423,7 +423,7 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
else if self
.module
.context
.coerce(union_.clone(), &())
.coerce(union_.derefine(), &())
.map_or(true, |coerced| coerced.union_pair().is_some())
{
return Err(self.elem_err(&l, &r, elem));
@ -1373,10 +1373,12 @@ impl<A: ASTBuildable> GenericASTLowerer<A> {
} else {
if let hir::Expr::Call(call) = &obj {
if call.return_t().is_some() {
*obj.ref_mut_t().unwrap() = vi.t;
if let Some(ref_mut_t) = obj.ref_mut_t() {
*ref_mut_t = vi.t;
}
}
} else {
*obj.ref_mut_t().unwrap() = vi.t;
} else if let Some(ref_mut_t) = obj.ref_mut_t() {
*ref_mut_t = vi.t;
}
None
};

View file

@ -619,6 +619,35 @@ impl SubrType {
pub fn is_no_var(&self) -> bool {
self.var_params.is_none() && self.kw_var_params.is_none()
}
pub fn derefine(&self) -> Self {
let non_default_params = self
.non_default_params
.iter()
.map(|pt| pt.clone().map_type(|t| t.derefine()))
.collect();
let var_params = self
.var_params
.as_ref()
.map(|pt| pt.clone().map_type(|t| t.derefine()));
let default_params = self
.default_params
.iter()
.map(|pt| pt.clone().map_type(|t| t.derefine()))
.collect();
let kw_var_params = self
.kw_var_params
.as_ref()
.map(|pt| pt.clone().map_type(|t| t.derefine()));
Self::new(
self.kind,
non_default_params,
var_params,
default_params,
kw_var_params,
self.return_t.derefine(),
)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
@ -3399,6 +3428,16 @@ impl Type {
sub: Box::new(sub.derefine()),
sup: Box::new(sup.derefine()),
},
Self::Callable { param_ts, return_t } => {
let param_ts = param_ts.iter().map(|t| t.derefine()).collect();
let return_t = return_t.derefine();
Self::Callable {
param_ts,
return_t: Box::new(return_t),
}
}
Self::Subr(subr) => Self::Subr(subr.derefine()),
Self::Quantified(quant) => quant.derefine().quantify(),
other => other.clone(),
}
}