From 966bb6076601bb581973e13e0fed8bc8efc72026 Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Sat, 31 Aug 2019 21:32:14 -0400 Subject: [PATCH] Infer function calls --- src/canonicalize.rs | 6 ++-- src/constrain.rs | 62 ++++++++++++++++++++++++++++++++ src/pretty_print_types.rs | 2 +- src/subs.rs | 2 +- src/types.rs | 6 +++- src/unify.rs | 45 +++++++++++++++++------ tests/test_infer.rs | 75 ++++++++++++++++++--------------------- 7 files changed, 140 insertions(+), 58 deletions(-) diff --git a/src/canonicalize.rs b/src/canonicalize.rs index 4085ffcd80..bec3e10c36 100644 --- a/src/canonicalize.rs +++ b/src/canonicalize.rs @@ -33,7 +33,7 @@ pub enum Expr { Assign(Vec<(Located, Located)>, Box>), // Application - Apply(Box>, Vec>), + Call(Box>, Vec>), ApplyVariant(Symbol, Option>>), // Product Types @@ -372,7 +372,7 @@ fn canonicalize( _ => () }; - let expr = Apply(Box::new(fn_expr), args); + let expr = Call(Box::new(fn_expr), args); for arg_out in outputs { output.references = output.references.union(arg_out.references); @@ -398,7 +398,7 @@ fn canonicalize( Pizza => { match &right_expr.value { &Var(ref sym) => Some(sym.clone()), - &Apply(ref loc_boxed_expr, _) => { + &Call(ref loc_boxed_expr, _) => { match (*loc_boxed_expr.clone()).value { Var(sym) => Some(sym), _ => None diff --git a/src/constrain.rs b/src/constrain.rs index b6401f24f8..f15fa652b3 100644 --- a/src/constrain.rs +++ b/src/constrain.rs @@ -62,10 +62,72 @@ pub fn constrain( constrain_defs(assignments, bound_vars, subs, ret_con) } } + Call(box_loc_fn_expr, args) => { + constrain_call(bound_vars, subs, *box_loc_fn_expr, args, expected, region) + }, _ => { panic!("TODO constraints for {:?}", loc_expr.value) } } } +fn constrain_call( + bound_vars: &BoundTypeVars, + subs: &mut Subs, + loc_expr: Located, + args: Vec>, + expected: Expected, + region: Region +) -> Constraint { +// constrainCall :: RTV -> A.Region -> Can.Expr -> [Can.Expr] -> Expected Type -> IO Constraint +// constrainCall rtv region func@(A.At funcRegion _) args expected = + // let maybeName = getName func + + let fn_var = subs.mk_flex_var(); + let ret_var = subs.mk_flex_var(); + let fn_type = Variable(fn_var); + let ret_type = Variable(ret_var); + let fn_region = loc_expr.region.clone(); + let fn_expected = NoExpectation(fn_type.clone()); + let fn_con = constrain(bound_vars, subs, loc_expr, fn_expected); + let fn_reason = + // TODO look up the name and use NamedFnArg if possible. + Reason::AnonymousFnCall(args.len() as u8); + + let mut arg_vars = Vec::with_capacity(args.len()); + let mut arg_types = Vec::with_capacity(args.len()); + let mut arg_cons = Vec::with_capacity(args.len()); + + for (index, loc_arg) in args.into_iter().enumerate() { + let region = loc_arg.region.clone(); + let arg_var = subs.mk_flex_var(); + let arg_type = Variable(arg_var); + let reason = + // TODO look up the name and use NamedFnArg if possible. + Reason::AnonymousFnArg(index as u8); + let expected_arg = ForReason(reason, arg_type.clone(), region.clone()); + let arg_con = constrain(bound_vars, subs, loc_arg, expected_arg); + + arg_vars.push(arg_var); + arg_types.push(arg_type); + arg_cons.push(arg_con); + } + + // TODO occurs check! + // return $ exists (funcVar:resultVar:argVars) $ CAnd ... + + let expected_fn_type = ForReason( + fn_reason, + Function(arg_types, Box::new(ret_type.clone())), + region.clone() + ); + + And(vec![ + fn_con, + Eq(fn_type, expected_fn_type, fn_region), + And(arg_cons), + Eq(ret_type, expected, region) + ]) +} + pub fn constrain_defs( assignments: Vec<(Located, Located)>, bound_vars: &BoundTypeVars, diff --git a/src/pretty_print_types.rs b/src/pretty_print_types.rs index 684de0789f..4c3a02a0a4 100644 --- a/src/pretty_print_types.rs +++ b/src/pretty_print_types.rs @@ -19,7 +19,7 @@ fn write_content(content: Content, subs: &mut Subs, buf: &mut String, use_parens FlexVar(None) => buf.push_str(WILDCARD), RigidVar(name) => buf.push_str(&name), Structure(flat_type) => write_flat_type(flat_type, subs, buf, use_parens), - Error => buf.push_str("") + Error(_) => buf.push_str("") } } diff --git a/src/subs.rs b/src/subs.rs index 7eb07cfcbb..a2d5c1f6a5 100644 --- a/src/subs.rs +++ b/src/subs.rs @@ -136,7 +136,7 @@ pub enum Content { FlexVar(Option /* name */), RigidVar(String /* name */), Structure(FlatType), - Error + Error(Problem) } #[derive(Clone, Debug, PartialEq, Eq)] diff --git a/src/types.rs b/src/types.rs index 4e92f3d52c..8aae6c0606 100644 --- a/src/types.rs +++ b/src/types.rs @@ -36,6 +36,10 @@ impl Expected { #[derive(Debug, Clone)] pub enum Reason { + AnonymousFnArg(u8 /* arg index */), + NamedFnArg(String /* function name */, u8 /* arg index */), + AnonymousFnCall(u8 /* arity */), + NamedFnCall(String /* function name */, u8 /* arity */), OperatorLeftArg(Operator), OperatorRightArg(Operator), FractionalLiteral, @@ -63,7 +67,7 @@ pub struct LetConstraint { #[derive(PartialEq, Eq, Debug, Clone)] pub enum Problem { - GenericMismatch(Box, Box), + GenericMismatch, ExtraArguments, MissingArguments, IfConditionNotBool, diff --git a/src/unify.rs b/src/unify.rs index 59c69accfc..b1e3e1c846 100644 --- a/src/unify.rs +++ b/src/unify.rs @@ -1,5 +1,6 @@ use subs::{Descriptor, FlatType, Variable, Subs}; use subs::Content::{self, *}; +use types::Problem; #[inline(always)] pub fn unify_vars(subs: &mut Subs, left_key: Variable, right_key: Variable) -> Descriptor { @@ -26,9 +27,9 @@ pub fn unify(subs: &mut Subs, left: &Descriptor, right: &Descriptor) -> Descript Structure(ref flat_type) => { unify_structure(subs, flat_type, &right.content) } - Error => { + Error(ref problem) => { // Error propagates. Whatever we're comparing it to doesn't matter! - from_content(Error) + from_content(Error(problem.clone())) } }; @@ -47,15 +48,15 @@ fn unify_structure(subs: &mut Subs, flat_type: &FlatType, other: &Content) -> De }, RigidVar(_) => { // Type mismatch! Rigid can only unify with flex. - from_content(Error) + from_content(Error(Problem::GenericMismatch)) }, Structure(ref other_flat_type) => { // Type mismatch! Rigid can only unify with flex. unify_flat_type(subs, flat_type, other_flat_type) }, - Error => { + Error(problem) => { // Error propagates. - from_content(Error) + from_content(Error(problem.clone())) }, } } @@ -75,8 +76,20 @@ fn unify_flat_type(subs: &mut Subs, left: &FlatType, right: &FlatType) -> Descri from_content(Structure(flat_type)) }, - (Func(_, _), Func(_, _)) => panic!("TODO unify_flat_type for Func"), - _ => from_content(Error) + (Func(l_args, l_ret), Func(r_args, r_ret)) => { + if l_args.len() == r_args.len() { + let args = unify_args(subs, l_args.iter(), r_args.iter()); + let ret = union_vars(subs, l_ret.clone(), r_ret.clone()); + let flat_type = Func(args, ret); + + from_content(Structure(flat_type)) + } else if l_args.len() > r_args.len() { + from_content(Error(Problem::ExtraArguments)) + } else { + from_content(Error(Problem::MissingArguments)) + } + }, + _ => from_content(Error(Problem::GenericMismatch)) } } @@ -95,6 +108,16 @@ where I: Iterator }).collect() } +fn union_vars(subs: &mut Subs, l_var: Variable, r_var: Variable) -> Variable { + // Look up the descriptors we have for these variables, and unify them. + let descriptor = unify_vars(subs, l_var.clone(), r_var.clone()); + + // set r_var to be the unioned value, then union l_var to r_var + subs.set(r_var.clone(), descriptor); + subs.union(l_var.clone(), r_var.clone()); + + r_var.clone() +} #[inline(always)] fn unify_rigid(name: &String, other: &Content) -> Descriptor { @@ -106,11 +129,11 @@ fn unify_rigid(name: &String, other: &Content) -> Descriptor { RigidVar(_) | Structure(_) => { // Type mismatch! Rigid can only unify with flex, even if the // rigid names are the same. - from_content(Error) + from_content(Error(Problem::GenericMismatch)) }, - Error => { + Error(problem) => { // Error propagates. - from_content(Error) + from_content(Error(problem.clone())) }, } } @@ -123,7 +146,7 @@ fn unify_flex(opt_name: &Option, other: &Content) -> Descriptor { // If both are flex, and only left has a name, keep the name around. from_content(FlexVar(opt_name.clone())) }, - FlexVar(Some(_)) | RigidVar(_) | Structure(_) | Error => { + FlexVar(Some(_)) | RigidVar(_) | Structure(_) | Error(_) => { // In all other cases, if left is flex, defer to right. // (This includes using right's name if both are flex and named.) from_content(other.clone()) diff --git a/tests/test_infer.rs b/tests/test_infer.rs index ae362a5c32..1371652df1 100644 --- a/tests/test_infer.rs +++ b/tests/test_infer.rs @@ -401,53 +401,46 @@ mod test_infer { ); } - // TODO identity function + // CALLING FUNCTIONS + + #[test] + fn call_returns_list() { + infer_eq( + indoc!(r#" + enlist = \val -> [ val ] + + enlist 5 + "#), + "List.List (Num.Num *)" + ); + } + // TODO calling functions + // TODO conditionals + // TODO type annotations // TODO BoundTypeVariables - // #[test] - // fn int_thunk() { - // assert_eq!( - // infer(indoc!(r#" - // \_ -> 5 - // "#)), - // Function(vec![var(0)], Box::new(Builtin(Int))) - // ); - // } - - // #[test] - // fn string_thunk() { - // assert_eq!( - // infer(indoc!(r#" - // \_ -> "thunk!" - // "#)), - // Function(vec![var(0)], Box::new(Builtin(Str))) - // ); - // } +// #[test] +// fn identity() { +// infer_eq( +// indoc!(r#" +// \val -> val +// "#), +// "a -> a" +// ); +// } - // #[test] - // fn identity_function() { - // assert_eq!( - // infer(indoc!(r#" - // \val -> val - // "#)), - // Function(vec![var(0)], box_var(0)) - // ); - // } +// #[test] +// fn always_function() { +// infer_eq( +// indoc!(r#" +// \val -> \_ -> val +// "#), +// "a -> (* -> a)" +// ); +// } - // #[test] - // fn always_function() { - // assert_eq!( - // infer(indoc!(r#" - // \val -> \_ -> val - // "#)), - // Function( - // vec![var(0)], - // Box::new(Function(vec![var(1)], box_var(0))) - // ) - // ); - // } // #[test] // fn basic_circular_type() {