From 6318f032df7f2a2ca46c35097f44e88cad6a4caa Mon Sep 17 00:00:00 2001 From: Folkert Date: Thu, 15 Oct 2020 16:03:56 +0200 Subject: [PATCH] first shot at implementing closure layout --- compiler/gen/src/llvm/convert.rs | 95 ++++++++++------- compiler/gen/src/llvm/refcounting.rs | 42 ++++++-- compiler/gen/tests/gen_primitives.rs | 8 +- compiler/mono/src/ir.rs | 151 ++++++++++++++++++++++++++- compiler/mono/src/layout.rs | 42 +++++++- 5 files changed, 277 insertions(+), 61 deletions(-) diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index eb3878157a..32f3fe964a 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -48,6 +48,50 @@ pub fn get_array_type<'ctx>(bt_enum: &BasicTypeEnum<'ctx>, size: u32) -> ArrayTy } } +fn basic_type_from_function_layout<'ctx>( + arena: &Bump, + context: &'ctx Context, + args: &[Layout<'_>], + ret_layout: &Layout<'_>, + ptr_bytes: u32, +) -> BasicTypeEnum<'ctx> { + let ret_type = basic_type_from_layout(arena, context, &ret_layout, ptr_bytes); + let mut arg_basic_types = Vec::with_capacity_in(args.len(), arena); + + for arg_layout in args.iter() { + arg_basic_types.push(basic_type_from_layout( + arena, context, arg_layout, ptr_bytes, + )); + } + + let fn_type = get_fn_type(&ret_type, arg_basic_types.into_bump_slice()); + let ptr_type = fn_type.ptr_type(AddressSpace::Generic); + + ptr_type.as_basic_type_enum() +} + +fn basic_type_from_record<'ctx>( + arena: &Bump, + context: &'ctx Context, + fields: &[Layout<'_>], + ptr_bytes: u32, +) -> BasicTypeEnum<'ctx> { + let mut field_types = Vec::with_capacity_in(fields.len(), arena); + + for field_layout in fields.iter() { + field_types.push(basic_type_from_layout( + arena, + context, + field_layout, + ptr_bytes, + )); + } + + context + .struct_type(field_types.into_bump_slice(), false) + .as_basic_type_enum() +} + pub fn basic_type_from_layout<'ctx>( arena: &Bump, context: &'ctx Context, @@ -59,53 +103,26 @@ pub fn basic_type_from_layout<'ctx>( match layout { FunctionPointer(args, ret_layout) => { - let ret_type = basic_type_from_layout(arena, context, &ret_layout, ptr_bytes); - let mut arg_basic_types = Vec::with_capacity_in(args.len(), arena); + basic_type_from_function_layout(arena, context, args, ret_layout, ptr_bytes) + } + Closure(args, closure_layout, ret_layout) => { + let function_pointer = + basic_type_from_function_layout(arena, context, args, ret_layout, ptr_bytes); - for arg_layout in args.iter() { - arg_basic_types.push(basic_type_from_layout( - arena, context, arg_layout, ptr_bytes, - )); - } + let closure_data = basic_type_from_record(arena, context, closure_layout, ptr_bytes); - let fn_type = get_fn_type(&ret_type, arg_basic_types.into_bump_slice()); - let ptr_type = fn_type.ptr_type(AddressSpace::Generic); - - ptr_type.as_basic_type_enum() + context + .struct_type(&[function_pointer, closure_data], false) + .as_basic_type_enum() } Pointer(layout) => basic_type_from_layout(arena, context, &layout, ptr_bytes) .ptr_type(AddressSpace::Generic) .into(), - Struct(sorted_fields) => { - // Determine types - let mut field_types = Vec::with_capacity_in(sorted_fields.len(), arena); - - for field_layout in sorted_fields.iter() { - field_types.push(basic_type_from_layout( - arena, - context, - field_layout, - ptr_bytes, - )); - } - - context - .struct_type(field_types.into_bump_slice(), false) - .as_basic_type_enum() - } + Struct(sorted_fields) => basic_type_from_record(arena, context, sorted_fields, ptr_bytes), Union(tags) if tags.len() == 1 => { - let layouts = tags.iter().next().unwrap(); + let sorted_fields = tags.iter().next().unwrap(); - // Determine types - let mut field_types = Vec::with_capacity_in(layouts.len(), arena); - - for layout in layouts.iter() { - field_types.push(basic_type_from_layout(arena, context, layout, ptr_bytes)); - } - - context - .struct_type(field_types.into_bump_slice(), false) - .as_basic_type_enum() + basic_type_from_record(arena, context, sorted_fields, ptr_bytes) } RecursiveUnion(_) | Union(_) => block_of_memory(context, layout, ptr_bytes), RecursivePointer => { diff --git a/compiler/gen/src/llvm/refcounting.rs b/compiler/gen/src/llvm/refcounting.rs index a950e5eca9..6c2ad2eb09 100644 --- a/compiler/gen/src/llvm/refcounting.rs +++ b/compiler/gen/src/llvm/refcounting.rs @@ -29,6 +29,27 @@ pub fn refcount_1(ctx: &Context, ptr_bytes: u32) -> IntValue<'_> { } } +pub fn decrement_refcount_struct<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + layout_ids: &mut LayoutIds<'a>, + value: BasicValueEnum<'ctx>, + layouts: &[Layout<'a>], +) { + let wrapper_struct = value.into_struct_value(); + + for (i, field_layout) in layouts.iter().enumerate() { + if field_layout.contains_refcounted() { + let field_ptr = env + .builder + .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") + .unwrap(); + + decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout) + } + } +} + pub fn decrement_refcount_layout<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, @@ -42,20 +63,21 @@ pub fn decrement_refcount_layout<'a, 'ctx, 'env>( Builtin(builtin) => { decrement_refcount_builtin(env, parent, layout_ids, value, layout, builtin) } - Struct(layouts) => { - let wrapper_struct = value.into_struct_value(); + Closure(_, closure_layout, _) => { + if closure_layout.iter().any(|f| f.contains_refcounted()) { + let wrapper_struct = value.into_struct_value(); - for (i, field_layout) in layouts.iter().enumerate() { - if field_layout.contains_refcounted() { - let field_ptr = env - .builder - .build_extract_value(wrapper_struct, i as u32, "decrement_struct_field") - .unwrap(); + let field_ptr = env + .builder + .build_extract_value(wrapper_struct, 1, "decrement_closure_data") + .unwrap(); - decrement_refcount_layout(env, parent, layout_ids, field_ptr, field_layout) - } + decrement_refcount_struct(env, parent, layout_ids, field_ptr, closure_layout) } } + Struct(layouts) => { + decrement_refcount_struct(env, parent, layout_ids, value, layouts); + } RecursivePointer => todo!("TODO implement decrement layout of recursive tag union"), Union(tags) => { diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 02b3d7fc96..8c5c3d636c 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -929,11 +929,13 @@ mod gen_primitives { r#" app Test provides [ main ] imports [] - x = 42 + foo = \{} -> + x = 42 + f = \{} -> x + f main = - f = \{} -> x - + f = foo {} f {} "# ), diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 29f4d8776b..9ebe9c3338 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -1248,6 +1248,7 @@ pub fn specialize_all<'a>( mut procs: Procs<'a>, layout_cache: &mut LayoutCache<'a>, ) -> Procs<'a> { + dbg!(&procs); let it = procs.externals_others_need.specs.clone(); let it = it .into_iter() @@ -1456,12 +1457,16 @@ fn build_specialized_proc<'a>( let proc_args = proc_args.into_bump_slice(); let closes_over = match closure_var { - Some(cvar) => layout_cache - .from_var(&env.arena, cvar, env.subs) - .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)), + Some(cvar) => match layout_cache.from_var(&env.arena, cvar, env.subs) { + Ok(layout) => layout, + Err(LayoutProblem::UnresolvedTypeVar) => Layout::Struct(&[]), + Err(err) => panic!("TODO handle invalid function {:?}", err), + }, None => Layout::Struct(&[]), }; + dbg!(&closes_over); + let ret_layout = layout_cache .from_var(&env.arena, ret_var, env.subs) .unwrap_or_else(|err| panic!("TODO handle invalid function {:?}", err)); @@ -1621,6 +1626,7 @@ pub fn with_hole<'a>( .. } = def.loc_expr.value { + dbg!(symbol); // Extract Procs, but discard the resulting Expr::Load. // That Load looks up the pointer, which we won't use here! @@ -1732,6 +1738,7 @@ pub fn with_hole<'a>( .. } = def.loc_expr.value { + dbg!(symbol); // Extract Procs, but discard the resulting Expr::Load. // That Load looks up the pointer, which we won't use here! @@ -2423,6 +2430,7 @@ pub fn with_hole<'a>( loc_body: boxed_body, .. } => { + dbg!(name); let loc_body = *boxed_body; match procs.insert_anonymous( @@ -2754,7 +2762,142 @@ pub fn from_can<'a>( return_type, ); - return from_can(env, cont.value, procs, layout_cache); + // does this function capture any local values? + let function_layout = + layout_cache.from_var(env.arena, function_type, env.subs); + let is_closure = + matches!(&function_layout, Ok(Layout::Closure(_, _, _))); + + if is_closure { + let function_layout = function_layout.unwrap(); + let full_layout = function_layout.clone(); + let fn_var = function_type; + let proc_name = *symbol; + let pending = PendingSpecialization::from_var(env.subs, fn_var); + + // When requested (that is, when procs.pending_specializations is `Some`), + // store a pending specialization rather than specializing immediately. + // + // We do this so that we can do specialization in two passes: first, + // build the mono_expr with all the specialized calls in place (but + // no specializations performed yet), and then second, *after* + // de-duplicating requested specializations (since multiple modules + // which could be getting monomorphized in parallel might request + // the same specialization independently), we work through the + // queue of pending specializations to complete each specialization + // exactly once. + match &mut procs.pending_specializations { + Some(pending_specializations) => { + // register the pending specialization, so this gets code genned later + add_pending( + pending_specializations, + proc_name, + full_layout.clone(), + pending, + ); + } + None => { + let opt_partial_proc = procs.partial_procs.get(&proc_name); + + match opt_partial_proc { + None => panic!("invalid proc"), + Some(partial_proc) => { + // TODO should pending_procs hold a Rc to avoid this .clone()? + let partial_proc = partial_proc.clone(); + + // Mark this proc as in-progress, so if we're dealing with + // mutually recursive functions, we don't loop forever. + // (We had a bug around this before this system existed!) + procs.specialized.insert( + (proc_name, full_layout.clone()), + InProgress, + ); + + match specialize( + env, + procs, + proc_name, + layout_cache, + pending, + partial_proc, + ) { + Ok((proc, layout)) => { + debug_assert_eq!(full_layout, layout); + let function_layout = + FunctionLayouts::from_layout(layout); + + procs + .specialized + .remove(&(proc_name, full_layout)); + + procs.specialized.insert( + ( + proc_name, + function_layout.full.clone(), + ), + Done(proc), + ); + } + Err(error) => { + let error_msg = env.arena.alloc(format!( + "TODO generate a RuntimeError message for {:?}", + error + )); + + procs + .runtime_errors + .insert(proc_name, error_msg); + + panic!(); + // Stmt::RuntimeError(error_msg) + } + } + } + } + } + } + + let mut stmt = from_can(env, cont.value, procs, layout_cache); + + let function_pointer = env.unique_symbol(); + let closure_data = env.unique_symbol(); + + // define the closure + let expr = + Expr::Struct(env.arena.alloc([function_pointer, closure_data])); + + stmt = Stmt::Let( + *symbol, + expr, + function_layout.clone(), + env.arena.alloc(stmt), + ); + + // define the closure data + let expr = Expr::Struct(&[]); + let closure_data_layout = Layout::Struct(&[]); + + stmt = Stmt::Let( + closure_data, + expr, + closure_data_layout, + env.arena.alloc(stmt), + ); + + // define the function pointer + let expr = Expr::FunctionPointer(*symbol, function_layout.clone()); + + stmt = Stmt::Let( + function_pointer, + expr, + function_layout, + env.arena.alloc(stmt), + ); + + return stmt; + } else { + return from_can(env, cont.value, procs, layout_cache); + } } _ => unreachable!(), } diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 6c4b438b39..6ba80f3921 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -26,6 +26,7 @@ pub enum Layout<'a> { RecursivePointer, /// A function. The types of its arguments, then the type of its return value. FunctionPointer(&'a [Layout<'a>], &'a Layout<'a>), + Closure(&'a [Layout<'a>], &'a [Layout<'a>], &'a Layout<'a>), Pointer(&'a Layout<'a>), } @@ -139,6 +140,9 @@ impl<'a> Layout<'a> { // Function pointers are immutable and can always be safely copied true } + Closure(_, closure_layout, _) => { + closure_layout.iter().all(|field| field.safe_to_memcpy()) + } Pointer(_) => { // We cannot memcpy pointers, because then we would have the same pointer in multiple places! false @@ -191,6 +195,13 @@ impl<'a> Layout<'a> { }) .max() .unwrap_or_default(), + Closure(_, closure_layout, _) => { + pointer_size + + closure_layout + .iter() + .map(|x| x.stack_size(pointer_size)) + .sum::() + } FunctionPointer(_, _) => pointer_size, RecursivePointer => pointer_size, Pointer(_) => pointer_size, @@ -220,6 +231,7 @@ impl<'a> Layout<'a> { .flatten() .any(|f| f.is_refcounted()), RecursiveUnion(_) => true, + Closure(_, closure_layout, _) => closure_layout.iter().any(|f| f.contains_refcounted()), FunctionPointer(_, _) | RecursivePointer | Pointer(_) => false, } } @@ -477,7 +489,7 @@ fn layout_from_flat_type<'a>( } } } - Func(args, _, ret_var) => { + Func(args, closure_var, ret_var) => { let mut fn_args = Vec::with_capacity_in(args.len(), arena); for arg_var in args { @@ -486,10 +498,30 @@ fn layout_from_flat_type<'a>( let ret = Layout::from_var(env, ret_var)?; - Ok(Layout::FunctionPointer( - fn_args.into_bump_slice(), - arena.alloc(ret), - )) + match Layout::from_var(env, closure_var) { + Ok(Layout::Builtin(builtin)) => Ok(Layout::Closure( + fn_args.into_bump_slice(), + arena.alloc([Layout::Builtin(builtin.clone())]), + arena.alloc(ret), + )), + Ok(Layout::Struct(closure_layouts)) => Ok(Layout::Closure( + fn_args.into_bump_slice(), + closure_layouts, + arena.alloc(ret), + )), + Ok(closure_layout) => { + // the closure parameter can be a tag union if there are multiple sizes + // we must make sure we can distinguish between that tag union, + // and the closure containing just one element, that happens to be a tag union. + todo!("TODO closure layout {:?}", &closure_layout) + } + Err(LayoutProblem::UnresolvedTypeVar) => Ok(Layout::FunctionPointer( + fn_args.into_bump_slice(), + arena.alloc(ret), + )), + + error => error, + } } Record(fields, ext_var) => { // Sort the fields by label