diff --git a/src/unify.rs b/src/unify.rs index 9de032c497..cc169d50d4 100644 --- a/src/unify.rs +++ b/src/unify.rs @@ -13,12 +13,18 @@ struct Context { struct RecordStructure { fields: ImMap, - extension: Variable, + ext: Variable, } +type UnifyResult = Result<(), Problem>; + +type Problems = Vec; + #[inline(always)] -pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable) { - if !subs.equivalent(var1, var2) { +pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable) -> UnifyResult { + if subs.equivalent(var1, var2) { + Ok(()) + } else { let ctx = Context { first: var1, first_desc: subs.get(var1), @@ -30,7 +36,7 @@ pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable) { } } -fn unify_context(subs: &mut Subs, ctx: Context) { +fn unify_context(subs: &mut Subs, ctx: Context) -> UnifyResult { match &ctx.first_desc.content { FlexVar(opt_name) => unify_flex(subs, &ctx, opt_name, &ctx.second_desc.content), RigidVar(name) => unify_rigid(subs, &ctx, name, &ctx.second_desc.content), @@ -38,7 +44,9 @@ fn unify_context(subs: &mut Subs, ctx: Context) { Alias(home, name, args, real_var) => unify_alias(subs, &ctx, home, name, args, real_var), Error(problem) => { // Error propagates. Whatever we're comparing it to doesn't matter! - merge(subs, &ctx, Error(problem.clone())) + merge(subs, &ctx, Error(problem.clone())); + + Err(problem.clone()) } } } @@ -51,7 +59,7 @@ fn unify_alias( name: &Uppercase, args: &Vec<(Lowercase, Variable)>, real_var: &Variable, -) { +) -> UnifyResult { let other_content = &ctx.second_desc.content; match other_content { @@ -61,41 +69,71 @@ fn unify_alias( subs, &ctx, Alias(home.clone(), name.clone(), args.clone(), *real_var), - ) + ); + + Ok(()) } RigidVar(_) => unify(subs, *real_var, ctx.second), Alias(other_home, other_name, other_args, other_real_var) => { if name == other_name && home == other_home { if args.len() == other_args.len() { + let mut answer = Ok(()); + for ((_, l_var), (_, r_var)) in args.iter().zip(other_args.iter()) { - unify(subs, *l_var, *r_var); + let result = unify(subs, *l_var, *r_var); + + answer = answer.and_then(|()| result); } - merge(subs, &ctx, other_content.clone()) + merge(subs, &ctx, other_content.clone()); + + answer } else if args.len() > other_args.len() { - merge(subs, &ctx, Error(Problem::ExtraArguments)) + let problem = Problem::ExtraArguments; + + merge(subs, &ctx, Error(problem.clone())); + + Err(problem) } else { - merge(subs, &ctx, Error(Problem::MissingArguments)) + let problem = Problem::MissingArguments; + + merge(subs, &ctx, Error(problem.clone())); + + Err(problem) } } else { unify(subs, *real_var, *other_real_var) } } Structure(_) => unify(subs, *real_var, ctx.second), - Error(problem) => merge(subs, ctx, Error(problem.clone())), + Error(problem) => { + merge(subs, ctx, Error(problem.clone())); + + Err(problem.clone()) + } } } #[inline(always)] -fn unify_structure(subs: &mut Subs, ctx: &Context, flat_type: &FlatType, other: &Content) { +fn unify_structure( + subs: &mut Subs, + ctx: &Context, + flat_type: &FlatType, + other: &Content, +) -> UnifyResult { match other { FlexVar(_) => { // If the other is flex, Structure wins! - merge(subs, ctx, Structure(flat_type.clone())) + merge(subs, ctx, Structure(flat_type.clone())); + + Ok(()) } RigidVar(_) => { + let problem = Problem::GenericMismatch; // Type mismatch! Rigid can only unify with flex. - merge(subs, ctx, Error(Problem::GenericMismatch)) + merge(subs, ctx, Error(problem.clone())); + + Err(problem) } Structure(ref other_flat_type) => { // Unify the two flat types @@ -104,32 +142,114 @@ fn unify_structure(subs: &mut Subs, ctx: &Context, flat_type: &FlatType, other: Alias(_, _, _, real_var) => unify(subs, ctx.first, *real_var), Error(problem) => { // Error propagates. - merge(subs, ctx, Error(problem.clone())) + merge(subs, ctx, Error(problem.clone())); + + Err(problem.clone()) } } } -fn unify_record(ctx: &Context, structure1: RecordStructure, structure2: RecordStructure) { - panic!("TODO unify_record"); - let x = 5; +fn unify_record( + subs: &mut Subs, + ctx: &Context, + rec1: RecordStructure, + rec2: RecordStructure, +) -> UnifyResult { + let fields1 = rec1.fields; + let fields2 = rec2.fields; + let shared_fields = fields1 + .clone() + .intersection_with(fields2.clone(), |one, two| (one, two)); + let unique_fields1 = fields1.clone().difference(fields2.clone()); + let unique_fields2 = fields2.difference(fields1); + + if unique_fields1.is_empty() { + if unique_fields2.is_empty() { + unify(subs, rec1.ext, rec2.ext); + unify_shared_fields(subs, ctx, shared_fields, ImMap::default(), rec1.ext) + } else { + // subRecord <- fresh context (Structure (Record1 uniqueFields2 ext2)) + // subUnify ext1 subRecord + // unifySharedFields context sharedFields Map.empty subRecord + panic!("TODO 1"); + } + } else if unique_fields2.is_empty() { + // subRecord <- fresh context (Structure (Record1 uniqueFields1 ext1)) + // subUnify subRecord ext2 + // unifySharedFields context sharedFields Map.empty subRecord + panic!("TODO 2"); + } else { + // let otherFields = Map.union uniqueFields1 uniqueFields2 + // ext <- fresh context Type.unnamedFlexVar + // sub1 <- fresh context (Structure (Record1 uniqueFields1 ext)) + // sub2 <- fresh context (Structure (Record1 uniqueFields2 ext)) + // subUnify ext1 sub2 + // subUnify sub1 ext2 + // unifySharedFields context sharedFields otherFields ext + // + panic!("TODO 3"); + } +} + +fn unify_shared_fields( + subs: &mut Subs, + ctx: &Context, + shared_fields: ImMap, + other_fields: ImMap, + ext: Variable, +) -> UnifyResult { + let mut matching_fields = ImMap::default(); + let num_shared_fields = shared_fields.len(); + + for (name, (actual, expected)) in shared_fields { + // TODO another way to do this might be to pass around a problems vec + // and check to see if its length increased after doing this unification. + if unify(subs, actual, expected).is_ok() { + matching_fields.insert(name, actual); + } + } + + if num_shared_fields == matching_fields.len() { + let flat_type = FlatType::Record(matching_fields.union(other_fields), ext); + + merge(subs, ctx, Structure(flat_type)); + + Ok(()) + } else { + let problem = Problem::GenericMismatch; + + // Type mismatch! Rigid can only unify with flex. + merge(subs, ctx, Error(problem.clone())); + + Err(problem) + } } #[inline(always)] -fn unify_flat_type(subs: &mut Subs, ctx: &Context, left: &FlatType, right: &FlatType) { +fn unify_flat_type( + subs: &mut Subs, + ctx: &Context, + left: &FlatType, + right: &FlatType, +) -> UnifyResult { use crate::subs::FlatType::*; match (left, right) { - (EmptyRecord, EmptyRecord) => merge(subs, ctx, Structure(left.clone())), + (EmptyRecord, EmptyRecord) => { + merge(subs, ctx, Structure(left.clone())); + + Ok(()) + } (Record(fields, ext), EmptyRecord) if fields.is_empty() => unify(subs, *ext, ctx.second), (EmptyRecord, Record(fields, ext)) if fields.is_empty() => unify(subs, ctx.first, *ext), (Record(fields1, ext1), Record(fields2, ext2)) => { - let structure1 = gather_fields(subs, fields1.clone(), *ext1); - let structure2 = gather_fields(subs, fields2.clone(), *ext2); + let rec1 = gather_fields(subs, fields1.clone(), *ext1); + let rec2 = gather_fields(subs, fields2.clone(), *ext2); - unify_record(ctx, structure1, structure2) + unify_record(subs, ctx, rec1, rec2) } ( Apply { @@ -153,21 +273,35 @@ fn unify_flat_type(subs: &mut Subs, ctx: &Context, left: &FlatType, right: &Flat name: (*r_type_name).clone(), args: (*r_args).clone(), }), - ) + ); + + Ok(()) } (Func(l_args, l_ret), Func(r_args, r_ret)) => { if l_args.len() == r_args.len() { unify_zip(subs, l_args.iter(), r_args.iter()); - unify(subs, *l_ret, *r_ret); + let answer = unify(subs, *l_ret, *r_ret); - merge(subs, ctx, Structure(Func((*r_args).clone(), *r_ret))) + merge(subs, ctx, Structure(Func((*r_args).clone(), *r_ret))); + + answer } else if l_args.len() > r_args.len() { - merge(subs, ctx, Error(Problem::ExtraArguments)) + merge(subs, ctx, Error(Problem::ExtraArguments)); + + Ok(()) } else { - merge(subs, ctx, Error(Problem::MissingArguments)) + merge(subs, ctx, Error(Problem::MissingArguments)); + + Ok(()) } } - _ => merge(subs, ctx, Error(Problem::GenericMismatch)), + _ => { + let problem = Problem::GenericMismatch; + + merge(subs, ctx, Error(problem.clone())); + + Err(problem) + } } } @@ -181,39 +315,60 @@ where } #[inline(always)] -fn unify_rigid(subs: &mut Subs, ctx: &Context, name: &str, other: &Content) { +fn unify_rigid(subs: &mut Subs, ctx: &Context, name: &str, other: &Content) -> UnifyResult { match other { FlexVar(_) => { // If the other is flex, rigid wins! - merge(subs, ctx, RigidVar(name.into())) + merge(subs, ctx, RigidVar(name.into())); + + Ok(()) } RigidVar(_) | Structure(_) => { // Type mismatch! Rigid can only unify with flex, even if the // rigid names are the same. - merge(subs, ctx, Error(Problem::GenericMismatch)) + merge(subs, ctx, Error(Problem::GenericMismatch)); + + Ok(()) } Alias(_, _, _, _) => { panic!("TODO unify_rigid Alias"); - panic!("TODO"); + + Ok(()) } Error(problem) => { // Error propagates. - merge(subs, ctx, Error(problem.clone())) + merge(subs, ctx, Error(problem.clone())); + + Err(problem.clone()) } } } #[inline(always)] -fn unify_flex(subs: &mut Subs, ctx: &Context, opt_name: &Option>, other: &Content) { +fn unify_flex( + subs: &mut Subs, + ctx: &Context, + opt_name: &Option>, + other: &Content, +) -> UnifyResult { match other { FlexVar(None) => { // If both are flex, and only left has a name, keep the name around. - merge(subs, ctx, FlexVar(opt_name.clone())) + merge(subs, ctx, FlexVar(opt_name.clone())); + + Ok(()) } - FlexVar(Some(_)) | RigidVar(_) | Structure(_) | Alias(_, _, _, _) | Error(_) => { + FlexVar(Some(_)) | RigidVar(_) | Structure(_) | Alias(_, _, _, _) => { // In all other cases, if left is flex, defer to right. // (This includes using right's name if both are flex and named.) - merge(subs, ctx, other.clone()) + merge(subs, ctx, other.clone()); + + Ok(()) + } + Error(problem) => { + merge(subs, ctx, Error(problem.clone())); + + Err(problem.clone()) } } } @@ -235,10 +390,7 @@ fn gather_fields( gather_fields(subs, fields, var) } - _ => RecordStructure { - fields, - extension: var, - }, + _ => RecordStructure { fields, ext: var }, } }