diff --git a/src/eval/interpreter.zig b/src/eval/interpreter.zig index a6e61957bb..42254a1dcf 100644 --- a/src/eval/interpreter.zig +++ b/src/eval/interpreter.zig @@ -92,6 +92,11 @@ pub const Interpreter = struct { } }; const Binding = struct { pattern_idx: can.CIR.Pattern.Idx, value: StackValue }; + const DefInProgress = struct { + pattern_idx: can.CIR.Pattern.Idx, + expr_idx: can.CIR.Expr.Idx, + value: ?StackValue, + }; allocator: std.mem.Allocator, runtime_types: *types.store.Store, runtime_layout_store: layout.Store, @@ -127,6 +132,7 @@ pub const Interpreter = struct { builtins: BuiltinTypes, /// Map from module name to ModuleEnv for resolving e_lookup_external expressions imported_modules: std.StringHashMap(*const can.ModuleEnv), + def_stack: std.array_list.Managed(DefInProgress), pub fn init(allocator: std.mem.Allocator, env: *can.ModuleEnv, builtin_types: BuiltinTypes, imported_modules_map: ?*const std.AutoHashMap(base_pkg.Ident.Idx, can.Can.AutoImportedType)) !Interpreter { // Convert imported modules map to other_envs slice @@ -184,6 +190,7 @@ pub const Interpreter = struct { .scratch_tags = try std.array_list.Managed(types.Tag).initCapacity(allocator, 8), .builtins = builtin_types, .imported_modules = std.StringHashMap(*const can.ModuleEnv).init(allocator), + .def_stack = try std.array_list.Managed(DefInProgress).initCapacity(allocator, 4), }; result.runtime_layout_store = try layout.Store.init(env, result.runtime_types); @@ -195,6 +202,14 @@ pub const Interpreter = struct { return try self.evalExprMinimal(expr_idx, roc_ops, null); } + fn registerDefValue(self: *Interpreter, expr_idx: can.CIR.Expr.Idx, value: StackValue) void { + if (self.def_stack.items.len == 0) return; + var top = &self.def_stack.items[self.def_stack.items.len - 1]; + if (top.expr_idx == expr_idx and top.value == null) { + top.value = value; + } + } + pub fn startTrace(self: *Interpreter) void { _ = self; } @@ -1236,6 +1251,7 @@ pub const Interpreter = struct { // Expect a closure layout from type-to-layout translation if (closure_layout.tag != .closure) return error.NotImplemented; const value = try self.pushRaw(closure_layout, 0); + self.registerDefValue(expr_idx, value); // Initialize the closure header if (value.ptr) |ptr| { const header: *layout.Closure = @ptrCast(@alignCast(ptr)); @@ -1298,7 +1314,24 @@ pub const Interpreter = struct { for (all_defs) |def_idx| { const def = self_interp.env.store.getDef(def_idx); if (def.pattern == cap.pattern_idx) { + var k: usize = self_interp.def_stack.items.len; + while (k > 0) { + k -= 1; + const entry = self_interp.def_stack.items[k]; + if (entry.pattern_idx == cap.pattern_idx) { + if (entry.value) |val| { + return val; + } + } + } // Found the def! Evaluate it to get the captured value + const new_entry = DefInProgress{ + .pattern_idx = def.pattern, + .expr_idx = def.expr, + .value = null, + }; + self_interp.def_stack.append(new_entry) catch return null; + defer _ = self_interp.def_stack.pop(); return self_interp.evalMinimal(def.expr, ops) catch null; } } @@ -1309,14 +1342,16 @@ pub const Interpreter = struct { for (caps, 0..) |cap_idx, i| { const cap = self.env.store.getCapture(cap_idx); field_names[i] = cap.name; - const captured_val = resolveCapture(self, cap, roc_ops) orelse return error.NotImplemented; - field_layouts[i] = captured_val.layout; + const cap_ct_var = can.ModuleEnv.varFrom(cap.pattern_idx); + const cap_rt_var = try self.translateTypeVar(self.env, cap_ct_var); + field_layouts[i] = try self.getRuntimeLayout(cap_rt_var); } const captures_layout_idx = try self.runtime_layout_store.putRecord(field_layouts, field_names); const captures_layout = self.runtime_layout_store.getLayout(captures_layout_idx); const closure_layout = Layout.closure(captures_layout_idx); const value = try self.pushRaw(closure_layout, 0); + self.registerDefValue(expr_idx, value); // Initialize header if (value.ptr) |ptr| { @@ -3321,6 +3356,7 @@ pub const Interpreter = struct { self.stack_memory.deinit(); self.bindings.deinit(); self.active_closures.deinit(); + self.def_stack.deinit(); self.scratch_tags.deinit(); self.imported_modules.deinit(); } diff --git a/test/snapshots/issue/segfault_pr_8315.md b/test/snapshots/issue/segfault_pr_8315.md new file mode 100644 index 0000000000..fc7f30fd4d --- /dev/null +++ b/test/snapshots/issue/segfault_pr_8315.md @@ -0,0 +1,75 @@ +# META +~~~ini +description=Regression test for segfault caused by self-capturing closure (PR #8315) +type=snippet +~~~ +# SOURCE +~~~roc +# Minimal reproduction of segfault bug +# A closure definition that captures itself causes infinite recursion +# during closure construction in the interpreter + +selfCapturing : {} -> U64 +selfCapturing = |{}| selfCapturing({}) +~~~ +# EXPECTED +NIL +# PROBLEMS +NIL +# TOKENS +~~~zig +LowerIdent,OpColon,OpenCurly,CloseCurly,OpArrow,UpperIdent, +LowerIdent,OpAssign,OpBar,OpenCurly,CloseCurly,OpBar,LowerIdent,NoSpaceOpenRound,OpenCurly,CloseCurly,CloseRound, +EndOfFile, +~~~ +# PARSE +~~~clojure +(file + (type-module) + (statements + (s-type-anno (name "selfCapturing") + (ty-fn + (ty-record) + (ty (name "U64")))) + (s-decl + (p-ident (raw "selfCapturing")) + (e-lambda + (args + (p-record)) + (e-apply + (e-ident (raw "selfCapturing")) + (e-record)))))) +~~~ +# FORMATTED +~~~roc +NO CHANGE +~~~ +# CANONICALIZE +~~~clojure +(can-ir + (d-let + (p-assign (ident "selfCapturing")) + (e-closure + (captures + (capture (ident "selfCapturing"))) + (e-lambda + (args + (p-record-destructure + (destructs))) + (e-call + (e-lookup-local + (p-assign (ident "selfCapturing"))) + (e-empty_record)))) + (annotation + (ty-fn (effectful false) + (ty-record) + (ty-lookup (name "U64") (builtin)))))) +~~~ +# TYPES +~~~clojure +(inferred-types + (defs + (patt (type "{ } -> Num(Int(Unsigned64))"))) + (expressions + (expr (type "{ } -> Num(Int(Unsigned64))")))) +~~~