diff --git a/cli/src/repl/eval.rs b/cli/src/repl/eval.rs index 7a018bdc9a..94f51ec6cf 100644 --- a/cli/src/repl/eval.rs +++ b/cli/src/repl/eval.rs @@ -146,6 +146,9 @@ fn jit_to_ast_help<'a>( single_tag_union_to_ast(env, ptr, field_layouts, tag_name.clone(), payload_vars) } + Content::Structure(FlatType::FunctionOrTagUnion(tag_name, _, _)) => { + single_tag_union_to_ast(env, ptr, field_layouts, tag_name.clone(), &[]) + } other => { unreachable!( "Something had a Struct layout, but instead of a Record or TagUnion type, it had: {:?}", @@ -318,6 +321,9 @@ fn ptr_to_ast<'a>( let (tag_name, payload_vars) = tags.iter().next().unwrap(); single_tag_union_to_ast(env, ptr, field_layouts, tag_name.clone(), payload_vars) } + Content::Structure(FlatType::FunctionOrTagUnion(tag_name, _, _)) => { + single_tag_union_to_ast(env, ptr, field_layouts, tag_name.clone(), &[]) + } Content::Structure(FlatType::EmptyRecord) => { struct_to_ast(env, ptr, &[], &MutMap::default()) } diff --git a/compiler/can/src/expr.rs b/compiler/can/src/expr.rs index b141490fbb..9fd811c3e6 100644 --- a/compiler/can/src/expr.rs +++ b/compiler/can/src/expr.rs @@ -161,6 +161,14 @@ pub enum Expr { arguments: Vec<(Variable, Located)>, }, + ZeroArgumentTag { + closure_name: Symbol, + variant_var: Variable, + ext_var: Variable, + name: TagName, + arguments: Vec<(Variable, Located)>, + }, + // Test Expect(Box>, Box>), @@ -395,6 +403,17 @@ pub fn canonicalize_expr<'a>( name, arguments: args, }, + ZeroArgumentTag { + variant_var, + ext_var, + name, + .. + } => Tag { + variant_var, + ext_var, + name, + arguments: args, + }, _ => { // This could be something like ((if True then fn1 else fn2) arg1 arg2). Call( @@ -630,11 +649,14 @@ pub fn canonicalize_expr<'a>( let variant_var = var_store.fresh(); let ext_var = var_store.fresh(); + let symbol = env.gen_unique_symbol(); + ( - Tag { + ZeroArgumentTag { name: TagName::Global((*tag).into()), arguments: vec![], variant_var, + closure_name: symbol, ext_var, }, Output::default(), @@ -645,13 +667,15 @@ pub fn canonicalize_expr<'a>( let ext_var = var_store.fresh(); let tag_ident = env.ident_ids.get_or_insert(&(*tag).into()); let symbol = Symbol::new(env.home, tag_ident); + let lambda_set_symbol = env.gen_unique_symbol(); ( - Tag { + ZeroArgumentTag { name: TagName::Private(symbol), arguments: vec![], variant_var, ext_var, + closure_name: lambda_set_symbol, }, Output::default(), ) @@ -1431,6 +1455,23 @@ pub fn inline_calls(var_store: &mut VarStore, scope: &mut Scope, expr: Expr) -> ); } + ZeroArgumentTag { + closure_name, + variant_var, + ext_var, + name, + arguments, + } => { + todo!( + "Inlining for ZeroArgumentTag with closure_name {:?}, variant_var {:?}, ext_var {:?}, name {:?}, arguments {:?}", + closure_name, + variant_var, + ext_var, + name, + arguments + ); + } + Call(boxed_tuple, args, called_via) => { let (fn_var, loc_expr, closure_var, expr_var) = *boxed_tuple; diff --git a/compiler/can/src/module.rs b/compiler/can/src/module.rs index ae1d5015d9..ca6f076ede 100644 --- a/compiler/can/src/module.rs +++ b/compiler/can/src/module.rs @@ -509,7 +509,7 @@ fn fix_values_captured_in_closure_expr( fix_values_captured_in_closure_expr(&mut loc_expr.value, no_capture_symbols); } - Tag { arguments, .. } => { + Tag { arguments, .. } | ZeroArgumentTag { arguments, .. } => { for (_, loc_arg) in arguments.iter_mut() { fix_values_captured_in_closure_expr(&mut loc_arg.value, no_capture_symbols); } diff --git a/compiler/constrain/src/expr.rs b/compiler/constrain/src/expr.rs index 6f645539a8..c4c5b5ed8e 100644 --- a/compiler/constrain/src/expr.rs +++ b/compiler/constrain/src/expr.rs @@ -867,6 +867,58 @@ pub fn constrain_expr( exists(vars, And(arg_cons)) } + ZeroArgumentTag { + variant_var, + ext_var, + name, + arguments, + closure_name, + } => { + let mut vars = Vec::with_capacity(arguments.len()); + let mut types = Vec::with_capacity(arguments.len()); + let mut arg_cons = Vec::with_capacity(arguments.len()); + + for (var, loc_expr) in arguments { + let arg_con = constrain_expr( + env, + loc_expr.region, + &loc_expr.value, + Expected::NoExpectation(Type::Variable(*var)), + ); + + arg_cons.push(arg_con); + vars.push(*var); + types.push(Type::Variable(*var)); + } + + let union_con = Eq( + Type::FunctionOrTagUnion( + name.clone(), + *closure_name, + Box::new(Type::Variable(*ext_var)), + ), + expected.clone(), + Category::TagApply { + tag_name: name.clone(), + args_count: arguments.len(), + }, + region, + ); + let ast_con = Eq( + Type::Variable(*variant_var), + expected, + Category::Storage(std::file!(), std::line!()), + region, + ); + + vars.push(*variant_var); + vars.push(*ext_var); + arg_cons.push(union_con); + arg_cons.push(ast_con); + + exists(vars, And(arg_cons)) + } + RunLowLevel { args, ret_var, op } => { // This is a modified version of what we do for function calls. diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index d512954643..f9a859ccd1 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -3022,351 +3022,63 @@ pub fn with_hole<'a>( variant_var, name: tag_name, arguments: args, - ext_var, + .. + } => { + let arena = env.arena; + + debug_assert!(!matches!( + env.subs.get_without_compacting(variant_var).content, + Content::Structure(FlatType::Func(_, _, _)) + )); + convert_tag_union( + env, + variant_var, + assigned, + hole, + tag_name, + procs, + layout_cache, + args, + arena, + ) + } + + ZeroArgumentTag { + variant_var, + name: tag_name, + arguments: args, + ext_var, + .. } => { - use crate::layout::UnionVariant::*; let arena = env.arena; let desc = env.subs.get_without_compacting(variant_var); if let Content::Structure(FlatType::Func(arg_vars, _, ret_var)) = desc.content { - let mut loc_pattern_args = vec![]; - let mut loc_expr_args = vec![]; - - let proc_symbol = env.unique_symbol(); - - for arg_var in arg_vars { - let arg_symbol = env.unique_symbol(); - - let loc_pattern = - Located::at_zero(roc_can::pattern::Pattern::Identifier(arg_symbol)); - - let loc_expr = Located::at_zero(roc_can::expr::Expr::Var(arg_symbol)); - - loc_pattern_args.push((arg_var, loc_pattern)); - loc_expr_args.push((arg_var, loc_expr)); - } - - let loc_body = Located::at_zero(roc_can::expr::Expr::Tag { - variant_var: ret_var, - name: tag_name, - arguments: loc_expr_args, - ext_var, - }); - - let inserted = procs.insert_anonymous( + tag_union_to_function( env, - proc_symbol, - variant_var, - loc_pattern_args, - loc_body, - CapturedSymbols::None, + arg_vars, ret_var, + tag_name, + ext_var, + procs, + variant_var, layout_cache, - ); - - match inserted { - Ok(_layout) => { - todo!("depends on 0-argument tag unions having a lambda union") - // return Stmt::Let( - // assigned, - // todo!(), // call_by_pointer(env, procs, proc_symbol, layout), - // layout, - // hole, - // ); - } - Err(runtime_error) => { - return Stmt::RuntimeError(env.arena.alloc(format!( - "RuntimeError {} line {} {:?}", - file!(), - line!(), - runtime_error, - ))); - } - } - } - - let res_variant = crate::layout::union_sorted_tags(env.arena, variant_var, env.subs); - - let variant = match res_variant { - Ok(cached) => cached, - Err(LayoutProblem::UnresolvedTypeVar(_)) => { - return Stmt::RuntimeError(env.arena.alloc(format!( - "UnresolvedTypeVar {} line {}", - file!(), - line!() - ))); - } - Err(LayoutProblem::Erroneous) => { - return Stmt::RuntimeError(env.arena.alloc(format!( - "Erroneous {} line {}", - file!(), - line!() - ))); - } - }; - - match variant { - Never => unreachable!( - "The `[]` type has no constructors, source var {:?}", - variant_var - ), - Unit | UnitWithArguments => let_empty_struct(assigned, hole), - BoolUnion { ttrue, .. } => Stmt::Let( assigned, - Expr::Literal(Literal::Bool(tag_name == ttrue)), - Layout::Builtin(Builtin::Int1), hole, - ), - ByteUnion(tag_names) => { - let tag_id = tag_names - .iter() - .position(|key| key == &tag_name) - .expect("tag must be in its own type"); - - Stmt::Let( - assigned, - Expr::Literal(Literal::Byte(tag_id as u8)), - Layout::Builtin(Builtin::Int8), - hole, - ) - } - - Unwrapped(_, field_layouts) => { - let field_symbols_temp = sorted_field_symbols(env, procs, layout_cache, args); - - let mut field_symbols = Vec::with_capacity_in(field_layouts.len(), env.arena); - field_symbols.extend(field_symbols_temp.iter().map(|r| r.1)); - let field_symbols = field_symbols.into_bump_slice(); - - // Layout will unpack this unwrapped tack if it only has one (non-zero-sized) field - let layout = layout_cache - .from_var(env.arena, variant_var, env.subs) - .unwrap_or_else(|err| { - panic!("TODO turn fn_var into a RuntimeError {:?}", err) - }); - - // even though this was originally a Tag, we treat it as a Struct from now on - let stmt = if let [only_field] = field_symbols { - let mut hole = hole.clone(); - substitute_in_exprs(env.arena, &mut hole, assigned, *only_field); - hole - } else { - Stmt::Let(assigned, Expr::Struct(field_symbols), layout, hole) - }; - - let iter = field_symbols_temp.into_iter().map(|(_, _, data)| data); - assign_to_symbols(env, procs, layout_cache, iter, stmt) - } - Wrapped(variant) => { - let union_size = variant.number_of_tags() as u8; - let (tag_id, _) = variant.tag_name_to_id(&tag_name); - - let field_symbols_temp = sorted_field_symbols(env, procs, layout_cache, args); - - let field_symbols; - let opt_tag_id_symbol; - - use WrappedVariant::*; - let (tag, layout) = match variant { - Recursive { sorted_tag_layouts } => { - debug_assert!(sorted_tag_layouts.len() > 1); - let tag_id_symbol = env.unique_symbol(); - opt_tag_id_symbol = Some(tag_id_symbol); - - field_symbols = { - let mut temp = - Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); - temp.push(tag_id_symbol); - - temp.extend(field_symbols_temp.iter().map(|r| r.1)); - - temp.into_bump_slice() - }; - - let mut layouts: Vec<&'a [Layout<'a>]> = - Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); - - for (_, arg_layouts) in sorted_tag_layouts.into_iter() { - layouts.push(arg_layouts); - } - - debug_assert!(layouts.len() > 1); - let layout = - Layout::Union(UnionLayout::Recursive(layouts.into_bump_slice())); - - let tag = Expr::Tag { - tag_layout: layout, - tag_name, - tag_id: tag_id as u8, - union_size, - arguments: field_symbols, - }; - - (tag, layout) - } - NonNullableUnwrapped { - fields, - tag_name: wrapped_tag_name, - } => { - debug_assert_eq!(tag_name, wrapped_tag_name); - - opt_tag_id_symbol = None; - - field_symbols = { - let mut temp = - Vec::with_capacity_in(field_symbols_temp.len(), arena); - - temp.extend(field_symbols_temp.iter().map(|r| r.1)); - - temp.into_bump_slice() - }; - - let layout = Layout::Union(UnionLayout::NonNullableUnwrapped(fields)); - - let tag = Expr::Tag { - tag_layout: layout, - tag_name, - tag_id: tag_id as u8, - union_size, - arguments: field_symbols, - }; - - (tag, layout) - } - NonRecursive { sorted_tag_layouts } => { - let tag_id_symbol = env.unique_symbol(); - opt_tag_id_symbol = Some(tag_id_symbol); - - field_symbols = { - let mut temp = - Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); - temp.push(tag_id_symbol); - - temp.extend(field_symbols_temp.iter().map(|r| r.1)); - - temp.into_bump_slice() - }; - - let mut layouts: Vec<&'a [Layout<'a>]> = - Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); - - for (_, arg_layouts) in sorted_tag_layouts.into_iter() { - layouts.push(arg_layouts); - } - - let layout = - Layout::Union(UnionLayout::NonRecursive(layouts.into_bump_slice())); - - let tag = Expr::Tag { - tag_layout: layout, - tag_name, - tag_id: tag_id as u8, - union_size, - arguments: field_symbols, - }; - - (tag, layout) - } - NullableWrapped { - nullable_id, - nullable_name: _, - sorted_tag_layouts, - } => { - let tag_id_symbol = env.unique_symbol(); - opt_tag_id_symbol = Some(tag_id_symbol); - - field_symbols = { - let mut temp = - Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); - temp.push(tag_id_symbol); - - temp.extend(field_symbols_temp.iter().map(|r| r.1)); - - temp.into_bump_slice() - }; - - let mut layouts: Vec<&'a [Layout<'a>]> = - Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); - - for (_, arg_layouts) in sorted_tag_layouts.into_iter() { - layouts.push(arg_layouts); - } - - let layout = Layout::Union(UnionLayout::NullableWrapped { - nullable_id, - other_tags: layouts.into_bump_slice(), - }); - - let tag = Expr::Tag { - tag_layout: layout, - tag_name, - tag_id: tag_id as u8, - union_size, - arguments: field_symbols, - }; - - (tag, layout) - } - NullableUnwrapped { - nullable_id, - nullable_name: _, - other_name: _, - other_fields, - } => { - // FIXME drop tag - let tag_id_symbol = env.unique_symbol(); - opt_tag_id_symbol = Some(tag_id_symbol); - - field_symbols = { - let mut temp = - Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); - // FIXME drop tag - temp.push(tag_id_symbol); - - temp.extend(field_symbols_temp.iter().map(|r| r.1)); - - temp.into_bump_slice() - }; - - let layout = Layout::Union(UnionLayout::NullableUnwrapped { - nullable_id, - other_fields, - }); - - let tag = Expr::Tag { - tag_layout: layout, - tag_name, - tag_id: tag_id as u8, - union_size, - arguments: field_symbols, - }; - - (tag, layout) - } - }; - - let mut stmt = Stmt::Let(assigned, tag, layout, hole); - let iter = field_symbols_temp - .into_iter() - .map(|x| x.2 .0) - .rev() - .zip(field_symbols.iter().rev()); - - stmt = assign_to_symbols(env, procs, layout_cache, iter, stmt); - - if let Some(tag_id_symbol) = opt_tag_id_symbol { - // define the tag id - stmt = Stmt::Let( - tag_id_symbol, - Expr::Literal(Literal::Int(tag_id as i128)), - Layout::Builtin(TAG_SIZE), - arena.alloc(stmt), - ); - } - - stmt - } + ) + } else { + convert_tag_union( + env, + variant_var, + assigned, + hole, + tag_name, + procs, + layout_cache, + args, + arena, + ) } } @@ -4468,6 +4180,347 @@ fn construct_closure_data<'a>( } } +#[allow(clippy::too_many_arguments)] +fn convert_tag_union<'a>( + env: &mut Env<'a, '_>, + variant_var: Variable, + assigned: Symbol, + hole: &'a Stmt<'a>, + tag_name: TagName, + procs: &mut Procs<'a>, + layout_cache: &mut LayoutCache<'a>, + args: std::vec::Vec<(Variable, Located)>, + arena: &'a Bump, +) -> Stmt<'a> { + use crate::layout::UnionVariant::*; + let res_variant = crate::layout::union_sorted_tags(env.arena, variant_var, env.subs); + let variant = match res_variant { + Ok(cached) => cached, + Err(LayoutProblem::UnresolvedTypeVar(_)) => { + return Stmt::RuntimeError(env.arena.alloc(format!( + "UnresolvedTypeVar {} line {}", + file!(), + line!() + ))) + } + Err(LayoutProblem::Erroneous) => { + return Stmt::RuntimeError(env.arena.alloc(format!( + "Erroneous {} line {}", + file!(), + line!() + ))); + } + }; + match variant { + Never => unreachable!( + "The `[]` type has no constructors, source var {:?}", + variant_var + ), + Unit | UnitWithArguments => { + Stmt::Let(assigned, Expr::Struct(&[]), Layout::Struct(&[]), hole) + } + BoolUnion { ttrue, .. } => Stmt::Let( + assigned, + Expr::Literal(Literal::Bool(tag_name == ttrue)), + Layout::Builtin(Builtin::Int1), + hole, + ), + ByteUnion(tag_names) => { + let tag_id = tag_names + .iter() + .position(|key| key == &tag_name) + .expect("tag must be in its own type"); + + Stmt::Let( + assigned, + Expr::Literal(Literal::Byte(tag_id as u8)), + Layout::Builtin(Builtin::Int8), + hole, + ) + } + + Unwrapped(field_layouts) => { + let field_symbols_temp = sorted_field_symbols(env, procs, layout_cache, args); + + let mut field_symbols = Vec::with_capacity_in(field_layouts.len(), env.arena); + field_symbols.extend(field_symbols_temp.iter().map(|r| r.1)); + let field_symbols = field_symbols.into_bump_slice(); + + // Layout will unpack this unwrapped tack if it only has one (non-zero-sized) field + let layout = layout_cache + .from_var(env.arena, variant_var, env.subs) + .unwrap_or_else(|err| panic!("TODO turn fn_var into a RuntimeError {:?}", err)); + + // even though this was originally a Tag, we treat it as a Struct from now on + let stmt = Stmt::Let(assigned, Expr::Struct(field_symbols), layout, hole); + + let iter = field_symbols_temp.into_iter().map(|(_, _, data)| data); + assign_to_symbols(env, procs, layout_cache, iter, stmt) + } + Wrapped(variant) => { + let union_size = variant.number_of_tags() as u8; + let (tag_id, _) = variant.tag_name_to_id(&tag_name); + + let field_symbols_temp = sorted_field_symbols(env, procs, layout_cache, args); + + let field_symbols; + let opt_tag_id_symbol; + + use WrappedVariant::*; + let (tag, layout) = match variant { + Recursive { sorted_tag_layouts } => { + debug_assert!(sorted_tag_layouts.len() > 1); + let tag_id_symbol = env.unique_symbol(); + opt_tag_id_symbol = Some(tag_id_symbol); + + field_symbols = { + let mut temp = Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); + temp.push(tag_id_symbol); + + temp.extend(field_symbols_temp.iter().map(|r| r.1)); + + temp.into_bump_slice() + }; + + let mut layouts: Vec<&'a [Layout<'a>]> = + Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); + + for (_, arg_layouts) in sorted_tag_layouts.into_iter() { + layouts.push(arg_layouts); + } + + debug_assert!(layouts.len() > 1); + let layout = Layout::Union(UnionLayout::Recursive(layouts.into_bump_slice())); + + let tag = Expr::Tag { + tag_layout: layout, + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + (tag, layout) + } + NonNullableUnwrapped { + fields, + tag_name: wrapped_tag_name, + } => { + debug_assert_eq!(tag_name, wrapped_tag_name); + + opt_tag_id_symbol = None; + + field_symbols = { + let mut temp = Vec::with_capacity_in(field_symbols_temp.len(), arena); + + temp.extend(field_symbols_temp.iter().map(|r| r.1)); + + temp.into_bump_slice() + }; + + let layout = Layout::Union(UnionLayout::NonNullableUnwrapped(fields)); + + let tag = Expr::Tag { + tag_layout: layout, + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + (tag, layout) + } + NonRecursive { sorted_tag_layouts } => { + let tag_id_symbol = env.unique_symbol(); + opt_tag_id_symbol = Some(tag_id_symbol); + + field_symbols = { + let mut temp = Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); + temp.push(tag_id_symbol); + + temp.extend(field_symbols_temp.iter().map(|r| r.1)); + + temp.into_bump_slice() + }; + + let mut layouts: Vec<&'a [Layout<'a>]> = + Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); + + for (_, arg_layouts) in sorted_tag_layouts.into_iter() { + layouts.push(arg_layouts); + } + + let layout = + Layout::Union(UnionLayout::NonRecursive(layouts.into_bump_slice())); + + let tag = Expr::Tag { + tag_layout: layout, + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + (tag, layout) + } + NullableWrapped { + nullable_id, + nullable_name: _, + sorted_tag_layouts, + } => { + let tag_id_symbol = env.unique_symbol(); + opt_tag_id_symbol = Some(tag_id_symbol); + + field_symbols = { + let mut temp = Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); + temp.push(tag_id_symbol); + + temp.extend(field_symbols_temp.iter().map(|r| r.1)); + + temp.into_bump_slice() + }; + + let mut layouts: Vec<&'a [Layout<'a>]> = + Vec::with_capacity_in(sorted_tag_layouts.len(), env.arena); + + for (_, arg_layouts) in sorted_tag_layouts.into_iter() { + layouts.push(arg_layouts); + } + + let layout = Layout::Union(UnionLayout::NullableWrapped { + nullable_id, + other_tags: layouts.into_bump_slice(), + }); + + let tag = Expr::Tag { + tag_layout: layout, + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + (tag, layout) + } + NullableUnwrapped { + nullable_id, + nullable_name: _, + other_name: _, + other_fields, + } => { + // FIXME drop tag + let tag_id_symbol = env.unique_symbol(); + opt_tag_id_symbol = Some(tag_id_symbol); + + field_symbols = { + let mut temp = Vec::with_capacity_in(field_symbols_temp.len() + 1, arena); + // FIXME drop tag + temp.push(tag_id_symbol); + + temp.extend(field_symbols_temp.iter().map(|r| r.1)); + + temp.into_bump_slice() + }; + + let layout = Layout::Union(UnionLayout::NullableUnwrapped { + nullable_id, + other_fields, + }); + + let tag = Expr::Tag { + tag_layout: layout, + tag_name, + tag_id: tag_id as u8, + union_size, + arguments: field_symbols, + }; + + (tag, layout) + } + }; + + let mut stmt = Stmt::Let(assigned, tag, layout, hole); + let iter = field_symbols_temp + .into_iter() + .map(|x| x.2 .0) + .rev() + .zip(field_symbols.iter().rev()); + + stmt = assign_to_symbols(env, procs, layout_cache, iter, stmt); + + if let Some(tag_id_symbol) = opt_tag_id_symbol { + // define the tag id + stmt = Stmt::Let( + tag_id_symbol, + Expr::Literal(Literal::Int(tag_id as i128)), + Layout::Builtin(TAG_SIZE), + arena.alloc(stmt), + ); + } + + stmt + } + } +} + +#[allow(clippy::too_many_arguments)] +fn tag_union_to_function<'a>( + env: &mut Env<'a, '_>, + arg_vars: std::vec::Vec, + ret_var: Variable, + tag_name: TagName, + ext_var: Variable, + procs: &mut Procs<'a>, + variant_var: Variable, + layout_cache: &mut LayoutCache<'a>, + assigned: Symbol, + hole: &'a Stmt<'a>, +) -> Stmt<'a> { + let mut loc_pattern_args = vec![]; + let mut loc_expr_args = vec![]; + let proc_symbol = env.unique_symbol(); + for arg_var in arg_vars { + let arg_symbol = env.unique_symbol(); + + let loc_pattern = Located::at_zero(roc_can::pattern::Pattern::Identifier(arg_symbol)); + + let loc_expr = Located::at_zero(roc_can::expr::Expr::Var(arg_symbol)); + + loc_pattern_args.push((arg_var, loc_pattern)); + loc_expr_args.push((arg_var, loc_expr)); + } + let loc_body = Located::at_zero(roc_can::expr::Expr::Tag { + variant_var: ret_var, + name: tag_name, + arguments: loc_expr_args, + ext_var, + }); + let inserted = procs.insert_anonymous( + env, + proc_symbol, + variant_var, + loc_pattern_args, + loc_body, + CapturedSymbols::None, + ret_var, + layout_cache, + ); + match inserted { + Ok(layout) => Stmt::Let( + assigned, + call_by_pointer(env, procs, proc_symbol, layout), + layout, + hole, + ), + Err(runtime_error) => Stmt::RuntimeError(env.arena.alloc(format!( + "RuntimeError {} line {} {:?}", + file!(), + line!(), + runtime_error, + ))), + } +} + #[allow(clippy::type_complexity)] fn sorted_field_symbols<'a>( env: &mut Env<'a, '_>, diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 058527b27b..d21ff72f0d 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -357,6 +357,12 @@ impl<'a, 'b> Env<'a, 'b> { self.seen.insert(var) } + + fn remove_seen(&mut self, var: Variable) -> bool { + let var = self.subs.get_root_key_without_compacting(var); + + self.seen.remove(&var) + } } impl<'a> Layout<'a> { @@ -1090,6 +1096,14 @@ fn layout_from_flat_type<'a>( Ok(layout_from_tag_union(arena, tags, subs)) } + FunctionOrTagUnion(tag_name, _, ext_var) => { + debug_assert!(ext_var_is_empty_tag_union(subs, ext_var)); + + let mut tags = MutMap::default(); + tags.insert(tag_name, vec![]); + + Ok(layout_from_tag_union(arena, tags, subs)) + } RecursiveTagUnion(rec_var, tags, ext_var) => { debug_assert!(ext_var_is_empty_tag_union(subs, ext_var)); @@ -1175,6 +1189,8 @@ fn layout_from_flat_type<'a>( UnionLayout::Recursive(tag_layouts.into_bump_slice()) }; + env.remove_seen(rec_var); + Ok(Layout::Union(union_layout)) } EmptyTagUnion => { diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 72108c3f45..6ace7c2a7c 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -241,18 +241,18 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.9 = 0i64; - let Test.8 = 3i64; - let Test.2 = Just Test.9 Test.8; - let Test.5 = 0i64; - let Test.6 = Index 0 Test.2; - let Test.7 = lowlevel Eq Test.5 Test.6; - if Test.7 then - let Test.1 = Index 1 Test.2; - ret Test.1; + let Test.10 = 0i64; + let Test.9 = 3i64; + let Test.3 = Just Test.10 Test.9; + let Test.6 = 0i64; + let Test.7 = Index 0 Test.3; + let Test.8 = lowlevel Eq Test.6 Test.7; + if Test.8 then + let Test.2 = Index 1 Test.3; + ret Test.2; else - let Test.4 = 0i64; - ret Test.4; + let Test.5 = 0i64; + ret Test.5; "# ), ) @@ -270,23 +270,23 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.10 = 1i64; - let Test.8 = 1i64; - let Test.9 = 2i64; - let Test.4 = These Test.10 Test.8 Test.9; - switch Test.4: + let Test.11 = 1i64; + let Test.9 = 1i64; + let Test.10 = 2i64; + let Test.5 = These Test.11 Test.9 Test.10; + switch Test.5: case 2: - let Test.1 = Index 1 Test.4; - ret Test.1; - - case 0: - let Test.2 = Index 1 Test.4; + let Test.2 = Index 1 Test.5; ret Test.2; - default: - let Test.3 = Index 1 Test.4; + case 0: + let Test.3 = Index 1 Test.5; ret Test.3; + default: + let Test.4 = Index 1 Test.5; + ret Test.4; + "# ), ) @@ -436,24 +436,24 @@ mod test_mono { indoc!( r#" procedure Num.24 (#Attr.2, #Attr.3): - let Test.5 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.5; + let Test.6 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.6; procedure Test.0 (): - let Test.11 = 0i64; - let Test.10 = 41i64; - let Test.1 = Just Test.11 Test.10; - let Test.7 = 0i64; - let Test.8 = Index 0 Test.1; - let Test.9 = lowlevel Eq Test.7 Test.8; - if Test.9 then - let Test.2 = Index 1 Test.1; - let Test.4 = 1i64; - let Test.3 = CallByName Num.24 Test.2 Test.4; - ret Test.3; + let Test.12 = 0i64; + let Test.11 = 41i64; + let Test.1 = Just Test.12 Test.11; + let Test.8 = 0i64; + let Test.9 = Index 0 Test.1; + let Test.10 = lowlevel Eq Test.8 Test.9; + if Test.10 then + let Test.3 = Index 1 Test.1; + let Test.5 = 1i64; + let Test.4 = CallByName Num.24 Test.3 Test.5; + ret Test.4; else - let Test.6 = 1i64; - ret Test.6; + let Test.7 = 1i64; + ret Test.7; "# ), ) @@ -470,9 +470,6 @@ mod test_mono { "#, indoc!( r#" - procedure Test.0 (): - let Test.3 = 2i64; - ret Test.3; "# ), ) @@ -491,31 +488,31 @@ mod test_mono { "#, indoc!( r#" - procedure Test.1 (Test.2): - let Test.5 = 2i64; - joinpoint Test.11: - let Test.9 = 0i64; - ret Test.9; + procedure Test.1 (Test.3): + let Test.6 = 2i64; + joinpoint Test.12: + let Test.10 = 0i64; + ret Test.10; in - let Test.10 = 2i64; - let Test.13 = lowlevel Eq Test.10 Test.5; - if Test.13 then - joinpoint Test.7 Test.12: - if Test.12 then - let Test.6 = 42i64; - ret Test.6; + let Test.11 = 2i64; + let Test.14 = lowlevel Eq Test.11 Test.6; + if Test.14 then + joinpoint Test.8 Test.13: + if Test.13 then + let Test.7 = 42i64; + ret Test.7; else - jump Test.11; + jump Test.12; in - let Test.8 = false; - jump Test.7 Test.8; + let Test.9 = false; + jump Test.8 Test.9; else - jump Test.11; + jump Test.12; procedure Test.0 (): - let Test.4 = Struct {}; - let Test.3 = CallByName Test.1 Test.4; - ret Test.3; + let Test.5 = Struct {}; + let Test.4 = CallByName Test.1 Test.5; + ret Test.4; "# ), ) @@ -561,37 +558,37 @@ mod test_mono { indoc!( r#" procedure Num.24 (#Attr.2, #Attr.3): - let Test.6 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.6; + let Test.8 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.8; procedure Test.0 (): - let Test.18 = 0i64; let Test.20 = 0i64; - let Test.19 = 41i64; - let Test.17 = Just Test.20 Test.19; - let Test.2 = Just Test.18 Test.17; - joinpoint Test.14: - let Test.8 = 1i64; - ret Test.8; + let Test.22 = 0i64; + let Test.21 = 41i64; + let Test.19 = Just Test.22 Test.21; + let Test.2 = Just Test.20 Test.19; + joinpoint Test.16: + let Test.10 = 1i64; + ret Test.10; in - let Test.12 = 0i64; - let Test.13 = Index 0 Test.2; - let Test.16 = lowlevel Eq Test.12 Test.13; - if Test.16 then - let Test.9 = Index 1 Test.2; - let Test.10 = 0i64; - let Test.11 = Index 0 Test.9; - let Test.15 = lowlevel Eq Test.10 Test.11; - if Test.15 then - let Test.7 = Index 1 Test.2; - let Test.3 = Index 1 Test.7; - let Test.5 = 1i64; - let Test.4 = CallByName Num.24 Test.3 Test.5; - ret Test.4; + let Test.14 = 0i64; + let Test.15 = Index 0 Test.2; + let Test.18 = lowlevel Eq Test.14 Test.15; + if Test.18 then + let Test.11 = Index 1 Test.2; + let Test.12 = 0i64; + let Test.13 = Index 0 Test.11; + let Test.17 = lowlevel Eq Test.12 Test.13; + if Test.17 then + let Test.9 = Index 1 Test.2; + let Test.5 = Index 1 Test.9; + let Test.7 = 1i64; + let Test.6 = CallByName Num.24 Test.5 Test.7; + ret Test.6; else - jump Test.14; + jump Test.16; else - jump Test.14; + jump Test.16; "# ), ) @@ -608,33 +605,33 @@ mod test_mono { indoc!( r#" procedure Num.24 (#Attr.2, #Attr.3): - let Test.6 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.6; + let Test.7 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.7; procedure Test.0 (): - let Test.15 = 3i64; - let Test.14 = 2i64; - let Test.3 = Struct {Test.14, Test.15}; - joinpoint Test.11: - let Test.1 = Index 0 Test.3; - let Test.2 = Index 1 Test.3; - let Test.5 = CallByName Num.24 Test.1 Test.2; - ret Test.5; + let Test.16 = 3i64; + let Test.15 = 2i64; + let Test.4 = Struct {Test.15, Test.16}; + joinpoint Test.12: + let Test.2 = Index 0 Test.4; + let Test.3 = Index 1 Test.4; + let Test.6 = CallByName Num.24 Test.2 Test.3; + ret Test.6; in - let Test.9 = Index 1 Test.3; - let Test.10 = 3i64; - let Test.13 = lowlevel Eq Test.10 Test.9; - if Test.13 then - let Test.7 = Index 0 Test.3; - let Test.8 = 4i64; - let Test.12 = lowlevel Eq Test.8 Test.7; - if Test.12 then - let Test.4 = 9i64; - ret Test.4; + let Test.10 = Index 1 Test.4; + let Test.11 = 3i64; + let Test.14 = lowlevel Eq Test.11 Test.10; + if Test.14 then + let Test.8 = Index 0 Test.4; + let Test.9 = 4i64; + let Test.13 = lowlevel Eq Test.9 Test.8; + if Test.13 then + let Test.5 = 9i64; + ret Test.5; else - jump Test.11; + jump Test.12; else - jump Test.11; + jump Test.12; "# ), ) @@ -781,29 +778,29 @@ mod test_mono { "#, indoc!( r#" - procedure Test.1 (Test.4): + procedure Test.1 (Test.5): let Test.2 = 0u8; - joinpoint Test.8 Test.3: + joinpoint Test.9 Test.3: ret Test.3; in switch Test.2: case 1: - let Test.9 = 1i64; - jump Test.8 Test.9; + let Test.10 = 1i64; + jump Test.9 Test.10; case 2: - let Test.10 = 2i64; - jump Test.8 Test.10; + let Test.11 = 2i64; + jump Test.9 Test.11; default: - let Test.11 = 3i64; - jump Test.8 Test.11; + let Test.12 = 3i64; + jump Test.9 Test.12; procedure Test.0 (): - let Test.6 = Struct {}; - let Test.5 = CallByName Test.1 Test.6; - ret Test.5; + let Test.7 = Struct {}; + let Test.6 = CallByName Test.1 Test.7; + ret Test.6; "# ), ) @@ -821,13 +818,13 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.2 = true; - if Test.2 then - let Test.3 = 1i64; - ret Test.3; + let Test.3 = true; + if Test.3 then + let Test.4 = 1i64; + ret Test.4; else - let Test.1 = 2i64; - ret Test.1; + let Test.2 = 2i64; + ret Test.2; "# ), ) @@ -847,18 +844,18 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.4 = true; - if Test.4 then - let Test.5 = 1i64; - ret Test.5; + let Test.6 = true; + if Test.6 then + let Test.7 = 1i64; + ret Test.7; else - let Test.2 = false; - if Test.2 then - let Test.3 = 2i64; - ret Test.3; + let Test.4 = false; + if Test.4 then + let Test.5 = 2i64; + ret Test.5; else - let Test.1 = 3i64; - ret Test.1; + let Test.3 = 3i64; + ret Test.3; "# ), ) @@ -883,34 +880,34 @@ mod test_mono { "#, indoc!( r#" - procedure Test.1 (Test.4): - let Test.19 = 1i64; - let Test.18 = 2i64; - let Test.2 = Ok Test.19 Test.18; - joinpoint Test.8 Test.3: + procedure Test.1 (Test.5): + let Test.20 = 1i64; + let Test.19 = 2i64; + let Test.2 = Ok Test.20 Test.19; + joinpoint Test.9 Test.3: ret Test.3; in - let Test.15 = 1i64; - let Test.16 = Index 0 Test.2; - let Test.17 = lowlevel Eq Test.15 Test.16; - if Test.17 then - let Test.12 = Index 1 Test.2; - let Test.13 = 3i64; - let Test.14 = lowlevel Eq Test.13 Test.12; - if Test.14 then - let Test.9 = 1i64; - jump Test.8 Test.9; + let Test.16 = 1i64; + let Test.17 = Index 0 Test.2; + let Test.18 = lowlevel Eq Test.16 Test.17; + if Test.18 then + let Test.13 = Index 1 Test.2; + let Test.14 = 3i64; + let Test.15 = lowlevel Eq Test.14 Test.13; + if Test.15 then + let Test.10 = 1i64; + jump Test.9 Test.10; else - let Test.10 = 2i64; - jump Test.8 Test.10; + let Test.11 = 2i64; + jump Test.9 Test.11; else - let Test.11 = 3i64; - jump Test.8 Test.11; + let Test.12 = 3i64; + jump Test.9 Test.12; procedure Test.0 (): - let Test.6 = Struct {}; - let Test.5 = CallByName Test.1 Test.6; - ret Test.5; + let Test.7 = Struct {}; + let Test.6 = CallByName Test.1 Test.7; + ret Test.6; "# ), ) @@ -1370,70 +1367,70 @@ mod test_mono { indoc!( r#" procedure List.3 (#Attr.2, #Attr.3): - let Test.38 = lowlevel ListLen #Attr.2; - let Test.34 = lowlevel NumLt #Attr.3 Test.38; - if Test.34 then - let Test.37 = 1i64; - let Test.36 = lowlevel ListGetUnsafe #Attr.2 #Attr.3; - let Test.35 = Ok Test.37 Test.36; - ret Test.35; + let Test.39 = lowlevel ListLen #Attr.2; + let Test.35 = lowlevel NumLt #Attr.3 Test.39; + if Test.35 then + let Test.38 = 1i64; + let Test.37 = lowlevel ListGetUnsafe #Attr.2 #Attr.3; + let Test.36 = Ok Test.38 Test.37; + ret Test.36; else - let Test.33 = 0i64; - let Test.32 = Struct {}; - let Test.31 = Err Test.33 Test.32; - ret Test.31; + let Test.34 = 0i64; + let Test.33 = Struct {}; + let Test.32 = Err Test.34 Test.33; + ret Test.32; procedure List.4 (#Attr.2, #Attr.3, #Attr.4): - let Test.14 = lowlevel ListLen #Attr.2; - let Test.12 = lowlevel NumLt #Attr.3 Test.14; - if Test.12 then - let Test.13 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; - ret Test.13; + let Test.15 = lowlevel ListLen #Attr.2; + let Test.13 = lowlevel NumLt #Attr.3 Test.15; + if Test.13 then + let Test.14 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; + ret Test.14; else ret #Attr.2; procedure Test.1 (Test.2): - let Test.39 = 0i64; - let Test.29 = CallByName List.3 Test.2 Test.39; - let Test.30 = 0i64; - let Test.28 = CallByName List.3 Test.2 Test.30; - let Test.7 = Struct {Test.28, Test.29}; - joinpoint Test.25: - let Test.18 = Array []; - ret Test.18; + let Test.40 = 0i64; + let Test.30 = CallByName List.3 Test.2 Test.40; + let Test.31 = 0i64; + let Test.29 = CallByName List.3 Test.2 Test.31; + let Test.8 = Struct {Test.29, Test.30}; + joinpoint Test.26: + let Test.19 = Array []; + ret Test.19; in - let Test.22 = Index 1 Test.7; - let Test.23 = 1i64; - let Test.24 = Index 0 Test.22; - let Test.27 = lowlevel Eq Test.23 Test.24; - if Test.27 then - let Test.19 = Index 0 Test.7; - let Test.20 = 1i64; - let Test.21 = Index 0 Test.19; - let Test.26 = lowlevel Eq Test.20 Test.21; - if Test.26 then - let Test.17 = Index 0 Test.7; - let Test.3 = Index 1 Test.17; - let Test.16 = Index 1 Test.7; - let Test.4 = Index 1 Test.16; - let Test.15 = 0i64; - let Test.9 = CallByName List.4 Test.2 Test.15 Test.4; - let Test.10 = 0i64; - let Test.8 = CallByName List.4 Test.9 Test.10 Test.3; - ret Test.8; + let Test.23 = Index 1 Test.8; + let Test.24 = 1i64; + let Test.25 = Index 0 Test.23; + let Test.28 = lowlevel Eq Test.24 Test.25; + if Test.28 then + let Test.20 = Index 0 Test.8; + let Test.21 = 1i64; + let Test.22 = Index 0 Test.20; + let Test.27 = lowlevel Eq Test.21 Test.22; + if Test.27 then + let Test.18 = Index 0 Test.8; + let Test.4 = Index 1 Test.18; + let Test.17 = Index 1 Test.8; + let Test.5 = Index 1 Test.17; + let Test.16 = 0i64; + let Test.10 = CallByName List.4 Test.2 Test.16 Test.5; + let Test.11 = 0i64; + let Test.9 = CallByName List.4 Test.10 Test.11 Test.4; + ret Test.9; else dec Test.2; - jump Test.25; + jump Test.26; else dec Test.2; - jump Test.25; + jump Test.26; procedure Test.0 (): - let Test.40 = 1i64; - let Test.41 = 2i64; - let Test.6 = Array [Test.40, Test.41]; - let Test.5 = CallByName Test.1 Test.6; - ret Test.5; + let Test.41 = 1i64; + let Test.42 = 2i64; + let Test.7 = Array [Test.41, Test.42]; + let Test.6 = CallByName Test.1 Test.7; + ret Test.6; "# ), ) @@ -1704,16 +1701,16 @@ mod test_mono { r#" procedure Test.1 (Test.2): inc Test.2; - let Test.5 = Struct {Test.2, Test.2}; - ret Test.5; + let Test.6 = Struct {Test.2, Test.2}; + ret Test.6; procedure Test.0 (): - let Test.6 = 1i64; - let Test.7 = 2i64; - let Test.8 = 3i64; - let Test.4 = Array [Test.6, Test.7, Test.8]; - let Test.3 = CallByName Test.1 Test.4; - ret Test.3; + let Test.7 = 1i64; + let Test.8 = 2i64; + let Test.9 = 3i64; + let Test.5 = Array [Test.7, Test.8, Test.9]; + let Test.4 = CallByName Test.1 Test.5; + ret Test.4; "# ), ) @@ -1882,14 +1879,14 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.5 = 0i64; - let Test.7 = 0i64; let Test.9 = 0i64; - let Test.10 = 1i64; - let Test.8 = Z Test.10; - let Test.6 = S Test.9 Test.8; - let Test.4 = S Test.7 Test.6; - let Test.2 = S Test.5 Test.4; + let Test.11 = 0i64; + let Test.13 = 0i64; + let Test.14 = 1i64; + let Test.12 = Z Test.14; + let Test.10 = S Test.13 Test.12; + let Test.8 = S Test.11 Test.10; + let Test.2 = S Test.9 Test.8; ret Test.2; "# ), @@ -1914,24 +1911,24 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.9 = 0i64; - let Test.11 = 0i64; let Test.13 = 0i64; - let Test.14 = 1i64; - let Test.12 = Z Test.14; - let Test.10 = S Test.13 Test.12; - let Test.8 = S Test.11 Test.10; - let Test.2 = S Test.9 Test.8; - let Test.5 = 1i64; - let Test.6 = Index 0 Test.2; + let Test.15 = 0i64; + let Test.17 = 0i64; + let Test.18 = 1i64; + let Test.16 = Z Test.18; + let Test.14 = S Test.17 Test.16; + let Test.12 = S Test.15 Test.14; + let Test.2 = S Test.13 Test.12; + let Test.9 = 1i64; + let Test.10 = Index 0 Test.2; dec Test.2; - let Test.7 = lowlevel Eq Test.5 Test.6; - if Test.7 then - let Test.3 = 0i64; - ret Test.3; + let Test.11 = lowlevel Eq Test.9 Test.10; + if Test.11 then + let Test.7 = 0i64; + ret Test.7; else - let Test.4 = 1i64; - ret Test.4; + let Test.8 = 1i64; + ret Test.8; "# ), ) @@ -1956,33 +1953,33 @@ mod test_mono { indoc!( r#" procedure Test.0 (): - let Test.15 = 0i64; - let Test.17 = 0i64; let Test.19 = 0i64; - let Test.20 = 1i64; - let Test.18 = Z Test.20; - let Test.16 = S Test.19 Test.18; - let Test.14 = S Test.17 Test.16; - let Test.2 = S Test.15 Test.14; - let Test.11 = 0i64; - let Test.12 = Index 0 Test.2; - let Test.13 = lowlevel Eq Test.11 Test.12; - if Test.13 then - let Test.7 = Index 1 Test.2; - let Test.8 = 0i64; - let Test.9 = Index 0 Test.7; - dec Test.7; + let Test.21 = 0i64; + let Test.23 = 0i64; + let Test.24 = 1i64; + let Test.22 = Z Test.24; + let Test.20 = S Test.23 Test.22; + let Test.18 = S Test.21 Test.20; + let Test.2 = S Test.19 Test.18; + let Test.15 = 0i64; + let Test.16 = Index 0 Test.2; + let Test.17 = lowlevel Eq Test.15 Test.16; + if Test.17 then + let Test.11 = Index 1 Test.2; + let Test.12 = 0i64; + let Test.13 = Index 0 Test.11; + dec Test.11; decref Test.2; - let Test.10 = lowlevel Eq Test.8 Test.9; - if Test.10 then - let Test.3 = 1i64; - ret Test.3; + let Test.14 = lowlevel Eq Test.12 Test.13; + if Test.14 then + let Test.7 = 1i64; + ret Test.7; else - let Test.5 = 0i64; - ret Test.5; + let Test.9 = 0i64; + ret Test.9; else - let Test.6 = 0i64; - ret Test.6; + let Test.10 = 0i64; + ret Test.10; "# ), ) @@ -2009,14 +2006,14 @@ mod test_mono { indoc!( r#" procedure Num.26 (#Attr.2, #Attr.3): - let Test.13 = lowlevel NumMul #Attr.2 #Attr.3; - ret Test.13; + let Test.17 = lowlevel NumMul #Attr.2 #Attr.3; + ret Test.17; procedure Test.1 (Test.6): - let Test.18 = Index 1 Test.6; - let Test.19 = false; - let Test.20 = lowlevel Eq Test.19 Test.18; - if Test.20 then + let Test.22 = Index 1 Test.6; + let Test.23 = false; + let Test.24 = lowlevel Eq Test.23 Test.22; + if Test.24 then let Test.8 = Index 0 Test.6; ret Test.8; else @@ -2024,10 +2021,10 @@ mod test_mono { ret Test.10; procedure Test.1 (Test.6): - let Test.29 = Index 0 Test.6; - let Test.30 = false; - let Test.31 = lowlevel Eq Test.30 Test.29; - if Test.31 then + let Test.33 = Index 0 Test.6; + let Test.34 = false; + let Test.35 = lowlevel Eq Test.34 Test.33; + if Test.35 then let Test.8 = 3i64; ret Test.8; else @@ -2035,22 +2032,6 @@ mod test_mono { ret Test.10; procedure Test.0 (): - let Test.34 = true; - let Test.5 = CallByName Test.1 Test.34; - let Test.32 = false; - let Test.3 = CallByName Test.1 Test.32; - let Test.24 = 11i64; - let Test.25 = true; - let Test.23 = Struct {Test.24, Test.25}; - let Test.4 = CallByName Test.1 Test.23; - let Test.21 = 7i64; - let Test.22 = false; - let Test.15 = Struct {Test.21, Test.22}; - let Test.2 = CallByName Test.1 Test.15; - let Test.14 = CallByName Num.26 Test.2 Test.3; - let Test.12 = CallByName Num.26 Test.14 Test.4; - let Test.11 = CallByName Num.26 Test.12 Test.5; - ret Test.11; "# ), ) @@ -2074,37 +2055,37 @@ mod test_mono { indoc!( r#" procedure Num.24 (#Attr.2, #Attr.3): - let Test.6 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.6; + let Test.8 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.8; procedure Test.0 (): - let Test.18 = 0i64; let Test.20 = 0i64; - let Test.19 = 41i64; - let Test.17 = Just Test.20 Test.19; - let Test.2 = Just Test.18 Test.17; - joinpoint Test.14: - let Test.8 = 1i64; - ret Test.8; + let Test.22 = 0i64; + let Test.21 = 41i64; + let Test.19 = Just Test.22 Test.21; + let Test.2 = Just Test.20 Test.19; + joinpoint Test.16: + let Test.10 = 1i64; + ret Test.10; in - let Test.12 = 0i64; - let Test.13 = Index 0 Test.2; - let Test.16 = lowlevel Eq Test.12 Test.13; - if Test.16 then - let Test.9 = Index 1 Test.2; - let Test.10 = 0i64; - let Test.11 = Index 0 Test.9; - let Test.15 = lowlevel Eq Test.10 Test.11; - if Test.15 then - let Test.7 = Index 1 Test.2; - let Test.3 = Index 1 Test.7; - let Test.5 = 1i64; - let Test.4 = CallByName Num.24 Test.3 Test.5; - ret Test.4; + let Test.14 = 0i64; + let Test.15 = Index 0 Test.2; + let Test.18 = lowlevel Eq Test.14 Test.15; + if Test.18 then + let Test.11 = Index 1 Test.2; + let Test.12 = 0i64; + let Test.13 = Index 0 Test.11; + let Test.17 = lowlevel Eq Test.12 Test.13; + if Test.17 then + let Test.9 = Index 1 Test.2; + let Test.5 = Index 1 Test.9; + let Test.7 = 1i64; + let Test.6 = CallByName Num.24 Test.5 Test.7; + ret Test.6; else - jump Test.14; + jump Test.16; else - jump Test.14; + jump Test.16; "# ), ) @@ -2193,67 +2174,67 @@ mod test_mono { indoc!( r#" procedure List.3 (#Attr.2, #Attr.3): - let Test.40 = lowlevel ListLen #Attr.2; - let Test.36 = lowlevel NumLt #Attr.3 Test.40; - if Test.36 then - let Test.39 = 1i64; - let Test.38 = lowlevel ListGetUnsafe #Attr.2 #Attr.3; - let Test.37 = Ok Test.39 Test.38; - ret Test.37; + let Test.41 = lowlevel ListLen #Attr.2; + let Test.37 = lowlevel NumLt #Attr.3 Test.41; + if Test.37 then + let Test.40 = 1i64; + let Test.39 = lowlevel ListGetUnsafe #Attr.2 #Attr.3; + let Test.38 = Ok Test.40 Test.39; + ret Test.38; else - let Test.35 = 0i64; - let Test.34 = Struct {}; - let Test.33 = Err Test.35 Test.34; - ret Test.33; + let Test.36 = 0i64; + let Test.35 = Struct {}; + let Test.34 = Err Test.36 Test.35; + ret Test.34; procedure List.4 (#Attr.2, #Attr.3, #Attr.4): - let Test.18 = lowlevel ListLen #Attr.2; - let Test.16 = lowlevel NumLt #Attr.3 Test.18; - if Test.16 then - let Test.17 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; - ret Test.17; + let Test.19 = lowlevel ListLen #Attr.2; + let Test.17 = lowlevel NumLt #Attr.3 Test.19; + if Test.17 then + let Test.18 = lowlevel ListSet #Attr.2 #Attr.3 #Attr.4; + ret Test.18; else ret #Attr.2; procedure Test.1 (Test.2, Test.3, Test.4): - let Test.32 = CallByName List.3 Test.4 Test.3; - let Test.31 = CallByName List.3 Test.4 Test.2; - let Test.12 = Struct {Test.31, Test.32}; - joinpoint Test.28: - let Test.21 = Array []; - ret Test.21; + let Test.33 = CallByName List.3 Test.4 Test.3; + let Test.32 = CallByName List.3 Test.4 Test.2; + let Test.13 = Struct {Test.32, Test.33}; + joinpoint Test.29: + let Test.22 = Array []; + ret Test.22; in - let Test.25 = Index 1 Test.12; - let Test.26 = 1i64; - let Test.27 = Index 0 Test.25; - let Test.30 = lowlevel Eq Test.26 Test.27; - if Test.30 then - let Test.22 = Index 0 Test.12; - let Test.23 = 1i64; - let Test.24 = Index 0 Test.22; - let Test.29 = lowlevel Eq Test.23 Test.24; - if Test.29 then - let Test.20 = Index 0 Test.12; - let Test.5 = Index 1 Test.20; - let Test.19 = Index 1 Test.12; - let Test.6 = Index 1 Test.19; - let Test.14 = CallByName List.4 Test.4 Test.2 Test.6; - let Test.13 = CallByName List.4 Test.14 Test.3 Test.5; - ret Test.13; + let Test.26 = Index 1 Test.13; + let Test.27 = 1i64; + let Test.28 = Index 0 Test.26; + let Test.31 = lowlevel Eq Test.27 Test.28; + if Test.31 then + let Test.23 = Index 0 Test.13; + let Test.24 = 1i64; + let Test.25 = Index 0 Test.23; + let Test.30 = lowlevel Eq Test.24 Test.25; + if Test.30 then + let Test.21 = Index 0 Test.13; + let Test.6 = Index 1 Test.21; + let Test.20 = Index 1 Test.13; + let Test.7 = Index 1 Test.20; + let Test.15 = CallByName List.4 Test.4 Test.2 Test.7; + let Test.14 = CallByName List.4 Test.15 Test.3 Test.6; + ret Test.14; else dec Test.4; - jump Test.28; + jump Test.29; else dec Test.4; - jump Test.28; + jump Test.29; procedure Test.0 (): - let Test.9 = 0i64; let Test.10 = 0i64; - let Test.41 = 1i64; - let Test.11 = Array [Test.41]; - let Test.8 = CallByName Test.1 Test.9 Test.10 Test.11; - ret Test.8; + let Test.11 = 0i64; + let Test.42 = 1i64; + let Test.12 = Array [Test.42]; + let Test.9 = CallByName Test.1 Test.10 Test.11 Test.12; + ret Test.9; "# ), ) diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index f176dec94b..4e54259f5a 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -728,6 +728,27 @@ fn type_to_variable( register(subs, rank, pools, content) } + FunctionOrTagUnion(tag_name, symbol, ext) => { + let temp_ext_var = type_to_variable(subs, rank, pools, cached, ext); + let mut ext_tag_vec = Vec::new(); + let new_ext_var = match roc_types::pretty_print::chase_ext_tag_union( + subs, + temp_ext_var, + &mut ext_tag_vec, + ) { + Ok(()) => Variable::EMPTY_TAG_UNION, + Err((new, _)) => new, + }; + debug_assert!(ext_tag_vec.is_empty()); + + let content = Content::Structure(FlatType::FunctionOrTagUnion( + tag_name.clone(), + *symbol, + new_ext_var, + )); + + register(subs, rank, pools, content) + } RecursiveTagUnion(rec_var, tags, ext) => { let mut tag_vars = MutMap::default(); @@ -1134,6 +1155,10 @@ fn adjust_rank_content( rank } + FunctionOrTagUnion(_, _, ext_var) => { + adjust_rank(subs, young_mark, visit_mark, group_rank, *ext_var) + } + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut rank = adjust_rank(subs, young_mark, visit_mark, group_rank, *ext_var); @@ -1309,6 +1334,12 @@ fn instantiate_rigids_help( ) } + FunctionOrTagUnion(tag_name, symbol, ext_var) => FunctionOrTagUnion( + tag_name, + symbol, + instantiate_rigids_help(subs, max_rank, pools, ext_var), + ), + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut new_tags = MutMap::default(); @@ -1495,6 +1526,12 @@ fn deep_copy_var_help( TagUnion(new_tags, deep_copy_var_help(subs, max_rank, pools, ext_var)) } + FunctionOrTagUnion(tag_name, symbol, ext_var) => FunctionOrTagUnion( + tag_name, + symbol, + deep_copy_var_help(subs, max_rank, pools, ext_var), + ), + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut new_tags = MutMap::default(); diff --git a/compiler/types/src/pretty_print.rs b/compiler/types/src/pretty_print.rs index aa1a39fa89..2b9ac83553 100644 --- a/compiler/types/src/pretty_print.rs +++ b/compiler/types/src/pretty_print.rs @@ -179,6 +179,9 @@ fn find_names_needed( find_names_needed(ext_var, subs, roots, root_appearances, names_taken); } + Structure(FunctionOrTagUnion(_, _, ext_var)) => { + find_names_needed(ext_var, subs, roots, root_appearances, names_taken); + } Structure(RecursiveTagUnion(rec_var, tags, ext_var)) => { let mut sorted_tags: Vec<_> = tags.iter().collect(); sorted_tags.sort(); @@ -487,6 +490,28 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String } } + FunctionOrTagUnion(tag_name, _, ext_var) => { + let interns = &env.interns; + let home = env.home; + + buf.push_str("[ "); + + buf.push_str(&tag_name.as_string(&interns, home)); + + buf.push_str(" ]"); + + let mut sorted_fields = vec![(tag_name, vec![])]; + let ext_content = chase_ext_tag_union(subs, ext_var, &mut sorted_fields); + if let Err((_, content)) = ext_content { + // This is an open tag union, so print the variable + // right after the ']' + // + // e.g. the "*" at the end of `{ x: I64 }*` + // or the "r" at the end of `{ x: I64 }r` + write_content(env, content, subs, buf, parens) + } + } + RecursiveTagUnion(rec_var, tags, ext_var) => { let interns = &env.interns; let home = env.home; @@ -570,6 +595,11 @@ pub fn chase_ext_tag_union( chase_ext_tag_union(subs, ext_var, fields) } + Content::Structure(FunctionOrTagUnion(tag_name, _, ext_var)) => { + fields.push((tag_name, vec![])); + + chase_ext_tag_union(subs, ext_var, fields) + } Content::Alias(_, _, var) => chase_ext_tag_union(subs, var, fields), content => Err((var, content)), diff --git a/compiler/types/src/solved_types.rs b/compiler/types/src/solved_types.rs index a70bc2b7c8..fa71812f3e 100644 --- a/compiler/types/src/solved_types.rs +++ b/compiler/types/src/solved_types.rs @@ -104,6 +104,10 @@ fn hash_solved_type_help( hash_solved_type_help(ext, flex_vars, state); } + FunctionOrTagUnion(_, _, ext) => { + hash_solved_type_help(ext, flex_vars, state); + } + RecursiveTagUnion(rec, tags, ext) => { var_id_hash_help(*rec, flex_vars, state); for (name, arguments) in tags { @@ -172,6 +176,7 @@ pub enum SolvedType { }, EmptyRecord, TagUnion(Vec<(TagName, Vec)>, Box), + FunctionOrTagUnion(TagName, Symbol, Box), RecursiveTagUnion(VarId, Vec<(TagName, Vec)>, Box), EmptyTagUnion, /// A type from an Invalid module @@ -263,6 +268,10 @@ impl SolvedType { SolvedType::TagUnion(solved_tags, Box::new(solved_ext)) } + FunctionOrTagUnion(tag_name, symbol, box_ext) => { + let solved_ext = Self::from_type(solved_subs, box_ext); + SolvedType::FunctionOrTagUnion(tag_name.clone(), *symbol, Box::new(solved_ext)) + } RecursiveTagUnion(rec_var, tags, box_ext) => { let solved_ext = Self::from_type(solved_subs, box_ext); let mut solved_tags = Vec::with_capacity(tags.len()); @@ -423,6 +432,11 @@ impl SolvedType { SolvedType::TagUnion(new_tags, Box::new(ext)) } + FunctionOrTagUnion(tag_name, symbol, ext_var) => { + let ext = Self::from_var_help(subs, recursion_vars, ext_var); + + SolvedType::FunctionOrTagUnion(tag_name, symbol, Box::new(ext)) + } RecursiveTagUnion(rec_var, tags, ext_var) => { recursion_vars.insert(subs, rec_var); @@ -562,6 +576,11 @@ pub fn to_type( Type::TagUnion(new_tags, Box::new(to_type(ext, free_vars, var_store))) } + FunctionOrTagUnion(tag_name, symbol, ext) => Type::FunctionOrTagUnion( + tag_name.clone(), + *symbol, + Box::new(to_type(ext, free_vars, var_store)), + ), RecursiveTagUnion(rec_var_id, tags, ext) => { let mut new_tags = Vec::with_capacity(tags.len()); diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index c43107f7d8..29a834378e 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -606,6 +606,7 @@ pub enum FlatType { Func(Vec, Variable, Variable), Record(MutMap>, Variable), TagUnion(MutMap>, Variable), + FunctionOrTagUnion(TagName, Symbol, Variable), RecursiveTagUnion(Variable, MutMap>, Variable), Erroneous(Problem), EmptyRecord, @@ -662,6 +663,10 @@ fn occurs( let it = once(&ext_var).chain(tags.values().flatten()); short_circuit(subs, root_var, &new_seen, it) } + FunctionOrTagUnion(_, _, ext_var) => { + let it = once(&ext_var); + short_circuit(subs, root_var, &new_seen, it) + } RecursiveTagUnion(_rec_var, tags, ext_var) => { // TODO rec_var is excluded here, verify that this is correct let it = once(&ext_var).chain(tags.values().flatten()); @@ -752,6 +757,13 @@ fn explicit_substitute( } subs.set_content(in_var, Structure(TagUnion(tags, new_ext_var))); } + FunctionOrTagUnion(tag_name, symbol, ext_var) => { + let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); + subs.set_content( + in_var, + Structure(FunctionOrTagUnion(tag_name, symbol, new_ext_var)), + ); + } RecursiveTagUnion(rec_var, mut tags, ext_var) => { // NOTE rec_var is not substituted, verify that this is correct! let new_ext_var = explicit_substitute(subs, from, to, ext_var, seen); @@ -891,6 +903,10 @@ fn get_var_names( taken_names } + FlatType::FunctionOrTagUnion(_, _, ext_var) => { + get_var_names(subs, ext_var, taken_names) + } + FlatType::RecursiveTagUnion(rec_var, tags, ext_var) => { let taken_names = get_var_names(subs, ext_var, taken_names); let mut taken_names = get_var_names(subs, rec_var, taken_names); @@ -1142,6 +1158,32 @@ fn flat_type_to_err_type( } } + FunctionOrTagUnion(tag_name, _, ext_var) => { + let mut err_tags = SendMap::default(); + + err_tags.insert(tag_name, vec![]); + + match var_to_err_type(subs, state, ext_var).unwrap_alias() { + ErrorType::TagUnion(sub_tags, sub_ext) => { + ErrorType::TagUnion(sub_tags.union(err_tags), sub_ext) + } + ErrorType::RecursiveTagUnion(_, sub_tags, sub_ext) => { + ErrorType::TagUnion(sub_tags.union(err_tags), sub_ext) + } + + ErrorType::FlexVar(var) => { + ErrorType::TagUnion(err_tags, TypeExt::FlexOpen(var)) + } + + ErrorType::RigidVar(var) => { + ErrorType::TagUnion(err_tags, TypeExt::RigidOpen(var)) + } + + other => + panic!("Tried to convert a tag union extension to an error, but the tag union extension had the ErrorType of {:?}", other) + } + } + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut err_tags = SendMap::default(); @@ -1236,6 +1278,9 @@ fn restore_content(subs: &mut Subs, content: &Content) { subs.restore(*ext_var); } + FunctionOrTagUnion(_, _, ext_var) => { + subs.restore(*ext_var); + } RecursiveTagUnion(rec_var, tags, ext_var) => { for var in tags.values().flatten() { diff --git a/compiler/types/src/types.rs b/compiler/types/src/types.rs index bdbabd53ed..7e0e54d3f4 100644 --- a/compiler/types/src/types.rs +++ b/compiler/types/src/types.rs @@ -141,6 +141,7 @@ pub enum Type { Function(Vec, Box, Box), Record(SendMap>, Box), TagUnion(Vec<(TagName, Vec)>, Box), + FunctionOrTagUnion(TagName, Symbol, Box), Alias(Symbol, Vec<(Lowercase, Type)>, Box), HostExposedAlias { name: Symbol, @@ -308,6 +309,26 @@ impl fmt::Debug for Type { } } } + Type::FunctionOrTagUnion(tag_name, _, ext) => { + write!(f, "[")?; + write!(f, "{:?}", tag_name)?; + write!(f, "]")?; + + match *ext.clone() { + Type::EmptyTagUnion => { + // This is a closed variant. We're done! + Ok(()) + } + other => { + // This is an open tag union, so print the variable + // right after the ']' + // + // e.g. the "*" at the end of `[ Foo ]*` + // or the "r" at the end of `[ DivByZero ]r` + other.fmt(f) + } + } + } Type::RecursiveTagUnion(rec, tags, ext) => { write!(f, "[")?; @@ -404,6 +425,9 @@ impl Type { } ext.substitute(substitutions); } + FunctionOrTagUnion(_, _, ext) => { + ext.substitute(substitutions); + } RecursiveTagUnion(_, tags, ext) => { for (_, args) in tags { for x in args { @@ -456,6 +480,9 @@ impl Type { closure.substitute_alias(rep_symbol, actual); ret.substitute_alias(rep_symbol, actual); } + FunctionOrTagUnion(_, _, ext) => { + ext.substitute_alias(rep_symbol, actual); + } RecursiveTagUnion(_, tags, ext) | TagUnion(tags, ext) => { for (_, args) in tags { for x in args { @@ -506,6 +533,7 @@ impl Type { || closure.contains_symbol(rep_symbol) || args.iter().any(|arg| arg.contains_symbol(rep_symbol)) } + FunctionOrTagUnion(_, _, ext) => ext.contains_symbol(rep_symbol), RecursiveTagUnion(_, tags, ext) | TagUnion(tags, ext) => { ext.contains_symbol(rep_symbol) || tags @@ -541,6 +569,7 @@ impl Type { || closure.contains_variable(rep_variable) || args.iter().any(|arg| arg.contains_variable(rep_variable)) } + FunctionOrTagUnion(_, _, ext) => ext.contains_variable(rep_variable), RecursiveTagUnion(_, tags, ext) | TagUnion(tags, ext) => { ext.contains_variable(rep_variable) || tags @@ -595,6 +624,9 @@ impl Type { closure.instantiate_aliases(region, aliases, var_store, introduced); ret.instantiate_aliases(region, aliases, var_store, introduced); } + FunctionOrTagUnion(_, _, ext) => { + ext.instantiate_aliases(region, aliases, var_store, introduced); + } RecursiveTagUnion(_, tags, ext) | TagUnion(tags, ext) => { for (_, args) in tags { for x in args { @@ -734,6 +766,9 @@ fn symbols_help(tipe: &Type, accum: &mut ImSet) { symbols_help(&closure, accum); args.iter().for_each(|arg| symbols_help(arg, accum)); } + FunctionOrTagUnion(_, _, ext) => { + symbols_help(&ext, accum); + } RecursiveTagUnion(_, tags, ext) | TagUnion(tags, ext) => { symbols_help(&ext, accum); tags.iter() @@ -807,6 +842,9 @@ fn variables_help(tipe: &Type, accum: &mut ImSet) { } variables_help(ext, accum); } + FunctionOrTagUnion(_, _, ext) => { + variables_help(ext, accum); + } RecursiveTagUnion(rec, tags, ext) => { for (_, args) in tags { for x in args { @@ -900,6 +938,9 @@ fn variables_help_detailed(tipe: &Type, accum: &mut VariableDetail) { } variables_help_detailed(ext, accum); } + FunctionOrTagUnion(_, _, ext) => { + variables_help_detailed(ext, accum); + } RecursiveTagUnion(rec, tags, ext) => { for (_, args) in tags { for x in args { diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index fea0690de8..3f88aff3bd 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -237,6 +237,10 @@ fn unify_structure( // unify the structure with this recursive tag union unify_pool(subs, pool, ctx.first, *structure) } + FlatType::FunctionOrTagUnion(_, _, _) => { + // unify the structure with this unrecursive tag union + unify_pool(subs, pool, ctx.first, *structure) + } _ => todo!("rec structure {:?}", &flat_type), }, @@ -978,12 +982,106 @@ fn unify_flat_type( problems } } - (TagUnion(tags, ext), Func(args, closure, ret)) if tags.len() == 1 => { - unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, true) + (FunctionOrTagUnion(tag_name, tag_symbol, ext), Func(args, closure, ret)) => { + unify_function_or_tag_union_and_func( + subs, + pool, + ctx, + tag_name, + *tag_symbol, + *ext, + args, + *ret, + *closure, + true, + ) } - (Func(args, closure, ret), TagUnion(tags, ext)) if tags.len() == 1 => { - unify_tag_union_and_func(tags, args, subs, pool, ctx, ext, ret, closure, false) + (Func(args, closure, ret), FunctionOrTagUnion(tag_name, tag_symbol, ext)) => { + unify_function_or_tag_union_and_func( + subs, + pool, + ctx, + tag_name, + *tag_symbol, + *ext, + args, + *ret, + *closure, + false, + ) } + (FunctionOrTagUnion(tag_name_1, _, ext_1), FunctionOrTagUnion(tag_name_2, _, ext_2)) => { + if tag_name_1 == tag_name_2 { + let problems = unify_pool(subs, pool, *ext_1, *ext_2); + if problems.is_empty() { + let desc = subs.get(ctx.second); + merge(subs, ctx, desc.content) + } else { + problems + } + } else { + let mut tags1 = MutMap::default(); + tags1.insert(tag_name_1.clone(), vec![]); + let union1 = gather_tags(subs, tags1, *ext_1); + + let mut tags2 = MutMap::default(); + tags2.insert(tag_name_2.clone(), vec![]); + let union2 = gather_tags(subs, tags2, *ext_2); + + unify_tag_union(subs, pool, ctx, union1, union2, (None, None)) + } + } + (TagUnion(tags1, ext1), FunctionOrTagUnion(tag_name, _, ext2)) => { + let union1 = gather_tags(subs, tags1.clone(), *ext1); + + let mut tags2 = MutMap::default(); + tags2.insert(tag_name.clone(), vec![]); + let union2 = gather_tags(subs, tags2, *ext2); + + unify_tag_union(subs, pool, ctx, union1, union2, (None, None)) + } + (FunctionOrTagUnion(tag_name, _, ext1), TagUnion(tags2, ext2)) => { + let mut tags1 = MutMap::default(); + tags1.insert(tag_name.clone(), vec![]); + let union1 = gather_tags(subs, tags1, *ext1); + + let union2 = gather_tags(subs, tags2.clone(), *ext2); + + unify_tag_union(subs, pool, ctx, union1, union2, (None, None)) + } + + (RecursiveTagUnion(recursion_var, tags1, ext1), FunctionOrTagUnion(tag_name, _, ext2)) => { + // this never happens in type-correct programs, but may happen if there is a type error + debug_assert!(is_recursion_var(subs, *recursion_var)); + + let mut tags2 = MutMap::default(); + tags2.insert(tag_name.clone(), vec![]); + + let union1 = gather_tags(subs, tags1.clone(), *ext1); + let union2 = gather_tags(subs, tags2, *ext2); + + unify_tag_union( + subs, + pool, + ctx, + union1, + union2, + (Some(*recursion_var), None), + ) + } + + (FunctionOrTagUnion(tag_name, _, ext1), RecursiveTagUnion(recursion_var, tags2, ext2)) => { + debug_assert!(is_recursion_var(subs, *recursion_var)); + + let mut tags1 = MutMap::default(); + tags1.insert(tag_name.clone(), vec![]); + + let union1 = gather_tags(subs, tags1, *ext1); + let union2 = gather_tags(subs, tags2.clone(), *ext2); + + unify_tag_union_not_recursive_recursive(subs, pool, ctx, union1, union2, *recursion_var) + } + (other1, other2) => mismatch!( "Trying to unify two flat types that are incompatible: {:?} ~ {:?}", other1, @@ -1166,53 +1264,44 @@ fn is_recursion_var(subs: &Subs, var: Variable) -> bool { ) } -#[allow(clippy::too_many_arguments, clippy::ptr_arg)] -fn unify_tag_union_and_func( - tags: &MutMap>, - args: &Vec, +#[allow(clippy::too_many_arguments)] +fn unify_function_or_tag_union_and_func( subs: &mut Subs, pool: &mut Pool, ctx: &Context, - ext: &Variable, - ret: &Variable, - closure: &Variable, + tag_name: &TagName, + _tag_symbol: Symbol, + tag_ext: Variable, + function_arguments: &[Variable], + function_return: Variable, + _function_lambda_set: Variable, left: bool, ) -> Outcome { use FlatType::*; - let (tag_name, payload) = tags.iter().next().unwrap(); + let mut new_tags = MutMap::with_capacity_and_hasher(1, default_hasher()); - if payload.is_empty() { - let mut new_tags = MutMap::with_capacity_and_hasher(1, default_hasher()); + new_tags.insert(tag_name.clone(), function_arguments.to_owned()); - new_tags.insert(tag_name.clone(), args.to_owned()); + let content = Structure(TagUnion(new_tags, tag_ext)); - let content = Structure(TagUnion(new_tags, *ext)); + let new_tag_union_var = fresh(subs, pool, ctx, content); - let new_tag_union_var = fresh(subs, pool, ctx, content); + let problems = if left { + unify_pool(subs, pool, new_tag_union_var, function_return) + } else { + unify_pool(subs, pool, function_return, new_tag_union_var) + }; - let problems = if left { - unify_pool(subs, pool, new_tag_union_var, *ret) + if problems.is_empty() { + let desc = if left { + subs.get(ctx.second) } else { - unify_pool(subs, pool, *ret, new_tag_union_var) + subs.get(ctx.first) }; - if problems.is_empty() { - let desc = if left { - subs.get(ctx.second) - } else { - subs.get(ctx.first) - }; - - subs.union(ctx.first, ctx.second, desc); - } - - problems - } else { - mismatch!( - "Trying to unify two flat types that are incompatible: {:?} ~ {:?}", - TagUnion(tags.clone(), *ext), - Func(args.to_owned(), *closure, *ret) - ) + subs.union(ctx.first, ctx.second, desc); } + + problems } diff --git a/editor/src/lang/solve.rs b/editor/src/lang/solve.rs index 7f9d5c7874..eac4fe4bf5 100644 --- a/editor/src/lang/solve.rs +++ b/editor/src/lang/solve.rs @@ -1243,6 +1243,10 @@ fn adjust_rank_content( rank } + FunctionOrTagUnion(_, _, ext_var) => { + adjust_rank(subs, young_mark, visit_mark, group_rank, *ext_var) + } + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut rank = adjust_rank(subs, young_mark, visit_mark, group_rank, *ext_var); @@ -1418,6 +1422,12 @@ fn instantiate_rigids_help( ) } + FunctionOrTagUnion(tag_name, symbol, ext_var) => FunctionOrTagUnion( + tag_name, + symbol, + instantiate_rigids_help(subs, max_rank, pools, ext_var), + ), + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut new_tags = MutMap::default(); @@ -1604,6 +1614,12 @@ fn deep_copy_var_help( TagUnion(new_tags, deep_copy_var_help(subs, max_rank, pools, ext_var)) } + FunctionOrTagUnion(tag_name, symbol, ext_var) => FunctionOrTagUnion( + tag_name, + symbol, + deep_copy_var_help(subs, max_rank, pools, ext_var), + ), + RecursiveTagUnion(rec_var, tags, ext_var) => { let mut new_tags = MutMap::default();