diff --git a/crates/compiler/mono/src/ir.rs b/crates/compiler/mono/src/ir.rs index 75738ac8d1..a680a9d147 100644 --- a/crates/compiler/mono/src/ir.rs +++ b/crates/compiler/mono/src/ir.rs @@ -9647,6 +9647,9 @@ where lambda_set_id = lambda_set_id.next(); stack.push(layout_interner.get(lambda_set.runtime_representation())); + + // TODO: figure out if we need to look at the other layouts + // stack.push(layout_interner.get(lambda_set.ret)); } Layout::RecursivePointer(_) => { /* do nothing, we've already generated for this type through the Union(_) */ diff --git a/crates/glue/src/RustGlue.roc b/crates/glue/src/RustGlue.roc index ce8c3724f6..87545ba06c 100644 --- a/crates/glue/src/RustGlue.roc +++ b/crates/glue/src/RustGlue.roc @@ -1,6 +1,6 @@ app "rust-glue" packages { pf: "../platform/main.roc" } - imports [pf.Types.{ Types }, pf.Shape.{ RocFn }, pf.File.{ File }, pf.TypeId.{ TypeId }] + imports [pf.Types.{ Types }, pf.Shape.{ Shape, RocFn }, pf.File.{ File }, pf.TypeId.{ TypeId }] provides [makeGlue] to pf makeGlue : List Types -> Result (List File) Str @@ -182,11 +182,19 @@ generateFunction = \buf, types, rocFn -> "arg\(c): \(type)" |> Str.joinWith ", " - externArguments = + externDefArguments = + rocFn.args + |> List.mapWithIndex \argId, i -> + type = typeName types argId + c = Num.toStr i + "arg\(c): *const \(type)" + |> Str.joinWith ", " + + externCallArguments = rocFn.args |> List.mapWithIndex \_, i -> c = Num.toStr i - "arg\(c)" + "&arg\(c)" |> Str.joinWith ", " externComma = if Str.isEmpty publicArguments then "" else ", " @@ -197,6 +205,7 @@ generateFunction = \buf, types, rocFn -> \(buf) #[repr(C)] + #[derive(Debug, Clone)] pub struct \(name) { closure_data: \(lambdaSet), } @@ -204,13 +213,16 @@ generateFunction = \buf, types, rocFn -> impl \(name) { pub fn force_thunk(mut self, \(publicArguments)) -> \(ret) { extern "C" { - fn \(externName)(\(publicArguments)\(externComma) closure_data: *mut u8, output: *mut \(ret)); + fn \(externName)(\(externDefArguments)\(externComma) closure_data: *mut u8, output: *mut \(ret)); } let mut output = std::mem::MaybeUninit::uninit(); let ptr = &mut self.closure_data as *mut _ as *mut u8; - unsafe { \(externName)(\(externArguments)\(externComma) ptr, output.as_mut_ptr(), ) }; + unsafe { \(externName)(\(externCallArguments)\(externComma) ptr, output.as_mut_ptr(), ) }; + + // ownership of the closure is transferred back to roc + core::mem::forget(self.closure_data); unsafe { output.assume_init() } } @@ -799,35 +811,80 @@ generateZeroElementSingleTagStruct = \buf, name, tagName -> """ generateDeriveStr = \buf, types, type, includeDebug -> + condWrite = \b, cond, str -> + if cond then + Str.concat b str + else + b + + deriveDebug = when includeDebug is + IncludeDebug -> Bool.true + ExcludeDebug -> Bool.false + buf |> Str.concat "#[derive(Clone, " - |> \b -> - if !(cannotDeriveCopy types type) then - Str.concat b "Copy, " - else - b - |> \b -> - if !(cannotDeriveDefault types type) then - Str.concat b "Default, " - else - b - |> \b -> - when includeDebug is - IncludeDebug -> - Str.concat b "Debug, " + |> condWrite (!(cannotDeriveCopy types type)) "Copy, " + |> condWrite (!(cannotDeriveDefault types type)) "Default, " + |> condWrite deriveDebug "Debug, " + |> condWrite (canDerivePartialEq types type) "PartialEq, PartialOrd, " + |> condWrite (!(hasFloat types type) && (canDerivePartialEq types type)) "Eq, Ord, Hash, " + |> Str.concat ")]\n" - ExcludeDebug -> - b - |> \b -> - if !(hasFloat types type) then - Str.concat b "Eq, Ord, Hash, " - else - b - |> Str.concat "PartialEq, PartialOrd)]\n" +canDerivePartialEq : Types, Shape -> Bool +canDerivePartialEq = \types, type -> + when type is + Function rocFn -> + runtimeRepresentation = Types.shape types rocFn.lambdaSet + canDerivePartialEq types runtimeRepresentation + Unsized -> Bool.false + + Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.true + RocStr -> Bool.true + RocList inner | RocSet inner | RocBox inner -> + innerType = Types.shape types inner + canDerivePartialEq types innerType + + RocDict k v -> + kType = Types.shape types k + vType = Types.shape types v + + canDerivePartialEq types kType && canDerivePartialEq types vType + + TagUnion (NullableUnwrapped _) | TagUnion (NullableWrapped _) | TagUnion (Recursive _) | TagUnion (NonNullableUnwrapped _) | RecursivePointer _ -> crash "TODO" + TagUnion (SingleTagStruct { payload: HasNoClosure fields }) -> + List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + + TagUnion (SingleTagStruct { payload: HasClosure fields }) -> + List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + + TagUnion (NonRecursive { tags }) -> + List.all tags \{ payload } -> + when payload is + Some id -> canDerivePartialEq types (Types.shape types id) + None -> Bool.true + + RocResult okId errId -> + canDerivePartialEq types (Types.shape types okId) + && canDerivePartialEq types (Types.shape types errId) + + Struct { fields: HasNoClosure fields } | TagUnionPayload { fields: HasNoClosure fields } -> + List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + + Struct { fields: HasClosure fields } | TagUnionPayload { fields: HasClosure fields } -> + List.all fields \{ id } -> canDerivePartialEq types (Types.shape types id) + +cannotDeriveCopy : Types, Shape -> Bool cannotDeriveCopy = \types, type -> when type is - Unit | Unsized | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) | Function _ -> Bool.false + Function rocFn -> + runtimeRepresentation = Types.shape types rocFn.lambdaSet + cannotDeriveCopy types runtimeRepresentation + + # unsized values are heap-allocated + Unsized -> Bool.true + + Unit | EmptyTagUnion | Bool | Num _ | TagUnion (Enumeration _) -> Bool.false RocStr | RocList _ | RocDict _ _ | RocSet _ | RocBox _ | TagUnion (NullableUnwrapped _) | TagUnion (NullableWrapped _) | TagUnion (Recursive _) | TagUnion (NonNullableUnwrapped _) | RecursivePointer _ -> Bool.true TagUnion (SingleTagStruct { payload: HasNoClosure fields }) -> List.any fields \{ id } -> cannotDeriveCopy types (Types.shape types id) diff --git a/crates/glue/src/types.rs b/crates/glue/src/types.rs index 212d83ece0..c9abb0ac9a 100644 --- a/crates/glue/src/types.rs +++ b/crates/glue/src/types.rs @@ -1924,7 +1924,7 @@ where RocStructFields::HasClosure { fields } } None => { - debug_assert!(!layout.has_varying_stack_size(&env.layout_cache.interner, arena)); + // debug_assert!(!layout.has_varying_stack_size(&env.layout_cache.interner, arena)); let fields: Vec<(String, TypeId)> = sortables .into_iter() diff --git a/crates/glue/tests/fixtures/return-function/app.roc b/crates/glue/tests/fixtures/return-function/app.roc new file mode 100644 index 0000000000..814a2d6b26 --- /dev/null +++ b/crates/glue/tests/fixtures/return-function/app.roc @@ -0,0 +1,10 @@ +app "app" + packages { pf: "platform.roc" } + imports [] + provides [main] to pf + +main : { f: I64, I64 -> I64 } +main = { f: increment } + +increment : I64, I64 -> I64 +increment = \x, y -> x + y diff --git a/crates/glue/tests/fixtures/return-function/platform.roc b/crates/glue/tests/fixtures/return-function/platform.roc new file mode 100644 index 0000000000..dc74a99c25 --- /dev/null +++ b/crates/glue/tests/fixtures/return-function/platform.roc @@ -0,0 +1,9 @@ +platform "test-platform" + requires {} { main : { f: I64, I64 -> I64 } } + exposes [] + packages {} + imports [] + provides [mainForHost] + +mainForHost : { f: I64, I64 -> I64 } +mainForHost = main diff --git a/crates/glue/tests/fixtures/return-function/src/lib.rs b/crates/glue/tests/fixtures/return-function/src/lib.rs new file mode 100644 index 0000000000..85875c826c --- /dev/null +++ b/crates/glue/tests/fixtures/return-function/src/lib.rs @@ -0,0 +1,61 @@ +mod test_glue; + +#[no_mangle] +pub extern "C" fn rust_main() -> i32 { + let record = test_glue::mainForHost(); + let answer = record.f.force_thunk(42i64, 1); + + println!("Answer was: {:?}", answer); // Debug + + // Exit code + 0 +} + +// Externs required by roc_std and by the Roc app + +use core::ffi::c_void; +use std::ffi::CStr; +use std::os::raw::c_char; + +#[no_mangle] +pub unsafe extern "C" fn roc_alloc(size: usize, _alignment: u32) -> *mut c_void { + return libc::malloc(size); +} + +#[no_mangle] +pub unsafe extern "C" fn roc_realloc( + c_ptr: *mut c_void, + new_size: usize, + _old_size: usize, + _alignment: u32, +) -> *mut c_void { + return libc::realloc(c_ptr, new_size); +} + +#[no_mangle] +pub unsafe extern "C" fn roc_dealloc(c_ptr: *mut c_void, _alignment: u32) { + return libc::free(c_ptr); +} + +#[no_mangle] +pub unsafe extern "C" fn roc_panic(c_ptr: *mut c_void, tag_id: u32) { + match tag_id { + 0 => { + let slice = CStr::from_ptr(c_ptr as *const c_char); + let string = slice.to_str().unwrap(); + eprintln!("Roc hit a panic: {}", string); + std::process::exit(1); + } + _ => todo!(), + } +} + +#[no_mangle] +pub unsafe extern "C" fn roc_memcpy(dst: *mut c_void, src: *mut c_void, n: usize) -> *mut c_void { + libc::memcpy(dst, src, n) +} + +#[no_mangle] +pub unsafe extern "C" fn roc_memset(dst: *mut c_void, c: i32, n: usize) -> *mut c_void { + libc::memset(dst, c, n) +} diff --git a/crates/glue/tests/test_glue_cli.rs b/crates/glue/tests/test_glue_cli.rs index cbd6d1c0fa..16658c0d0f 100644 --- a/crates/glue/tests/test_glue_cli.rs +++ b/crates/glue/tests/test_glue_cli.rs @@ -133,6 +133,9 @@ mod glue_cli_run { Answer was: "Hello World!" Answer was: discriminant_U1::None "#), + return_function:"return-function" => indoc!(r#" + Answer was: 43 + "#), } fn check_for_tests(all_fixtures: &mut roc_collections::VecSet) {