Merge pull request #8705 from roc-lang/cor-mono2

Expand monomorphization + tests
This commit is contained in:
Richard Feldman 2025-12-22 22:09:45 -05:00 committed by GitHub
commit 35da162c5b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 1181 additions and 14 deletions

View file

@ -0,0 +1,847 @@
//! Closure Transformer
//!
//! Transforms closures with captures into tagged values with explicit capture records.
//! This is the first step of lambda set specialization following the Cor approach.
//!
//! ## Transformation Example
//!
//! Input:
//! ```roc
//! {
//! x = 42
//! addX = |y| x + y
//! addX(10)
//! }
//! ```
//!
//! Output:
//! ```roc
//! {
//! x = 42
//! addX = #addX({ x: x })
//! match addX {
//! #addX({ x }) => {
//! y = 10
//! (x + y)
//! },
//! }
//! }
//! ```
//!
//! ## Implementation Notes
//!
//! - Closures become tags with capture records (using `#` prefix to avoid clashing with userspace tags)
//! - Call sites become inline match expressions that dispatch based on the lambda set
//! - Pure lambdas (no captures) become tags with empty records
const std = @import("std");
const base = @import("base");
const ModuleEnv = @import("ModuleEnv.zig");
const CIR = @import("CIR.zig");
const Expr = CIR.Expr;
const Pattern = @import("Pattern.zig").Pattern;
const RecordField = CIR.RecordField;
const Self = @This();
/// Information about a transformed closure
pub const ClosureInfo = struct {
/// The tag name for this closure (e.g., `addX`)
tag_name: base.Ident.Idx,
/// The lambda body expression
lambda_body: Expr.Idx,
/// The lambda arguments
lambda_args: CIR.Pattern.Span,
/// The capture names (for generating dispatch function patterns)
capture_names: std.ArrayList(base.Ident.Idx),
};
/// Information for generating a dispatch function
pub const DispatchFunction = struct {
/// Name of the dispatch function (e.g., `call_addX`)
name: base.Ident.Idx,
/// The closures that can reach this call site
closures: std.ArrayList(ClosureInfo),
};
/// The allocator for intermediate allocations
allocator: std.mem.Allocator,
/// The module environment containing the CIR (mutable for adding new expressions)
module_env: *ModuleEnv,
/// Counter for generating unique closure names
closure_counter: u32,
/// Map from original closure expression to its transformation info
closures: std.AutoHashMap(Expr.Idx, ClosureInfo),
/// Map from pattern index to closure info (for tracking which variables hold closures)
pattern_closures: std.AutoHashMap(CIR.Pattern.Idx, ClosureInfo),
/// List of dispatch functions to generate
dispatch_functions: std.ArrayList(DispatchFunction),
/// Initialize the transformer
pub fn init(allocator: std.mem.Allocator, module_env: *ModuleEnv) Self {
return .{
.allocator = allocator,
.module_env = module_env,
.closure_counter = 0,
.closures = std.AutoHashMap(Expr.Idx, ClosureInfo).init(allocator),
.pattern_closures = std.AutoHashMap(CIR.Pattern.Idx, ClosureInfo).init(allocator),
.dispatch_functions = std.ArrayList(DispatchFunction).empty,
};
}
/// Free resources
pub fn deinit(self: *Self) void {
// Free capture name lists
var closure_iter = self.closures.valueIterator();
while (closure_iter.next()) |info| {
info.capture_names.deinit(self.allocator);
}
self.closures.deinit();
// pattern_closures shares ClosureInfo with closures, don't double-free
self.pattern_closures.deinit();
// Free dispatch function closure lists
for (self.dispatch_functions.items) |*df| {
df.closures.deinit(self.allocator);
}
self.dispatch_functions.deinit(self.allocator);
}
/// Generate a unique tag name for a closure
pub fn generateClosureTagName(self: *Self, hint: ?base.Ident.Idx) !base.Ident.Idx {
self.closure_counter += 1;
// If we have a hint (e.g., from the variable name), use it
if (hint) |h| {
const hint_name = self.module_env.getIdent(h);
// Use # prefix since it's Roc's comment syntax and can't clash with userspace tags
// e.g., "myFunc" becomes "#myFunc"
const tag_name = try std.fmt.allocPrint(
self.allocator,
"#{s}",
.{hint_name},
);
defer self.allocator.free(tag_name);
return try self.module_env.insertIdent(base.Ident.for_text(tag_name));
}
// Otherwise generate a numeric name
const tag_name = try std.fmt.allocPrint(
self.allocator,
"#{d}",
.{self.closure_counter},
);
defer self.allocator.free(tag_name);
return try self.module_env.insertIdent(base.Ident.for_text(tag_name));
}
/// Generate a dispatch match expression for a closure call.
///
/// Transforms a call like `f(10)` where `f` is a closure into:
/// ```roc
/// match f {
/// #f({ x }) => {
/// y = 10 # Bind call arguments to lambda parameters
/// x + y # Original lambda body
/// }
/// }
/// ```
fn generateDispatchMatch(
self: *Self,
closure_var_expr: Expr.Idx,
closure_info: ClosureInfo,
call_args: []const Expr.Idx,
) !Expr.Idx {
// Step 1: Create the capture record destructure pattern
// For `{ x, y }` we need a record_destructure with each field
const record_destruct_start = self.module_env.store.scratchRecordDestructTop();
for (closure_info.capture_names.items) |capture_name| {
// Create an assign pattern for this capture binding
const assign_pattern = try self.module_env.store.addPattern(
Pattern{ .assign = .{ .ident = capture_name } },
base.Region.zero(),
);
// Create the record destruct for this field
const destruct = Pattern.RecordDestruct{
.label = capture_name,
.ident = capture_name,
.kind = .{ .Required = assign_pattern },
};
const destruct_idx = try self.module_env.store.addRecordDestruct(destruct, base.Region.zero());
try self.module_env.store.addScratchRecordDestruct(destruct_idx);
}
const destructs_span = try self.module_env.store.recordDestructSpanFrom(record_destruct_start);
// Create the record destructure pattern
const record_pattern = try self.module_env.store.addPattern(
Pattern{ .record_destructure = .{ .destructs = destructs_span } },
base.Region.zero(),
);
// Step 2: Create the applied_tag pattern: `f({ x, y })
// The tag pattern takes the record pattern as its single argument
const pattern_args_start = self.module_env.store.scratchPatternTop();
try self.module_env.store.addScratchPattern(record_pattern);
const pattern_args_span = try self.module_env.store.patternSpanFrom(pattern_args_start);
const tag_pattern = try self.module_env.store.addPattern(
Pattern{ .applied_tag = .{
.name = closure_info.tag_name,
.args = pattern_args_span,
} },
base.Region.zero(),
);
// Step 3: Create the body - a block that binds arguments then executes lambda body
// We need to bind each call argument to the corresponding lambda parameter
const lambda_params = self.module_env.store.slicePatterns(closure_info.lambda_args);
// If we have arguments to bind, create a block with let bindings
const body_expr = if (call_args.len > 0 and lambda_params.len > 0) blk: {
const stmt_start = self.module_env.store.scratch.?.statements.top();
// Bind each argument to its parameter
const num_args = @min(call_args.len, lambda_params.len);
for (0..num_args) |i| {
const param_pattern = lambda_params[i];
const arg_expr = call_args[i];
const stmt = CIR.Statement{ .s_decl = .{
.pattern = param_pattern,
.expr = arg_expr,
.anno = null,
} };
const stmt_idx = try self.module_env.store.addStatement(stmt, base.Region.zero());
try self.module_env.store.scratch.?.statements.append(stmt_idx);
}
const stmts_span = try self.module_env.store.statementSpanFrom(stmt_start);
// Create block with bindings and lambda body as final expression
break :blk try self.module_env.store.addExpr(Expr{
.e_block = .{
.stmts = stmts_span,
.final_expr = closure_info.lambda_body,
},
}, base.Region.zero());
} else blk: {
// No arguments, just use the lambda body directly
break :blk closure_info.lambda_body;
};
// Step 4: Create the match branch
const branch_pattern_start = self.module_env.store.scratchMatchBranchPatternTop();
const branch_pattern = try self.module_env.store.addMatchBranchPattern(
Expr.Match.BranchPattern{
.pattern = tag_pattern,
.degenerate = false,
},
base.Region.zero(),
);
try self.module_env.store.addScratchMatchBranchPattern(branch_pattern);
const branch_patterns_span = try self.module_env.store.matchBranchPatternSpanFrom(branch_pattern_start);
// Create a fresh type variable for the redundant field
const redundant_var = try self.module_env.types.fresh();
const branch = Expr.Match.Branch{
.patterns = branch_patterns_span,
.value = body_expr,
.guard = null,
.redundant = redundant_var,
};
const branch_idx = try self.module_env.store.addMatchBranch(branch, base.Region.zero());
// Step 5: Create the match expression
const branch_start = self.module_env.store.scratchMatchBranchTop();
try self.module_env.store.addScratchMatchBranch(branch_idx);
const branches_span = try self.module_env.store.matchBranchSpanFrom(branch_start);
// Create a fresh type variable for exhaustiveness
const exhaustive_var = try self.module_env.types.fresh();
return try self.module_env.store.addExpr(Expr{
.e_match = .{
.cond = closure_var_expr,
.branches = branches_span,
.exhaustive = exhaustive_var,
},
}, base.Region.zero());
}
/// Transform a closure expression into a tag with capture record.
/// Returns the new expression index.
pub fn transformClosure(
self: *Self,
closure_expr_idx: Expr.Idx,
binding_name_hint: ?base.Ident.Idx,
) !Expr.Idx {
const expr = self.module_env.store.getExpr(closure_expr_idx);
switch (expr) {
.e_closure => |closure| {
// Get the lambda body and args
const lambda_expr = self.module_env.store.getExpr(closure.lambda_idx);
const lambda = switch (lambda_expr) {
.e_lambda => |l| l,
else => return closure_expr_idx, // Not a lambda, return as-is
};
// Generate tag name
const tag_name = try self.generateClosureTagName(binding_name_hint);
// Get captures
const captures = self.module_env.store.sliceCaptures(closure.captures);
// Build capture record fields
const scratch_top = self.module_env.store.scratch.?.record_fields.top();
var capture_names = std.ArrayList(base.Ident.Idx).empty;
for (captures) |capture_idx| {
const capture = self.module_env.store.getCapture(capture_idx);
// Create a lookup expression for the captured variable
// Use store.addExpr directly to avoid region sync checks during transformation
const lookup_expr = try self.module_env.store.addExpr(Expr{
.e_lookup_local = .{ .pattern_idx = capture.pattern_idx },
}, base.Region.zero());
// Create record field: { capture_name: capture_value }
const field = RecordField{
.name = capture.name,
.value = lookup_expr,
};
const field_idx = try self.module_env.store.addRecordField(field, base.Region.zero());
try self.module_env.store.scratch.?.record_fields.append(field_idx);
try capture_names.append(self.allocator, capture.name);
}
// Create the record expression
const fields_span = try self.module_env.store.recordFieldSpanFrom(scratch_top);
const record_expr = if (captures.len > 0)
try self.module_env.store.addExpr(Expr{
.e_record = .{ .fields = fields_span, .ext = null },
}, base.Region.zero())
else
try self.module_env.store.addExpr(Expr{
.e_empty_record = .{},
}, base.Region.zero());
// Create the tag expression: `tagName(captureRecord)
// First, add the record as an argument
const args_start = self.module_env.store.scratch.?.exprs.top();
try self.module_env.store.scratch.?.exprs.append(record_expr);
const args_span = try self.module_env.store.exprSpanFrom(args_start);
const tag_expr = try self.module_env.store.addExpr(Expr{
.e_tag = .{
.name = tag_name,
.args = args_span,
},
}, base.Region.zero());
// Store closure info for dispatch function generation
try self.closures.put(closure_expr_idx, ClosureInfo{
.tag_name = tag_name,
.lambda_body = lambda.body,
.lambda_args = lambda.args,
.capture_names = capture_names,
});
return tag_expr;
},
.e_lambda => |lambda| {
// Pure lambda (no captures) - still wrap in a tag with empty record
const tag_name = try self.generateClosureTagName(binding_name_hint);
const empty_record = try self.module_env.store.addExpr(Expr{
.e_empty_record = .{},
}, base.Region.zero());
const args_start = self.module_env.store.scratch.?.exprs.top();
try self.module_env.store.scratch.?.exprs.append(empty_record);
const args_span = try self.module_env.store.exprSpanFrom(args_start);
const tag_expr = try self.module_env.store.addExpr(Expr{
.e_tag = .{
.name = tag_name,
.args = args_span,
},
}, base.Region.zero());
// Store info for dispatch
try self.closures.put(closure_expr_idx, ClosureInfo{
.tag_name = tag_name,
.lambda_body = lambda.body,
.lambda_args = lambda.args,
.capture_names = std.ArrayList(base.Ident.Idx).empty,
});
return tag_expr;
},
else => return closure_expr_idx, // Not a closure, return as-is
}
}
/// Transform an entire expression tree, handling closures and their call sites.
/// This is the main entry point for the transformation.
pub fn transformExpr(self: *Self, expr_idx: Expr.Idx) !Expr.Idx {
const expr = self.module_env.store.getExpr(expr_idx);
switch (expr) {
.e_closure => {
// Transform closure to tag
return try self.transformClosure(expr_idx, null);
},
.e_lambda => {
// Transform pure lambda to tag
return try self.transformClosure(expr_idx, null);
},
.e_block => |block| {
// Transform block: handle statements and final expression
const stmts = self.module_env.store.sliceStatements(block.stmts);
// Create new statements with transformed expressions
const stmt_start = self.module_env.store.scratch.?.statements.top();
for (stmts) |stmt_idx| {
const stmt = self.module_env.store.getStatement(stmt_idx);
switch (stmt) {
.s_decl => |decl| {
// Get binding name hint from pattern
const pattern = self.module_env.store.getPattern(decl.pattern);
const name_hint: ?base.Ident.Idx = switch (pattern) {
.assign => |a| a.ident,
else => null,
};
// Check if this is a closure binding
const decl_expr = self.module_env.store.getExpr(decl.expr);
const new_expr = switch (decl_expr) {
.e_closure, .e_lambda => blk: {
const transformed = try self.transformClosure(decl.expr, name_hint);
// Track this pattern as holding a closure
if (self.closures.get(decl.expr)) |closure_info| {
try self.pattern_closures.put(decl.pattern, closure_info);
}
break :blk transformed;
},
else => try self.transformExpr(decl.expr),
};
// Create new statement with transformed expression
const new_stmt_idx = try self.module_env.store.addStatement(
CIR.Statement{ .s_decl = .{
.pattern = decl.pattern,
.expr = new_expr,
.anno = decl.anno,
} },
base.Region.zero(),
);
try self.module_env.store.scratch.?.statements.append(new_stmt_idx);
},
.s_decl_gen => |decl| {
const pattern = self.module_env.store.getPattern(decl.pattern);
const name_hint: ?base.Ident.Idx = switch (pattern) {
.assign => |a| a.ident,
else => null,
};
const decl_expr = self.module_env.store.getExpr(decl.expr);
const new_expr = switch (decl_expr) {
.e_closure, .e_lambda => blk: {
const transformed = try self.transformClosure(decl.expr, name_hint);
// Track this pattern as holding a closure
if (self.closures.get(decl.expr)) |closure_info| {
try self.pattern_closures.put(decl.pattern, closure_info);
}
break :blk transformed;
},
else => try self.transformExpr(decl.expr),
};
const new_stmt_idx = try self.module_env.store.addStatement(
CIR.Statement{ .s_decl_gen = .{
.pattern = decl.pattern,
.expr = new_expr,
.anno = decl.anno,
} },
base.Region.zero(),
);
try self.module_env.store.scratch.?.statements.append(new_stmt_idx);
},
else => {
// Copy statement as-is
try self.module_env.store.scratch.?.statements.append(stmt_idx);
},
}
}
const new_stmts_span = try self.module_env.store.statementSpanFrom(stmt_start);
// Transform final expression
const new_final = try self.transformExpr(block.final_expr);
// Create new block
return try self.module_env.store.addExpr(Expr{
.e_block = .{
.stmts = new_stmts_span,
.final_expr = new_final,
},
}, base.Region.zero());
},
.e_call => |call| {
// First transform arguments recursively
const args = self.module_env.store.sliceExpr(call.args);
const args_start = self.module_env.store.scratch.?.exprs.top();
for (args) |arg_idx| {
const new_arg = try self.transformExpr(arg_idx);
try self.module_env.store.scratch.?.exprs.append(new_arg);
}
const new_args_span = try self.module_env.store.exprSpanFrom(args_start);
const transformed_args = self.module_env.store.sliceExpr(new_args_span);
// Check if the function is a local variable that holds a closure
const func_expr = self.module_env.store.getExpr(call.func);
switch (func_expr) {
.e_lookup_local => |lookup| {
// Check if this pattern was assigned a closure
if (self.pattern_closures.get(lookup.pattern_idx)) |closure_info| {
// Generate a dispatch match expression
return try self.generateDispatchMatch(
call.func,
closure_info,
transformed_args,
);
}
},
else => {},
}
// Not a closure call, transform normally
const new_func = try self.transformExpr(call.func);
return try self.module_env.store.addExpr(Expr{
.e_call = .{
.func = new_func,
.args = new_args_span,
.called_via = call.called_via,
},
}, base.Region.zero());
},
.e_if => |if_expr| {
const branches = self.module_env.store.sliceIfBranches(if_expr.branches);
const branch_start = self.module_env.store.scratch.?.if_branches.top();
for (branches) |branch_idx| {
const branch = self.module_env.store.getIfBranch(branch_idx);
const new_cond = try self.transformExpr(branch.cond);
const new_body = try self.transformExpr(branch.body);
const new_branch_idx = try self.module_env.store.addIfBranch(
Expr.IfBranch{ .cond = new_cond, .body = new_body },
base.Region.zero(),
);
try self.module_env.store.scratch.?.if_branches.append(new_branch_idx);
}
const new_branches_span = try self.module_env.store.ifBranchSpanFrom(branch_start);
const new_else = try self.transformExpr(if_expr.final_else);
return try self.module_env.store.addExpr(Expr{
.e_if = .{
.branches = new_branches_span,
.final_else = new_else,
},
}, base.Region.zero());
},
.e_binop => |binop| {
const new_lhs = try self.transformExpr(binop.lhs);
const new_rhs = try self.transformExpr(binop.rhs);
return try self.module_env.store.addExpr(Expr{
.e_binop = .{
.op = binop.op,
.lhs = new_lhs,
.rhs = new_rhs,
},
}, base.Region.zero());
},
// Pass through simple expressions unchanged
.e_num,
.e_frac_f32,
.e_frac_f64,
.e_dec,
.e_dec_small,
.e_str_segment,
.e_str,
.e_lookup_local,
.e_lookup_external,
.e_empty_list,
.e_empty_record,
.e_zero_argument_tag,
.e_runtime_error,
.e_ellipsis,
.e_anno_only,
.e_lookup_required,
.e_type_var_dispatch,
.e_hosted_lambda,
.e_low_level_lambda,
=> return expr_idx,
.e_list => |list| {
const elems = self.module_env.store.sliceExpr(list.elems);
const elems_start = self.module_env.store.scratch.?.exprs.top();
for (elems) |elem_idx| {
const new_elem = try self.transformExpr(elem_idx);
try self.module_env.store.scratch.?.exprs.append(new_elem);
}
const new_elems_span = try self.module_env.store.exprSpanFrom(elems_start);
return try self.module_env.store.addExpr(Expr{
.e_list = .{ .elems = new_elems_span },
}, base.Region.zero());
},
.e_tuple => |tuple| {
const elems = self.module_env.store.sliceExpr(tuple.elems);
const elems_start = self.module_env.store.scratch.?.exprs.top();
for (elems) |elem_idx| {
const new_elem = try self.transformExpr(elem_idx);
try self.module_env.store.scratch.?.exprs.append(new_elem);
}
const new_elems_span = try self.module_env.store.exprSpanFrom(elems_start);
return try self.module_env.store.addExpr(Expr{
.e_tuple = .{ .elems = new_elems_span },
}, base.Region.zero());
},
.e_record => |record| {
const field_indices = self.module_env.store.sliceRecordFields(record.fields);
const fields_start = self.module_env.store.scratch.?.record_fields.top();
for (field_indices) |field_idx| {
const field = self.module_env.store.getRecordField(field_idx);
const new_value = try self.transformExpr(field.value);
const new_field = RecordField{
.name = field.name,
.value = new_value,
};
const new_field_idx = try self.module_env.store.addRecordField(new_field, base.Region.zero());
try self.module_env.store.scratch.?.record_fields.append(new_field_idx);
}
const new_fields_span = try self.module_env.store.recordFieldSpanFrom(fields_start);
const new_ext = if (record.ext) |ext| try self.transformExpr(ext) else null;
return try self.module_env.store.addExpr(Expr{
.e_record = .{
.fields = new_fields_span,
.ext = new_ext,
},
}, base.Region.zero());
},
.e_tag => |tag| {
const args = self.module_env.store.sliceExpr(tag.args);
const args_start = self.module_env.store.scratch.?.exprs.top();
for (args) |arg_idx| {
const new_arg = try self.transformExpr(arg_idx);
try self.module_env.store.scratch.?.exprs.append(new_arg);
}
const new_args_span = try self.module_env.store.exprSpanFrom(args_start);
return try self.module_env.store.addExpr(Expr{
.e_tag = .{
.name = tag.name,
.args = new_args_span,
},
}, base.Region.zero());
},
.e_unary_minus => |unary| {
const new_expr = try self.transformExpr(unary.expr);
return try self.module_env.store.addExpr(Expr{
.e_unary_minus = .{ .expr = new_expr },
}, base.Region.zero());
},
.e_unary_not => |unary| {
const new_expr = try self.transformExpr(unary.expr);
return try self.module_env.store.addExpr(Expr{
.e_unary_not = .{ .expr = new_expr },
}, base.Region.zero());
},
.e_dot_access => |dot| {
const new_receiver = try self.transformExpr(dot.receiver);
const new_args = if (dot.args) |args_span| blk: {
const args = self.module_env.store.sliceExpr(args_span);
const args_start = self.module_env.store.scratch.?.exprs.top();
for (args) |arg_idx| {
const new_arg = try self.transformExpr(arg_idx);
try self.module_env.store.scratch.?.exprs.append(new_arg);
}
break :blk try self.module_env.store.exprSpanFrom(args_start);
} else null;
return try self.module_env.store.addExpr(Expr{
.e_dot_access = .{
.receiver = new_receiver,
.field_name = dot.field_name,
.field_name_region = dot.field_name_region,
.args = new_args,
},
}, base.Region.zero());
},
.e_crash => return expr_idx,
.e_dbg => |dbg| {
const new_expr = try self.transformExpr(dbg.expr);
return try self.module_env.store.addExpr(Expr{
.e_dbg = .{
.expr = new_expr,
},
}, base.Region.zero());
},
.e_expect => |expect| {
const new_body = try self.transformExpr(expect.body);
return try self.module_env.store.addExpr(Expr{
.e_expect = .{
.body = new_body,
},
}, base.Region.zero());
},
.e_return => |ret| {
const new_expr = try self.transformExpr(ret.expr);
return try self.module_env.store.addExpr(Expr{
.e_return = .{ .expr = new_expr },
}, base.Region.zero());
},
.e_match => |match| {
const new_cond = try self.transformExpr(match.cond);
// Note: match branches would need deeper transformation for closures in branches
// For now, pass through as-is
return try self.module_env.store.addExpr(Expr{
.e_match = .{
.cond = new_cond,
.branches = match.branches,
.exhaustive = match.exhaustive,
},
}, base.Region.zero());
},
.e_nominal => |nominal| {
const new_backing = try self.transformExpr(nominal.backing_expr);
return try self.module_env.store.addExpr(Expr{
.e_nominal = .{
.nominal_type_decl = nominal.nominal_type_decl,
.backing_expr = new_backing,
.backing_type = nominal.backing_type,
},
}, base.Region.zero());
},
.e_nominal_external => |nominal| {
const new_backing = try self.transformExpr(nominal.backing_expr);
return try self.module_env.store.addExpr(Expr{
.e_nominal_external = .{
.module_idx = nominal.module_idx,
.target_node_idx = nominal.target_node_idx,
.backing_expr = new_backing,
.backing_type = nominal.backing_type,
},
}, base.Region.zero());
},
.e_for => |for_expr| {
const new_expr = try self.transformExpr(for_expr.expr);
const new_body = try self.transformExpr(for_expr.body);
return try self.module_env.store.addExpr(Expr{
.e_for = .{
.patt = for_expr.patt,
.expr = new_expr,
.body = new_body,
},
}, base.Region.zero());
},
}
}
// Tests
const testing = std.testing;
test "ClosureTransformer: init and deinit" {
const allocator = testing.allocator;
const module_env = try allocator.create(ModuleEnv);
module_env.* = try ModuleEnv.init(allocator, "test");
defer {
module_env.deinit();
allocator.destroy(module_env);
}
var transformer = Self.init(allocator, module_env);
defer transformer.deinit();
try testing.expectEqual(@as(u32, 0), transformer.closure_counter);
}
test "ClosureTransformer: generateClosureTagName with hint" {
const allocator = testing.allocator;
const module_env = try allocator.create(ModuleEnv);
module_env.* = try ModuleEnv.init(allocator, "test");
defer {
module_env.deinit();
allocator.destroy(module_env);
}
var transformer = Self.init(allocator, module_env);
defer transformer.deinit();
// Create a hint identifier
const hint = try module_env.insertIdent(base.Ident.for_text("addX"));
const tag_name = try transformer.generateClosureTagName(hint);
const tag_str = module_env.getIdent(tag_name);
try testing.expectEqualStrings("#addX", tag_str);
}
test "ClosureTransformer: generateClosureTagName without hint" {
const allocator = testing.allocator;
const module_env = try allocator.create(ModuleEnv);
module_env.* = try ModuleEnv.init(allocator, "test");
defer {
module_env.deinit();
allocator.destroy(module_env);
}
var transformer = Self.init(allocator, module_env);
defer transformer.deinit();
const tag_name = try transformer.generateClosureTagName(null);
const tag_str = module_env.getIdent(tag_name);
try testing.expectEqualStrings("#1", tag_str);
}

