feat: introduce bidirectional type checking

This commit is contained in:
Shunsuke Shibayama 2023-09-05 13:57:58 +09:00
parent 5f8d744e47
commit 75b5b68831
13 changed files with 501 additions and 135 deletions

View file

@ -20,7 +20,7 @@ use erg_parser::desugar::Desugarer;
use erg_parser::token::{Token, TokenKind};
use crate::ty::constructors::{
array_t, dict_t, mono, mono_q, named_free_var, poly, proj, proj_call, ref_, ref_mut,
array_t, bounded, dict_t, mono, mono_q, named_free_var, poly, proj, proj_call, ref_, ref_mut,
refinement, set_t, subr_t, subtypeof, tp_enum, tuple_t, v_enum,
};
use crate::ty::free::{Constraint, HasLevel};
@ -1668,8 +1668,8 @@ impl Context {
if let Some(fv) = lhs.as_free() {
let (sub, sup) = fv.get_subsup().unwrap();
if self.is_trait(&sup) && !self.trait_impl_exists(&sub, &sup) {
// link to `Never` to prevent double errors from being reported
lhs.destructive_link(&Never);
// link to `Never..Obj` to prevent double errors from being reported
lhs.destructive_link(&bounded(Never, Type::Obj));
let sub = if cfg!(feature = "debug") {
sub
} else {

View file

@ -897,6 +897,36 @@ impl Context {
}
}
fn search_callee_info_without_args(
&self,
obj: &hir::Expr,
attr_name: &Option<Identifier>,
input: &Input,
namespace: &Context,
) -> SingleTyCheckResult<VarInfo> {
if obj.ref_t() == Type::FAILURE {
// (...Obj) -> Failure
return Ok(VarInfo {
t: Type::Subr(SubrType::new(
SubrKind::Func,
vec![],
Some(ParamTy::Pos(ref_(Obj))),
vec![],
Failure,
)),
..VarInfo::default()
});
}
if let Some(attr_name) = attr_name.as_ref() {
self.search_method_info_without_args(obj, attr_name, input, namespace)
} else {
Ok(VarInfo {
t: obj.t(),
..VarInfo::default()
})
}
}
fn resolve_overload(
&self,
obj: &hir::Expr,
@ -1155,6 +1185,150 @@ impl Context {
))
}
fn search_method_info_without_args(
&self,
obj: &hir::Expr,
attr_name: &Identifier,
input: &Input,
namespace: &Context,
) -> SingleTyCheckResult<VarInfo> {
match self.get_attr_info_from_attributive(obj.ref_t(), attr_name) {
Triple::Ok(vi) => {
return Ok(vi);
}
Triple::Err(e) => {
return Err(e);
}
_ => {}
}
for ctx in self
.get_nominal_super_type_ctxs(obj.ref_t())
.ok_or_else(|| {
TyCheckError::type_not_found(
self.cfg.input.clone(),
line!() as usize,
obj.loc(),
self.caused_by(),
obj.ref_t(),
)
})?
{
if let Some(vi) = ctx
.locals
.get(attr_name.inspect())
.or_else(|| ctx.decls.get(attr_name.inspect()))
{
self.validate_visibility(attr_name, vi, input, namespace)?;
return Ok(vi.clone());
}
for (_, methods_ctx) in ctx.methods_list.iter() {
if let Some(vi) = methods_ctx
.locals
.get(attr_name.inspect())
.or_else(|| methods_ctx.decls.get(attr_name.inspect()))
{
self.validate_visibility(attr_name, vi, input, namespace)?;
return Ok(vi.clone());
}
}
if let Some(ctx) = self.get_same_name_context(&ctx.name) {
match ctx.rec_get_var_info(attr_name, AccessKind::BoundAttr, input, namespace) {
Triple::Ok(t) => {
return Ok(t);
}
Triple::Err(e) => {
return Err(e);
}
Triple::None => {}
}
}
}
if let Ok(singular_ctxs) = self.get_singular_ctxs_by_hir_expr(obj, namespace) {
for ctx in singular_ctxs {
if let Some(vi) = ctx
.locals
.get(attr_name.inspect())
.or_else(|| ctx.decls.get(attr_name.inspect()))
{
self.validate_visibility(attr_name, vi, input, namespace)?;
return Ok(vi.clone());
}
for (_, method_ctx) in ctx.methods_list.iter() {
if let Some(vi) = method_ctx
.locals
.get(attr_name.inspect())
.or_else(|| method_ctx.decls.get(attr_name.inspect()))
{
self.validate_visibility(attr_name, vi, input, namespace)?;
return Ok(vi.clone());
}
}
}
return Err(TyCheckError::singular_no_attr_error(
self.cfg.input.clone(),
line!() as usize,
attr_name.loc(),
namespace.name.to_string(),
obj.qual_name().as_deref().unwrap_or("?"),
obj.ref_t(),
attr_name.inspect(),
self.get_similar_attr_from_singular(obj, attr_name.inspect()),
));
}
match self.get_attr_type_by_name(obj, attr_name) {
Triple::Ok(method) => {
let def_t = self.instantiate_def_type(&method.definition_type).unwrap();
self.sub_unify(obj.ref_t(), &def_t, obj, None)
// HACK: change this func's return type to TyCheckResult<Type>
.map_err(|mut errs| errs.remove(0))?;
return Ok(method.method_info.clone());
}
Triple::Err(err) => {
return Err(err);
}
_ => {}
}
for patch in self.find_patches_of(obj.ref_t()) {
if let Some(vi) = patch
.locals
.get(attr_name.inspect())
.or_else(|| patch.decls.get(attr_name.inspect()))
{
self.validate_visibility(attr_name, vi, input, namespace)?;
return Ok(vi.clone());
}
for (_, methods_ctx) in patch.methods_list.iter() {
if let Some(vi) = methods_ctx
.locals
.get(attr_name.inspect())
.or_else(|| methods_ctx.decls.get(attr_name.inspect()))
{
self.validate_visibility(attr_name, vi, input, namespace)?;
return Ok(vi.clone());
}
}
}
let coerced = self
.coerce(obj.t(), &())
.map_err(|mut errs| errs.remove(0))?;
if &coerced != obj.ref_t() {
let hash = get_hash(obj.ref_t());
obj.ref_t().destructive_coerce();
if get_hash(obj.ref_t()) != hash {
return self.search_method_info_without_args(obj, attr_name, input, namespace);
}
}
Err(TyCheckError::no_attr_error(
self.cfg.input.clone(),
line!() as usize,
attr_name.loc(),
namespace.name.to_string(),
obj.ref_t(),
attr_name.inspect(),
self.get_similar_attr(obj.ref_t(), attr_name.inspect()),
))
}
fn validate_visibility(
&self,
ident: &Identifier,
@ -1930,6 +2104,37 @@ impl Context {
Ok(())
}
pub(crate) fn get_call_t_without_args(
&self,
obj: &hir::Expr,
attr_name: &Option<Identifier>,
input: &Input,
namespace: &Context,
) -> Result<VarInfo, (Option<VarInfo>, TyCheckErrors)> {
let found = self
.search_callee_info_without_args(obj, attr_name, input, namespace)
.map_err(|err| (None, TyCheckErrors::from(err)))?;
log!(
"Found:\ncallee: {obj}{}\nfound: {found}",
fmt_option!(pre ".", attr_name.as_ref().map(|ident| &ident.name))
);
let instance = self
.instantiate(found.t.clone(), obj)
.map_err(|errs| (Some(found.clone()), errs))?;
log!("Instantiated:\ninstance: {instance}");
debug_assert!(
!instance.is_quantified_subr(),
"{instance} is quantified subr"
);
log!(info "Substituted:\ninstance: {instance}");
debug_assert!(instance.has_no_qvar(), "{instance} has qvar");
let res = VarInfo {
t: instance,
..found
};
Ok(res)
}
pub(crate) fn get_call_t(
&self,
obj: &hir::Expr,

View file

@ -848,9 +848,7 @@ impl Context {
}
// HACK: {op: |T|(T -> T) | op == F} => ?T -> ?T
Refinement(refine) if refine.t.is_quantified_subr() => {
let quant = enum_unwrap!(*refine.t, Type::Quantified);
let mut tmp_tv_cache = TyVarCache::new(self.level, self);
let t = self.instantiate_t_inner(*quant, &mut tmp_tv_cache, callee)?;
let t = self.instantiate(*refine.t, callee)?;
match &t {
Type::Subr(subr) => {
if let Some(self_t) = subr.self_t() {

View file

@ -665,26 +665,23 @@ impl Context {
pub(crate) fn assign_params(
&mut self,
params: &mut hir::Params,
opt_decl_subr_t: Option<SubrType>,
expect: Option<SubrType>,
) -> TyCheckResult<()> {
let mut errs = TyCheckErrors::empty();
if let Some(decl_subr_t) = opt_decl_subr_t {
debug_assert_eq!(
params.non_defaults.len(),
decl_subr_t.non_default_params.len()
);
debug_assert_eq!(params.defaults.len(), decl_subr_t.default_params.len());
if let Some(subr_t) = expect {
debug_assert_eq!(params.non_defaults.len(), subr_t.non_default_params.len());
debug_assert_eq!(params.defaults.len(), subr_t.default_params.len());
for (non_default, pt) in params
.non_defaults
.iter_mut()
.zip(decl_subr_t.non_default_params.iter())
.zip(subr_t.non_default_params.iter())
{
if let Err(es) = self.assign_param(non_default, Some(pt), ParamKind::NonDefault) {
errs.extend(es);
}
}
if let Some(var_params) = &mut params.var_params {
if let Some(pt) = &decl_subr_t.var_params {
if let Some(pt) = &subr_t.var_params {
let pt = pt.clone().map_type(unknown_len_array_t);
if let Err(es) = self.assign_param(var_params, Some(&pt), ParamKind::VarParams)
{
@ -694,11 +691,7 @@ impl Context {
errs.extend(es);
}
}
for (default, pt) in params
.defaults
.iter_mut()
.zip(decl_subr_t.default_params.iter())
{
for (default, pt) in params.defaults.iter_mut().zip(subr_t.default_params.iter()) {
if let Err(es) = self.assign_param(
&mut default.sig,
Some(pt),

View file

@ -716,6 +716,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
// ?T(<: Add(?T))
// ?U(:> {1, 2}, <: Add(?U)) ==> {1, 2}
sup_fv.dummy_link();
sub_fv.dummy_link();
if lsub.qual_name() == rsub.qual_name() {
for (lps, rps) in lsub.typarams().iter().zip(rsub.typarams().iter()) {
self.sub_unify_tp(lps, rps, None, false).map_err(|errs| {
@ -735,6 +736,7 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
}
sup_fv.undo();
sub_fv.undo();
let intersec = self.ctx.intersection(&lsup, &rsup);
if intersec == Type::Never {
return Err(TyCheckErrors::from(TyCheckError::subtyping_error(
@ -765,7 +767,9 @@ impl<'c, 'l, 'u, L: Locational> Unifier<'c, 'l, 'u, L> {
}
// e.g. intersec == Int, rsup == Add(?T)
// => ?T(:> Int)
self.sub_unify(&intersec, &rsup)?;
if !(intersec.is_recursive() && rsup.is_recursive()) {
self.sub_unify(&intersec, &rsup)?;
}
self.sub_unify(&rsub, &union)?;
// self.sub_unify(&intersec, &lsup, loc, param_name)?;
// self.sub_unify(&lsub, &union, loc, param_name)?;

View file

@ -155,6 +155,19 @@ impl ASTLowerer {
res
}
fn fake_lower_literal(&self, lit: ast::Literal) -> LowerResult<hir::Literal> {
let loc = lit.loc();
let lit = hir::Literal::try_from(lit.token).map_err(|_| {
LowerError::invalid_literal(
self.cfg.input.clone(),
line!() as usize,
loc,
self.module.context.caused_by(),
)
})?;
Ok(lit)
}
fn fake_lower_acc(&self, acc: ast::Accessor) -> LowerResult<hir::Accessor> {
// TypeApp is lowered in `fake_lower_expr`
match acc {
@ -501,7 +514,7 @@ impl ASTLowerer {
pub(crate) fn fake_lower_expr(&self, expr: ast::Expr) -> LowerResult<hir::Expr> {
match expr {
ast::Expr::Literal(lit) => Ok(hir::Expr::Literal(self.lower_literal(lit)?)),
ast::Expr::Literal(lit) => Ok(hir::Expr::Literal(self.fake_lower_literal(lit)?)),
ast::Expr::BinOp(binop) => Ok(hir::Expr::BinOp(self.fake_lower_binop(binop)?)),
ast::Expr::UnaryOp(unop) => Ok(hir::Expr::UnaryOp(self.fake_lower_unaryop(unop)?)),
ast::Expr::Array(arr) => Ok(hir::Expr::Array(self.fake_lower_array(arr)?)),
@ -857,9 +870,11 @@ impl ASTLowerer {
log!(info "entered {}", fn_name!());
match expr {
ast::Expr::Literal(lit) if lit.is_doc_comment() => {
Ok(hir::Expr::Literal(self.lower_literal(lit)?))
Ok(hir::Expr::Literal(self.lower_literal(lit, None)?))
}
ast::Expr::Accessor(acc) if allow_acc => {
Ok(hir::Expr::Accessor(self.lower_acc(acc, None)?))
}
ast::Expr::Accessor(acc) if allow_acc => Ok(hir::Expr::Accessor(self.lower_acc(acc)?)),
ast::Expr::Def(def) => Ok(hir::Expr::Def(self.declare_def(def)?)),
ast::Expr::TypeAscription(tasc) => Ok(hir::Expr::TypeAsc(self.declare_ident(tasc)?)),
ast::Expr::Call(call)
@ -868,7 +883,7 @@ impl ASTLowerer {
.map(|op| op.is_import())
.unwrap_or(false) =>
{
Ok(hir::Expr::Call(self.lower_call(call)?))
Ok(hir::Expr::Call(self.lower_call(call, None)?))
}
other => Err(LowerErrors::from(LowerError::declare_error(
self.cfg().input.clone(),

View file

@ -260,7 +260,7 @@ impl ASTLowerer {
self.errs.extend(errs);
}
for chunk in ast.module.into_iter() {
match self.lower_chunk(chunk) {
match self.lower_chunk(chunk, None) {
Ok(chunk) => {
module.push(chunk);
}

View file

@ -31,7 +31,9 @@ use crate::ty::constructors::{
use crate::ty::free::Constraint;
use crate::ty::typaram::TyParam;
use crate::ty::value::{GenTypeObj, TypeObj, ValueObj};
use crate::ty::{CastTarget, GuardType, HasType, ParamTy, Predicate, Type, VisibilityModifier};
use crate::ty::{
CastTarget, GuardType, HasType, ParamTy, Predicate, SubrType, Type, VisibilityModifier,
};
use crate::context::{
ClassDefType, Context, ContextKind, ContextProvider, ControlKind, ModuleContext,
@ -223,7 +225,11 @@ impl ASTLowerer {
}
impl ASTLowerer {
pub(crate) fn lower_literal(&self, lit: ast::Literal) -> LowerResult<hir::Literal> {
pub(crate) fn lower_literal(
&mut self,
lit: ast::Literal,
expect: Option<&Type>,
) -> LowerResult<hir::Literal> {
let loc = lit.loc();
let lit = hir::Literal::try_from(lit.token).map_err(|_| {
LowerError::invalid_literal(
@ -233,16 +239,23 @@ impl ASTLowerer {
self.module.context.caused_by(),
)
})?;
if let Some(expect) = expect {
if let Err(_errs) = self.module.context.sub_unify(&lit.t(), expect, &loc, None) {
// self.errs.extend(errs);
}
}
Ok(lit)
}
fn lower_array(&mut self, array: ast::Array) -> LowerResult<hir::Array> {
fn lower_array(&mut self, array: ast::Array, expect: Option<&Type>) -> LowerResult<hir::Array> {
log!(info "entered {}({array})", fn_name!());
match array {
ast::Array::Normal(arr) => Ok(hir::Array::Normal(self.lower_normal_array(arr)?)),
ast::Array::WithLength(arr) => {
Ok(hir::Array::WithLength(self.lower_array_with_length(arr)?))
ast::Array::Normal(arr) => {
Ok(hir::Array::Normal(self.lower_normal_array(arr, expect)?))
}
ast::Array::WithLength(arr) => Ok(hir::Array::WithLength(
self.lower_array_with_length(arr, expect)?,
)),
other => feature_error!(
LowerErrors,
LowerError,
@ -278,14 +291,18 @@ impl ASTLowerer {
))
}
fn lower_normal_array(&mut self, array: ast::NormalArray) -> LowerResult<hir::NormalArray> {
fn lower_normal_array(
&mut self,
array: ast::NormalArray,
_expect: Option<&Type>,
) -> LowerResult<hir::NormalArray> {
log!(info "entered {}({array})", fn_name!());
let mut new_array = vec![];
let eval_result = self.module.context.eval_const_normal_array(&array);
let (elems, ..) = array.elems.deconstruct();
let mut union = Type::Never;
for elem in elems.into_iter() {
let elem = self.lower_expr(elem.expr)?;
let elem = self.lower_expr(elem.expr, None)?;
let union_ = self.module.context.union(&union, elem.ref_t());
if let Some((l, r)) = union_.union_pair() {
match (l.is_unbound_var(), r.is_unbound_var()) {
@ -335,11 +352,12 @@ impl ASTLowerer {
fn lower_array_with_length(
&mut self,
array: ast::ArrayWithLength,
_expect: Option<&Type>,
) -> LowerResult<hir::ArrayWithLength> {
log!(info "entered {}({array})", fn_name!());
let elem = self.lower_expr(array.elem.expr)?;
let elem = self.lower_expr(array.elem.expr, None)?;
let array_t = self.gen_array_with_length_type(&elem, &array.len);
let len = self.lower_expr(*array.len)?;
let len = self.lower_expr(*array.len, None)?;
let hir_array = hir::ArrayWithLength::new(array.l_sqbr, array.r_sqbr, array_t, elem, len);
Ok(hir_array)
}
@ -354,35 +372,49 @@ impl ASTLowerer {
}
}
fn lower_tuple(&mut self, tuple: ast::Tuple) -> LowerResult<hir::Tuple> {
fn lower_tuple(&mut self, tuple: ast::Tuple, expect: Option<&Type>) -> LowerResult<hir::Tuple> {
log!(info "entered {}({tuple})", fn_name!());
match tuple {
ast::Tuple::Normal(tup) => Ok(hir::Tuple::Normal(self.lower_normal_tuple(tup)?)),
ast::Tuple::Normal(tup) => {
Ok(hir::Tuple::Normal(self.lower_normal_tuple(tup, expect)?))
}
}
}
fn lower_normal_tuple(&mut self, tuple: ast::NormalTuple) -> LowerResult<hir::NormalTuple> {
fn lower_normal_tuple(
&mut self,
tuple: ast::NormalTuple,
_expect: Option<&Type>,
) -> LowerResult<hir::NormalTuple> {
log!(info "entered {}({tuple})", fn_name!());
let mut new_tuple = vec![];
let (elems, .., paren) = tuple.elems.deconstruct();
for elem in elems {
let elem = self.lower_expr(elem.expr)?;
let elem = self.lower_expr(elem.expr, None)?;
new_tuple.push(elem);
}
Ok(hir::NormalTuple::new(hir::Args::values(new_tuple, paren)))
}
fn lower_record(&mut self, record: ast::Record) -> LowerResult<hir::Record> {
fn lower_record(
&mut self,
record: ast::Record,
expect: Option<&Type>,
) -> LowerResult<hir::Record> {
log!(info "entered {}({record})", fn_name!());
match record {
ast::Record::Normal(rec) => self.lower_normal_record(rec),
ast::Record::Normal(rec) => self.lower_normal_record(rec, expect),
ast::Record::Mixed(mixed) => {
self.lower_normal_record(Desugarer::desugar_shortened_record_inner(mixed))
self.lower_normal_record(Desugarer::desugar_shortened_record_inner(mixed), None)
}
}
}
fn lower_normal_record(&mut self, record: ast::NormalRecord) -> LowerResult<hir::Record> {
fn lower_normal_record(
&mut self,
record: ast::NormalRecord,
_expect: Option<&Type>,
) -> LowerResult<hir::Record> {
log!(info "entered {}({record})", fn_name!());
let mut hir_record =
hir::Record::new(record.l_brace, record.r_brace, hir::RecordAttrs::empty());
@ -411,11 +443,13 @@ impl ASTLowerer {
Ok(hir_record)
}
fn lower_set(&mut self, set: ast::Set) -> LowerResult<hir::Set> {
fn lower_set(&mut self, set: ast::Set, expect: Option<&Type>) -> LowerResult<hir::Set> {
log!(info "enter {}({set})", fn_name!());
match set {
ast::Set::Normal(set) => Ok(hir::Set::Normal(self.lower_normal_set(set)?)),
ast::Set::WithLength(set) => Ok(hir::Set::WithLength(self.lower_set_with_length(set)?)),
ast::Set::Normal(set) => Ok(hir::Set::Normal(self.lower_normal_set(set, expect)?)),
ast::Set::WithLength(set) => Ok(hir::Set::WithLength(
self.lower_set_with_length(set, expect)?,
)),
ast::Set::Comprehension(set) => feature_error!(
LowerErrors,
LowerError,
@ -426,13 +460,17 @@ impl ASTLowerer {
}
}
fn lower_normal_set(&mut self, set: ast::NormalSet) -> LowerResult<hir::NormalSet> {
fn lower_normal_set(
&mut self,
set: ast::NormalSet,
_expect: Option<&Type>,
) -> LowerResult<hir::NormalSet> {
log!(info "entered {}({set})", fn_name!());
let (elems, ..) = set.elems.deconstruct();
let mut union = Type::Never;
let mut new_set = vec![];
for elem in elems {
let elem = self.lower_expr(elem.expr)?;
let elem = self.lower_expr(elem.expr, None)?;
union = self.module.context.union(&union, elem.ref_t());
if ERG_MODE && union.is_union_type() {
return Err(LowerErrors::from(LowerError::syntax_error(
@ -504,11 +542,12 @@ impl ASTLowerer {
fn lower_set_with_length(
&mut self,
set: ast::SetWithLength,
_expect: Option<&Type>,
) -> LowerResult<hir::SetWithLength> {
log!("entered {}({set})", fn_name!());
let elem = self.lower_expr(set.elem.expr)?;
let elem = self.lower_expr(set.elem.expr, None)?;
let set_t = self.gen_set_with_length_type(&elem, &set.len);
let len = self.lower_expr(*set.len)?;
let len = self.lower_expr(*set.len, None)?;
let hir_set = hir::SetWithLength::new(set.l_brace, set.r_brace, set_t, elem, len);
Ok(hir_set)
}
@ -528,10 +567,10 @@ impl ASTLowerer {
}
}
fn lower_dict(&mut self, dict: ast::Dict) -> LowerResult<hir::Dict> {
fn lower_dict(&mut self, dict: ast::Dict, expect: Option<&Type>) -> LowerResult<hir::Dict> {
log!(info "enter {}({dict})", fn_name!());
match dict {
ast::Dict::Normal(set) => Ok(hir::Dict::Normal(self.lower_normal_dict(set)?)),
ast::Dict::Normal(set) => Ok(hir::Dict::Normal(self.lower_normal_dict(set, expect)?)),
other => feature_error!(
LowerErrors,
LowerError,
@ -539,17 +578,21 @@ impl ASTLowerer {
other.loc(),
"dict comprehension"
),
// ast::Dict::WithLength(set) => Ok(hir::Dict::WithLength(self.lower_dict_with_length(set)?)),
// ast::Dict::WithLength(set) => Ok(hir::Dict::WithLength(self.lower_dict_with_length(set, expect)?)),
}
}
fn lower_normal_dict(&mut self, dict: ast::NormalDict) -> LowerResult<hir::NormalDict> {
fn lower_normal_dict(
&mut self,
dict: ast::NormalDict,
_expect: Option<&Type>,
) -> LowerResult<hir::NormalDict> {
log!(info "enter {}({dict})", fn_name!());
let mut union = dict! {};
let mut new_kvs = vec![];
for kv in dict.kvs {
let key = self.lower_expr(kv.key)?;
let value = self.lower_expr(kv.value)?;
let key = self.lower_expr(kv.key, None)?;
let value = self.lower_expr(kv.value, None)?;
if let Some(popped_val_t) = union.insert(key.t(), value.t()) {
if PYTHON_MODE {
let val_t = union.get_mut(key.ref_t()).unwrap();
@ -627,16 +670,20 @@ impl ASTLowerer {
))
}
pub(crate) fn lower_acc(&mut self, acc: ast::Accessor) -> LowerResult<hir::Accessor> {
pub(crate) fn lower_acc(
&mut self,
acc: ast::Accessor,
expect: Option<&Type>,
) -> LowerResult<hir::Accessor> {
log!(info "entered {}({acc})", fn_name!());
match acc {
ast::Accessor::Ident(ident) => {
let ident = self.lower_ident(ident)?;
let ident = self.lower_ident(ident, expect)?;
let acc = hir::Accessor::Ident(ident);
Ok(acc)
}
ast::Accessor::Attr(attr) => {
let obj = self.lower_expr(*attr.obj)?;
let obj = self.lower_expr(*attr.obj, None)?;
let vi = match self.module.context.get_attr_info(
&obj,
&attr.ident,
@ -686,7 +733,11 @@ impl ASTLowerer {
}
}
fn lower_ident(&mut self, ident: ast::Identifier) -> LowerResult<hir::Identifier> {
fn lower_ident(
&mut self,
ident: ast::Identifier,
expect: Option<&Type>,
) -> LowerResult<hir::Identifier> {
// `match` is a special form, typing is magic
let (vi, __name__) = if ident.vis.is_private()
&& (&ident.inspect()[..] == "match" || &ident.inspect()[..] == "match!")
@ -739,6 +790,15 @@ impl ASTLowerer {
)
};
self.inc_ref(ident.inspect(), &vi, &ident.name);
if let Some(expect) = expect {
if let Err(_errs) = self
.module
.context
.sub_unify(&vi.t, expect, &ident.loc(), None)
{
// self.errs.extend(errs);
}
}
let ident = hir::Identifier::new(ident, __name__, vi);
Ok(ident)
}
@ -805,18 +865,18 @@ impl ASTLowerer {
}
}
fn lower_bin(&mut self, bin: ast::BinOp) -> hir::BinOp {
fn lower_bin(&mut self, bin: ast::BinOp, expect: Option<&Type>) -> hir::BinOp {
log!(info "entered {}({bin})", fn_name!());
let mut args = bin.args.into_iter();
let lhs = *args.next().unwrap();
let rhs = *args.next().unwrap();
let guard = self.get_guard_type(&bin.op, &lhs, &rhs);
let lhs = self.lower_expr(lhs).unwrap_or_else(|errs| {
let lhs = self.lower_expr(lhs, None).unwrap_or_else(|errs| {
self.errs.extend(errs);
hir::Expr::Dummy(hir::Dummy::new(vec![]))
});
let lhs = hir::PosArg::new(lhs);
let rhs = self.lower_expr(rhs).unwrap_or_else(|errs| {
let rhs = self.lower_expr(rhs, None).unwrap_or_else(|errs| {
self.errs.extend(errs);
hir::Expr::Dummy(hir::Dummy::new(vec![]))
});
@ -845,23 +905,32 @@ impl ASTLowerer {
*return_t = guard;
}
}
if let Some(expect) = expect {
if let Err(_errs) =
self.module
.context
.sub_unify(vi.t.return_t().unwrap(), expect, &args, None)
{
// self.errs.extend(errs);
}
}
let mut args = args.into_iter();
let lhs = args.next().unwrap().expr;
let rhs = args.next().unwrap().expr;
hir::BinOp::new(bin.op, lhs, rhs, vi)
}
fn lower_unary(&mut self, unary: ast::UnaryOp) -> hir::UnaryOp {
fn lower_unary(&mut self, unary: ast::UnaryOp, expect: Option<&Type>) -> hir::UnaryOp {
log!(info "entered {}({unary})", fn_name!());
let mut args = unary.args.into_iter();
let arg = self
.lower_expr(*args.next().unwrap())
.lower_expr(*args.next().unwrap(), None)
.unwrap_or_else(|errs| {
self.errs.extend(errs);
hir::Expr::Dummy(hir::Dummy::new(vec![]))
});
let args = [hir::PosArg::new(arg)];
let t = self
let vi = self
.module
.context
.get_unaryop_t(&unary.op, &args, &self.cfg.input, &self.module.context)
@ -869,12 +938,26 @@ impl ASTLowerer {
self.errs.extend(errs);
VarInfo::ILLEGAL
});
if let Some(expect) = expect {
if let Err(_errs) =
self.module
.context
.sub_unify(vi.t.return_t().unwrap(), expect, &args, None)
{
// self.errs.extend(errs);
}
}
let mut args = args.into_iter();
let expr = args.next().unwrap().expr;
hir::UnaryOp::new(unary.op, expr, t)
hir::UnaryOp::new(unary.op, expr, vi)
}
fn lower_args(&mut self, args: ast::Args, errs: &mut LowerErrors) -> hir::Args {
fn lower_args(
&mut self,
args: ast::Args,
expect: Option<&SubrType>,
errs: &mut LowerErrors,
) -> hir::Args {
let (pos_args, var_args, kw_args, paren) = args.deconstruct();
let mut hir_args = hir::Args::new(
Vec::with_capacity(pos_args.len()),
@ -882,8 +965,14 @@ impl ASTLowerer {
Vec::with_capacity(kw_args.len()),
paren,
);
for (nth, arg) in pos_args.into_iter().enumerate() {
match self.lower_expr(arg.expr) {
let pos_params = expect
.as_ref()
.map(|subr| subr.pos_params().map(|p| Some(p.typ())))
.map_or(vec![None; pos_args.len()], |params| {
params.take(pos_args.len()).collect()
});
for (nth, (arg, param)) in pos_args.into_iter().zip(pos_params).enumerate() {
match self.lower_expr(arg.expr, param) {
Ok(expr) => {
if let Some(kind) = self.module.context.control_kind() {
self.push_guard(nth, kind, expr.ref_t());
@ -897,7 +986,7 @@ impl ASTLowerer {
}
}
if let Some(var_args) = var_args {
match self.lower_expr(var_args.expr) {
match self.lower_expr(var_args.expr, None) {
Ok(expr) => hir_args.var_args = Some(Box::new(hir::PosArg::new(expr))),
Err(es) => {
errs.extend(es);
@ -907,7 +996,7 @@ impl ASTLowerer {
}
}
for arg in kw_args.into_iter() {
match self.lower_expr(arg.expr) {
match self.lower_expr(arg.expr, None) {
Ok(expr) => hir_args.push_kw(hir::KwArg::new(arg.keyword, expr)),
Err(es) => {
errs.extend(es);
@ -947,7 +1036,11 @@ impl ASTLowerer {
/// returning `Ok(call)` does not mean the call is valid, just means it is syntactically valid
/// `ASTLowerer` is designed to cause as little information loss in HIR as possible
pub(crate) fn lower_call(&mut self, call: ast::Call) -> LowerResult<hir::Call> {
pub(crate) fn lower_call(
&mut self,
call: ast::Call,
_expect: Option<&Type>,
) -> LowerResult<hir::Call> {
log!(info "entered {}({}{}(...))", fn_name!(), call.obj, fmt_option!(call.attr_name));
if let (Some(name), None) = (call.obj.get_name(), &call.attr_name) {
self.module.context.higher_order_caller.push(name.clone());
@ -968,8 +1061,7 @@ impl ASTLowerer {
} else {
None
};
let hir_args = self.lower_args(call.args, &mut errs);
let mut obj = match self.lower_expr(*call.obj) {
let mut obj = match self.lower_expr(*call.obj, None) {
Ok(obj) => obj,
Err(es) => {
self.module.context.higher_order_caller.pop();
@ -977,6 +1069,17 @@ impl ASTLowerer {
return Err(errs);
}
};
let opt_vi = self.module.context.get_call_t_without_args(
&obj,
&call.attr_name,
&self.cfg.input,
&self.module.context,
);
let expect_subr = opt_vi
.as_ref()
.ok()
.and_then(|vi| <&SubrType>::try_from(&vi.t).ok());
let hir_args = self.lower_args(call.args, expect_subr, &mut errs);
let mut vi = match self.module.context.get_call_t(
&obj,
&call.attr_name,
@ -1023,6 +1126,11 @@ impl ASTLowerer {
None
};
let mut call = hir::Call::new(obj, attr_name, hir_args);
/*if let Some((found, expect)) = call.signature_t().and_then(|sig| sig.return_t()).zip(expect) {
if let Err(errs) = self.module.context.sub_unify(found, expect, &call, None) {
self.errs.extend(errs);
}
}*/
self.module.context.higher_order_caller.pop();
if errs.is_empty() {
self.exec_additional_op(&mut call)?;
@ -1088,10 +1196,14 @@ impl ASTLowerer {
}
}
fn lower_pack(&mut self, pack: ast::DataPack) -> LowerResult<hir::Call> {
fn lower_pack(
&mut self,
pack: ast::DataPack,
_expect: Option<&Type>,
) -> LowerResult<hir::Call> {
log!(info "entered {}({pack})", fn_name!());
let class = self.lower_expr(*pack.class)?;
let args = self.lower_record(pack.args)?;
let class = self.lower_expr(*pack.class, None)?;
let args = self.lower_record(pack.args, None)?;
let args = vec![hir::PosArg::new(hir::Expr::Record(args))];
let attr_name = ast::Identifier::new(
VisModifierSpec::Public(Token::new(
@ -1172,7 +1284,7 @@ impl ASTLowerer {
};
let mut hir_defaults = vec![];
for default in params.defaults.into_iter() {
match self.lower_expr(default.default_val) {
match self.lower_expr(default.default_val, None) {
Ok(default_val) => {
let sig = self.lower_non_default_param(default.sig)?;
hir_defaults.push(hir::DefaultParamSignature::new(sig, default_val));
@ -1193,7 +1305,12 @@ impl ASTLowerer {
}
}
fn lower_lambda(&mut self, lambda: ast::Lambda) -> LowerResult<hir::Lambda> {
fn lower_lambda(
&mut self,
lambda: ast::Lambda,
expect: Option<&Type>,
) -> LowerResult<hir::Lambda> {
let expect = expect.and_then(|t| <&SubrType>::try_from(t).ok());
log!(info "entered {}({lambda})", fn_name!());
let in_statement = PYTHON_MODE
&& self
@ -1224,7 +1341,11 @@ impl ASTLowerer {
}
errs
})?;
if let Err(errs) = self.module.context.assign_params(&mut params, None) {
if let Err(errs) = self
.module
.context
.assign_params(&mut params, expect.cloned())
{
self.errs.extend(errs);
}
let overwritten = {
@ -1244,7 +1365,7 @@ impl ASTLowerer {
if let Err(errs) = self.module.context.register_const(&lambda.body) {
self.errs.extend(errs);
}
let body = self.lower_block(lambda.body).map_err(|errs| {
let body = self.lower_block(lambda.body, None).map_err(|errs| {
if !in_statement {
self.pop_append_errs();
}
@ -1467,7 +1588,7 @@ impl ASTLowerer {
if let Err(errs) = self.module.context.register_const(&body.block) {
self.errs.extend(errs);
}
match self.lower_block(body.block) {
match self.lower_block(body.block, None) {
Ok(block) => {
let found_body_t = block.ref_t();
let outer = self.module.context.outer.as_ref().unwrap();
@ -1565,7 +1686,7 @@ impl ASTLowerer {
if let Err(errs) = self.module.context.register_const(&body.block) {
self.errs.extend(errs);
}
match self.lower_block(body.block) {
match self.lower_block(body.block, None) {
Ok(block) => {
let found_body_t = self.module.context.squash_tyvar(block.t());
let vi = match self.module.context.outer.as_mut().unwrap().assign_subr(
@ -1636,7 +1757,7 @@ impl ASTLowerer {
.as_mut()
.unwrap()
.fake_subr_assign(&sig.ident, &sig.decorators, Type::Failure)?;
let block = self.lower_block(body.block)?;
let block = self.lower_block(body.block, None)?;
let ident = hir::Identifier::bare(sig.ident);
let ret_t_spec = if let Some(ts) = sig.return_t_spec {
let spec_t = self.module.context.instantiate_typespec(&ts.t_spec)?;
@ -1729,7 +1850,7 @@ impl ASTLowerer {
self.errs.extend(errs);
}
},
ast::ClassAttr::Decl(decl) => match self.lower_type_asc(decl) {
ast::ClassAttr::Decl(decl) => match self.lower_type_asc(decl, None) {
Ok(decl) => {
hir_methods.push(hir::Expr::TypeAsc(decl));
}
@ -1737,7 +1858,7 @@ impl ASTLowerer {
self.errs.extend(errs);
}
},
ast::ClassAttr::Doc(doc) => match self.lower_literal(doc) {
ast::ClassAttr::Doc(doc) => match self.lower_literal(doc, None) {
Ok(doc) => {
hir_methods.push(hir::Expr::Literal(doc));
}
@ -1936,7 +2057,7 @@ impl ASTLowerer {
self.errs.extend(errs);
}
},
ast::ClassAttr::Decl(decl) => match self.lower_type_asc(decl) {
ast::ClassAttr::Decl(decl) => match self.lower_type_asc(decl, None) {
Ok(decl) => {
hir_methods.push(hir::Expr::TypeAsc(decl));
}
@ -1944,7 +2065,7 @@ impl ASTLowerer {
self.errs.extend(errs);
}
},
ast::ClassAttr::Doc(doc) => match self.lower_literal(doc) {
ast::ClassAttr::Doc(doc) => match self.lower_literal(doc, None) {
Ok(doc) => {
hir_methods.push(hir::Expr::Literal(doc));
}
@ -1964,8 +2085,8 @@ impl ASTLowerer {
fn lower_redef(&mut self, redef: ast::ReDef) -> LowerResult<hir::ReDef> {
log!(info "entered {}({redef})", fn_name!());
let mut attr = self.lower_acc(redef.attr)?;
let expr = self.lower_expr(*redef.expr)?;
let mut attr = self.lower_acc(redef.attr, None)?;
let expr = self.lower_expr(*redef.expr, None)?;
if let Err(err) =
self.var_result_t_check(&attr, &Str::from(attr.show()), attr.ref_t(), expr.ref_t())
{
@ -2302,14 +2423,18 @@ impl ASTLowerer {
}
}
fn lower_type_asc(&mut self, tasc: ast::TypeAscription) -> LowerResult<hir::TypeAscription> {
fn lower_type_asc(
&mut self,
tasc: ast::TypeAscription,
expect: Option<&Type>,
) -> LowerResult<hir::TypeAscription> {
log!(info "entered {}({tasc})", fn_name!());
let kind = tasc.kind();
let spec_t = self
.module
.context
.instantiate_typespec(&tasc.t_spec.t_spec)?;
let expr = self.lower_expr(*tasc.expr)?;
let expr = self.lower_expr(*tasc.expr, expect)?;
match kind {
AscriptionKind::TypeOf | AscriptionKind::AsCast => {
self.module.context.sub_unify(
@ -2438,25 +2563,27 @@ impl ASTLowerer {
// Call.obj == Accessor cannot be type inferred by itself (it can only be inferred with arguments)
// so turn off type checking (check=false)
fn lower_expr(&mut self, expr: ast::Expr) -> LowerResult<hir::Expr> {
fn lower_expr(&mut self, expr: ast::Expr, expect: Option<&Type>) -> LowerResult<hir::Expr> {
log!(info "entered {}", fn_name!());
let casted = self.module.context.get_casted_type(&expr);
let mut expr = match expr {
ast::Expr::Literal(lit) => hir::Expr::Literal(self.lower_literal(lit)?),
ast::Expr::Array(arr) => hir::Expr::Array(self.lower_array(arr)?),
ast::Expr::Tuple(tup) => hir::Expr::Tuple(self.lower_tuple(tup)?),
ast::Expr::Record(rec) => hir::Expr::Record(self.lower_record(rec)?),
ast::Expr::Set(set) => hir::Expr::Set(self.lower_set(set)?),
ast::Expr::Dict(dict) => hir::Expr::Dict(self.lower_dict(dict)?),
ast::Expr::Accessor(acc) => hir::Expr::Accessor(self.lower_acc(acc)?),
ast::Expr::BinOp(bin) => hir::Expr::BinOp(self.lower_bin(bin)),
ast::Expr::UnaryOp(unary) => hir::Expr::UnaryOp(self.lower_unary(unary)),
ast::Expr::Call(call) => hir::Expr::Call(self.lower_call(call)?),
ast::Expr::DataPack(pack) => hir::Expr::Call(self.lower_pack(pack)?),
ast::Expr::Lambda(lambda) => hir::Expr::Lambda(self.lower_lambda(lambda)?),
ast::Expr::TypeAscription(tasc) => hir::Expr::TypeAsc(self.lower_type_asc(tasc)?),
ast::Expr::Literal(lit) => hir::Expr::Literal(self.lower_literal(lit, expect)?),
ast::Expr::Array(arr) => hir::Expr::Array(self.lower_array(arr, expect)?),
ast::Expr::Tuple(tup) => hir::Expr::Tuple(self.lower_tuple(tup, expect)?),
ast::Expr::Record(rec) => hir::Expr::Record(self.lower_record(rec, expect)?),
ast::Expr::Set(set) => hir::Expr::Set(self.lower_set(set, expect)?),
ast::Expr::Dict(dict) => hir::Expr::Dict(self.lower_dict(dict, expect)?),
ast::Expr::Accessor(acc) => hir::Expr::Accessor(self.lower_acc(acc, expect)?),
ast::Expr::BinOp(bin) => hir::Expr::BinOp(self.lower_bin(bin, expect)),
ast::Expr::UnaryOp(unary) => hir::Expr::UnaryOp(self.lower_unary(unary, expect)),
ast::Expr::Call(call) => hir::Expr::Call(self.lower_call(call, expect)?),
ast::Expr::DataPack(pack) => hir::Expr::Call(self.lower_pack(pack, expect)?),
ast::Expr::Lambda(lambda) => hir::Expr::Lambda(self.lower_lambda(lambda, expect)?),
ast::Expr::TypeAscription(tasc) => {
hir::Expr::TypeAsc(self.lower_type_asc(tasc, expect)?)
}
// Checking is also performed for expressions in Dummy. However, it has no meaning in code generation
ast::Expr::Dummy(dummy) => hir::Expr::Dummy(self.lower_dummy(dummy)?),
ast::Expr::Dummy(dummy) => hir::Expr::Dummy(self.lower_dummy(dummy, expect)?),
other => {
log!(err "unreachable: {other}");
return unreachable_error!(LowerErrors, LowerError, self.module.context);
@ -2475,7 +2602,11 @@ impl ASTLowerer {
/// The meaning of TypeAscription changes between chunk and expr.
/// For example, `x: Int`, as expr, is `x` itself,
/// but as chunk, it declares that `x` is of type `Int`, and is valid even before `x` is defined.
pub fn lower_chunk(&mut self, chunk: ast::Expr) -> LowerResult<hir::Expr> {
pub fn lower_chunk(
&mut self,
chunk: ast::Expr,
expect: Option<&Type>,
) -> LowerResult<hir::Expr> {
log!(info "entered {}", fn_name!());
match chunk {
ast::Expr::Def(def) => Ok(hir::Expr::Def(self.lower_def(def)?)),
@ -2483,15 +2614,21 @@ impl ASTLowerer {
ast::Expr::PatchDef(defs) => Ok(hir::Expr::PatchDef(self.lower_patch_def(defs)?)),
ast::Expr::ReDef(redef) => Ok(hir::Expr::ReDef(self.lower_redef(redef)?)),
ast::Expr::TypeAscription(tasc) => Ok(hir::Expr::TypeAsc(self.lower_decl(tasc)?)),
other => self.lower_expr(other),
other => self.lower_expr(other, expect),
}
}
fn lower_block(&mut self, ast_block: ast::Block) -> LowerResult<hir::Block> {
fn lower_block(
&mut self,
ast_block: ast::Block,
expect: Option<&Type>,
) -> LowerResult<hir::Block> {
log!(info "entered {}", fn_name!());
let mut hir_block = Vec::with_capacity(ast_block.len());
for chunk in ast_block.into_iter() {
let chunk = match self.lower_chunk(chunk) {
let last = ast_block.len() - 1;
for (i, chunk) in ast_block.into_iter().enumerate() {
let expect = if i == last { expect } else { None };
let chunk = match self.lower_chunk(chunk, expect) {
Ok(chunk) => chunk,
Err(errs) => {
self.errs.extend(errs);
@ -2503,11 +2640,17 @@ impl ASTLowerer {
Ok(hir::Block::new(hir_block))
}
fn lower_dummy(&mut self, ast_dummy: ast::Dummy) -> LowerResult<hir::Dummy> {
fn lower_dummy(
&mut self,
ast_dummy: ast::Dummy,
expect: Option<&Type>,
) -> LowerResult<hir::Dummy> {
log!(info "entered {}", fn_name!());
let mut hir_dummy = Vec::with_capacity(ast_dummy.len());
for chunk in ast_dummy.into_iter() {
let chunk = self.lower_chunk(chunk)?;
let last = ast_dummy.len() - 1;
for (i, chunk) in ast_dummy.into_iter().enumerate() {
let expect = if i == last { expect } else { None };
let chunk = self.lower_chunk(chunk, expect)?;
hir_dummy.push(chunk);
}
Ok(hir::Dummy::new(hir_dummy))
@ -2562,7 +2705,7 @@ impl ASTLowerer {
self.errs.extend(errs);
}
for chunk in ast.module.into_iter() {
match self.lower_chunk(chunk) {
match self.lower_chunk(chunk, None) {
Ok(chunk) => {
module.push(chunk);
}

View file

@ -556,6 +556,21 @@ impl SubrType {
}
}
/// WARN: This is an infinite iterator
///
/// `self` is not included
pub fn pos_params(&self) -> impl Iterator<Item = &ParamTy> + Clone {
let non_defaults = self
.non_default_params
.iter()
.filter(|pt| !pt.name().is_some_and(|n| &n[..] == "self"));
if let Some(var_params) = self.var_params.as_ref() {
non_defaults.chain(std::iter::repeat(var_params.as_ref()))
} else {
non_defaults.chain(std::iter::repeat(&ParamTy::Pos(Type::Failure)))
}
}
pub fn param_names(&self) -> impl Iterator<Item = &str> + Clone {
self.non_default_params
.iter()
@ -2388,19 +2403,6 @@ impl Type {
}
}
pub fn union_types(&self) -> Vec<Type> {
match self {
Self::FreeVar(fv) if fv.is_linked() => fv.crack().union_types(),
Self::Refinement(refine) => refine.t.union_types(),
Self::Or(t1, t2) => {
let mut types = t1.union_types();
types.extend(t2.union_types());
types
}
_ => vec![self.clone()],
}
}
/// assert!((A or B).contains_union(B))
pub fn contains_union(&self, typ: &Type) -> bool {
match self {