View file

@ -147,20 +147,6 @@ pub fn createSpecializedName(
return specialized_ident;
}
/// Monomorphize a module's top-level expressions
/// Returns a new ModuleEnv with specialized functions
pub fn monomorphize() !void {
// Phase 1: Just traverse and identify what needs specialization
// For now, this is a placeholder for the full implementation
// In a full implementation, we would:
// 1. Find all top-level function definitions
// 2. Analyze their types to find polymorphic ones
// 3. Find all call sites and their concrete types
// 4. Create specialized versions
// 5. Update call sites to use specialized versions
}
// Tests
const testing = std.testing;

View file

@ -18,6 +18,8 @@ pub const HostedCompiler = @import("HostedCompiler.zig");
pub const RocEmitter = @import("RocEmitter.zig");
/// Monomorphizer - specializes polymorphic functions to concrete types
pub const Monomorphizer = @import("Monomorphizer.zig");
/// Closure Transformer - transforms closures with captures into tagged values
pub const ClosureTransformer = @import("ClosureTransformer.zig");
test "compile tests" {
std.testing.refAllDecls(@This());
@ -54,4 +56,5 @@ test "compile tests" {
// Monomorphization
std.testing.refAllDecls(@import("Monomorphizer.zig"));
std.testing.refAllDecls(@import("ClosureTransformer.zig"));
}

View file

@ -346,3 +346,334 @@ test "roundtrip: complex arithmetic produces same result" {
try testing.expectEqual(original_result, emitted_result);
try testing.expectEqual(@as(i128, 16), emitted_result);
}
/// Helper to check if source code contains a closure with captures
fn hasClosureWithCaptures(allocator: std.mem.Allocator, source: []const u8) !bool {
const resources = try helpers.parseAndCanonicalizeExpr(allocator, source);
defer helpers.cleanupParseAndCanonical(allocator, resources);
// Recursively check if any expression is a closure with captures
return checkForCapturesRecursive(resources.module_env, resources.expr_idx);
}
fn checkForCapturesRecursive(module_env: *can.ModuleEnv, expr_idx: can.CIR.Expr.Idx) bool {
const expr = module_env.store.getExpr(expr_idx);
switch (expr) {
.e_closure => |closure| {
if (closure.captures.span.len > 0) {
return true;
}
// Also check the lambda body
return checkForCapturesRecursive(module_env, closure.lambda_idx);
},
.e_lambda => |lambda| {
return checkForCapturesRecursive(module_env, lambda.body);
},
.e_block => |block| {
// Check statements
const stmts = module_env.store.sliceStatements(block.stmts);
for (stmts) |stmt_idx| {
const stmt = module_env.store.getStatement(stmt_idx);
switch (stmt) {
.s_decl => |decl| {
if (checkForCapturesRecursive(module_env, decl.expr)) {
return true;
}
},
.s_decl_gen => |decl| {
if (checkForCapturesRecursive(module_env, decl.expr)) {
return true;
}
},
else => {},
}
}
// Check final expression
return checkForCapturesRecursive(module_env, block.final_expr);
},
.e_call => |call| {
if (checkForCapturesRecursive(module_env, call.func)) {
return true;
}
const args = module_env.store.sliceExpr(call.args);
for (args) |arg_idx| {
if (checkForCapturesRecursive(module_env, arg_idx)) {
return true;
}
}
return false;
},
.e_if => |if_expr| {
const branches = module_env.store.sliceIfBranches(if_expr.branches);
for (branches) |branch_idx| {
const branch = module_env.store.getIfBranch(branch_idx);
if (checkForCapturesRecursive(module_env, branch.cond) or
checkForCapturesRecursive(module_env, branch.body))
{
return true;
}
}
return checkForCapturesRecursive(module_env, if_expr.final_else);
},
.e_binop => |binop| {
return checkForCapturesRecursive(module_env, binop.lhs) or
checkForCapturesRecursive(module_env, binop.rhs);
},
else => return false,
}
}
test "detect closure with single capture" {
const source =
\\{
\\ x = 42
\\ f = |y| x + y
\\ f(10)
\\}
;
const has_captures = try hasClosureWithCaptures(test_allocator, source);
try testing.expect(has_captures);
}
test "detect closure with multiple captures" {
const source =
\\{
\\ a = 1
\\ b = 2
\\ f = |x| a + b + x
\\ f(3)
\\}
;
const has_captures = try hasClosureWithCaptures(test_allocator, source);
try testing.expect(has_captures);
}
test "detect pure lambda (no captures)" {
const source =
\\{
\\ f = |x| x + 1
\\ f(41)
\\}
;
const has_captures = try hasClosureWithCaptures(test_allocator, source);
try testing.expect(!has_captures);
}
/// Helper to transform a single closure expression directly (not in a block)
fn transformClosureExpr(allocator: std.mem.Allocator, source: []const u8) ![]const u8 {
const resources = try helpers.parseAndCanonicalizeExpr(allocator, source);
defer helpers.cleanupParseAndCanonical(allocator, resources);
// Create transformer
var transformer = can.ClosureTransformer.init(allocator, resources.module_env);
defer transformer.deinit();
// Transform just the expression (not the block around it)
const transformed_idx = try transformer.transformClosure(resources.expr_idx, null);
// Emit the transformed expression
var emitter = Emitter.init(allocator, resources.module_env);
defer emitter.deinit();
try emitter.emitExpr(transformed_idx);
return try allocator.dupe(u8, emitter.getOutput());
}
test "transform pure lambda to tag" {
// Test a pure lambda (no captures) - parsed directly without a block
const source = "|x| x + 1";
const output = try transformClosureExpr(test_allocator, source);
defer test_allocator.free(output);
// Pure lambda should be transformed to a # tag with empty record
try testing.expect(std.mem.indexOf(u8, output, "#") != null);
try testing.expect(std.mem.indexOf(u8, output, "{}") != null);
}
/// Helper to transform closures in a block and emit the result
fn transformBlockAndEmit(allocator: std.mem.Allocator, source: []const u8) ![]const u8 {
const resources = try helpers.parseAndCanonicalizeExpr(allocator, source);
defer helpers.cleanupParseAndCanonical(allocator, resources);
// Create transformer
var transformer = can.ClosureTransformer.init(allocator, resources.module_env);
defer transformer.deinit();
// Transform the entire expression tree
const transformed_idx = try transformer.transformExpr(resources.expr_idx);
// Emit the transformed expression
var emitter = Emitter.init(allocator, resources.module_env);
defer emitter.deinit();
try emitter.emitExpr(transformed_idx);
return try allocator.dupe(u8, emitter.getOutput());
}
test "transform closure with single capture to tag" {
const source =
\\{
\\ x = 42
\\ f = |y| x + y
\\ f(10)
\\}
;
const output = try transformBlockAndEmit(test_allocator, source);
defer test_allocator.free(output);
// The closure should have been transformed to a # tag
try testing.expect(std.mem.indexOf(u8, output, "#") != null);
// The capture 'x' should appear in the tag's record argument
try testing.expect(std.mem.indexOf(u8, output, "x:") != null or
std.mem.indexOf(u8, output, "{x") != null);
// The call should have been transformed to a match expression
try testing.expect(std.mem.indexOf(u8, output, "match") != null);
}
test "transform closure with multiple captures" {
const source =
\\{
\\ a = 1
\\ b = 2
\\ f = |x| a + b + x
\\ f(3)
\\}
;
const output = try transformBlockAndEmit(test_allocator, source);
defer test_allocator.free(output);
// The closure should have been transformed to a # tag
try testing.expect(std.mem.indexOf(u8, output, "#") != null);
// Both captures 'a' and 'b' should appear in the tag's record
try testing.expect(std.mem.indexOf(u8, output, "a:") != null or
std.mem.indexOf(u8, output, "{a") != null);
try testing.expect(std.mem.indexOf(u8, output, "b:") != null or
std.mem.indexOf(u8, output, ", b") != null);
// The call should have been transformed to a match expression
try testing.expect(std.mem.indexOf(u8, output, "match") != null);
}
test "verify: closure with single capture transforms correctly" {
const source =
\\{
\\ x = 42
\\ f = |y| x + y
\\ f(10)
\\}
;
// Transform the code
const transformed = try transformBlockAndEmit(test_allocator, source);
defer test_allocator.free(transformed);
// Verify transformation structure:
// - Should have a # tag (internal closure tag)
// - Should have a match expression for the call site
// - Should reference the captured variable x
try testing.expect(std.mem.indexOf(u8, transformed, "#") != null);
try testing.expect(std.mem.indexOf(u8, transformed, "match") != null);
}
test "verify: closure with multiple captures transforms correctly" {
const source =
\\{
\\ a = 1
\\ b = 2
\\ f = |x| a + b + x
\\ f(3)
\\}
;
// Transform the code
const transformed = try transformBlockAndEmit(test_allocator, source);
defer test_allocator.free(transformed);
// Verify transformation structure:
// - Should have a # tag (internal closure tag)
// - Should have a match expression for the call site
try testing.expect(std.mem.indexOf(u8, transformed, "#") != null);
try testing.expect(std.mem.indexOf(u8, transformed, "match") != null);
}
test "verify: pure lambda (no captures) transforms correctly" {
const source =
\\{
\\ f = |x| x + 1
\\ f(41)
\\}
;
// Transform the code
const transformed = try transformBlockAndEmit(test_allocator, source);
defer test_allocator.free(transformed);
// Verify transformation structure:
// - Should have a # tag (internal closure tag)
// - Should have a match expression for the call site
// - Pure lambdas should have empty record {}
try testing.expect(std.mem.indexOf(u8, transformed, "#") != null);
try testing.expect(std.mem.indexOf(u8, transformed, "match") != null);
try testing.expect(std.mem.indexOf(u8, transformed, "{}") != null);
}
test "verify: nested closures with captures transforms correctly" {
// A closure that returns another closure, both with captures
const source =
\\{
\\ x = 10
\\ makeAdder = |y| |z| x + y + z
\\ addFive = makeAdder(5)
\\ addFive(3)
\\}
;
// Transform the code
const transformed = try transformBlockAndEmit(test_allocator, source);
defer test_allocator.free(transformed);
// Verify transformation structure:
// - Should have # tags (internal closure tags)
// - Should have match expressions for the call sites
try testing.expect(std.mem.indexOf(u8, transformed, "#") != null);
try testing.expect(std.mem.indexOf(u8, transformed, "match") != null);
}
test "ClosureTransformer: can generate tag names" {
// Test that the transformer can generate unique tag names
const allocator = test_allocator;
const module_env = try allocator.create(can.ModuleEnv);
module_env.* = try can.ModuleEnv.init(allocator, "test");
defer {
module_env.deinit();
allocator.destroy(module_env);
}
var transformer = can.ClosureTransformer.init(allocator, module_env);
defer transformer.deinit();
// Generate a tag name with a hint
const hint = try module_env.insertIdent(base.Ident.for_text("myFunc"));
const tag_name1 = try transformer.generateClosureTagName(hint);
const tag_str1 = module_env.getIdent(tag_name1);
try testing.expectEqualStrings("#myFunc", tag_str1);
// Generate another tag name without hint
const tag_name2 = try transformer.generateClosureTagName(null);
const tag_str2 = module_env.getIdent(tag_name2);
try testing.expectEqualStrings("#2", tag_str2);
}