roc/src/types/instantiate.zig
2025-12-14 10:09:54 -05:00

438 lines
19 KiB
Zig
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

//! Type instantiation for Hindley-Milner type inference.
//!
//! This module provides functionality to instantiate polymorphic types with fresh
//! type variables while preserving type aliases and structure. This is a critical
//! component for proper handling of annotated functions in the type system.
const std = @import("std");
const base = @import("base");
const collections = @import("collections");
const types_store = @import("store.zig");
const types_mod = @import("types.zig");
const TypesStore = types_store.Store;
const Var = types_mod.Var;
const Flex = types_mod.Flex;
const StaticDispatchConstraint = types_mod.StaticDispatchConstraint;
const Rigid = types_mod.Rigid;
const Content = types_mod.Content;
const FlatType = types_mod.FlatType;
const Alias = types_mod.Alias;
const Func = types_mod.Func;
const Record = types_mod.Record;
const TagUnion = types_mod.TagUnion;
const RecordField = types_mod.RecordField;
const Tag = types_mod.Tag;
const Num = types_mod.Num;
const NominalType = types_mod.NominalType;
const Tuple = types_mod.Tuple;
const Rank = types_mod.Rank;
const Mark = types_mod.Mark;
const Ident = base.Ident;
/// Type to manage instantiation.
///
/// Entry point is `instantiateVar`
///
/// This type does not own any of it's fields it's a convenience wrapper to
/// making threading it's field through all the recursive functions easier
pub const Instantiator = struct {
// not owned
store: *TypesStore,
idents: *const base.Ident.Store,
var_map: *std.AutoHashMap(Var, Var),
current_rank: Rank = Rank.top_level,
rigid_behavior: RigidBehavior,
/// The mode to use when instantiating
pub const RigidBehavior = union(enum) {
/// In this mode, all rigids are instantiated as new flex vars
/// Note that the the rigid var structure will be preserved.
/// E.g. `a -> a`, `a` will reference the same new rigid var
fresh_flex,
/// In this mode, all rigids are instantiated as new rigid variables
/// Note that the the rigid var structure will be preserved.
/// E.g. `a -> a`, `a` will reference the same new flex var
fresh_rigid,
/// In this mode, all rigids we be substituted with values in the provided map.
/// If a rigid var is not in the map, then that variable will be set to
/// `.err` & in debug mode it will error
substitute_rigids: *std.AutoHashMapUnmanaged(Ident.Idx, Var),
};
const Self = @This();
// instantiation //
/// Instantiate a variable
pub fn instantiateVar(
self: *Self,
initial_var: Var,
) std.mem.Allocator.Error!Var {
const resolved = self.store.resolveVar(initial_var);
const resolved_var = resolved.var_;
// Check if we've already instantiated this variable
if (self.var_map.get(resolved_var)) |fresh_var| {
return fresh_var;
}
switch (resolved.desc.content) {
.rigid => |rigid| {
// If this var is rigid, then create a new var depending on the
// provided behavior
const fresh_type: enum { flex, rigid } = blk: {
switch (self.rigid_behavior) {
.fresh_rigid => {
break :blk .rigid;
},
.fresh_flex => {
break :blk .flex;
},
.substitute_rigids => |rigid_subs| {
// If this is a var that we're substituting, then we
// we just return it.
const existing_var = inner_blk: {
if (rigid_subs.get(rigid.name)) |existing_flex| {
break :inner_blk existing_flex;
} else {
std.debug.assert(false);
break :inner_blk try self.store.freshFromContentWithRank(
.err,
self.current_rank,
);
}
};
// Remember this substitution for recursive references
try self.var_map.put(resolved_var, existing_var);
return existing_var;
},
}
};
// Remember this substitution for recursive references
// IMPORTANT: This has to be inserted _before_ we recurse into `instantiateContent`
const fresh_var = try self.store.freshFromContentWithRank(.{ .flex = Flex.init() }, self.current_rank);
try self.var_map.put(resolved_var, fresh_var);
// Copy the rigid var's constraints
const fresh_constraints = try self.instantiateStaticDispatchConstraints(rigid.constraints);
// Copy the rigid var's constraints
const fresh_content = switch (fresh_type) {
.flex => Content{ .flex = Flex{ .name = rigid.name, .constraints = fresh_constraints } },
.rigid => Content{ .rigid = Rigid{ .name = rigid.name, .constraints = fresh_constraints } },
};
// Update the placeholder fresh var with the real content
try self.store.dangerousSetVarDesc(
fresh_var,
.{
.content = fresh_content,
.rank = self.current_rank,
.mark = Mark.none,
},
);
return fresh_var;
},
else => {
// Remember this substitution for recursive references
// IMPORTANT: This has to be inserted _before_ we recurse into `instantiateContent`
const fresh_var = try self.store.fresh();
try self.var_map.put(resolved_var, fresh_var);
// Generate the content
const fresh_content = try self.instantiateContent(resolved.desc.content);
// Update the placeholder fresh var with the real content
try self.store.dangerousSetVarDesc(
fresh_var,
.{
.content = fresh_content,
.rank = self.current_rank,
.mark = Mark.none,
},
);
return fresh_var;
},
}
}
fn instantiateContent(self: *Self, content: Content) std.mem.Allocator.Error!Content {
return switch (content) {
.flex => |flex| Content{ .flex = try self.instantiateFlex(flex) },
.rigid => {
// Rigids should be handled by `instantiateVar`
// If we have run into one here, it is abug
unreachable;
},
.alias => |alias| {
// Instantiate the structure recursively
return try self.instantiateAlias(alias);
},
.structure => |flat_type| blk: {
// Instantiate the structure recursively
const fresh_flat_type = try self.instantiateFlatType(flat_type);
break :blk Content{ .structure = fresh_flat_type };
},
.recursion_var => |rec_var| blk: {
// Instantiate the structure the recursion var points to
const fresh_structure = try self.instantiateVar(rec_var.structure);
break :blk Content{ .recursion_var = .{ .structure = fresh_structure, .name = rec_var.name } };
},
.err => Content.err,
};
}
fn instantiateFlex(self: *Self, flex: Flex) std.mem.Allocator.Error!Flex {
const fresh_constraints = try self.instantiateStaticDispatchConstraints(flex.constraints);
return Flex{ .name = flex.name, .constraints = fresh_constraints };
}
fn instantiateAlias(self: *Self, alias: Alias) std.mem.Allocator.Error!Content {
var fresh_vars = std.ArrayList(Var).empty;
defer fresh_vars.deinit(self.store.gpa);
var iter = self.store.iterAliasArgs(alias);
while (iter.next()) |arg_var| {
const fresh_elem = try self.instantiateVar(arg_var);
try fresh_vars.append(self.store.gpa, fresh_elem);
}
const backing_var = self.store.getAliasBackingVar(alias);
const fresh_backing_var = try self.instantiateVar(backing_var);
return self.store.mkAlias(alias.ident, fresh_backing_var, fresh_vars.items);
}
fn instantiateFlatType(self: *Self, flat_type: FlatType) std.mem.Allocator.Error!FlatType {
return switch (flat_type) {
.tuple => |tuple| FlatType{ .tuple = try self.instantiateTuple(tuple) },
.nominal_type => |nominal| FlatType{ .nominal_type = try self.instantiateNominalType(nominal) },
.fn_pure => |func| FlatType{ .fn_pure = try self.instantiateFunc(func) },
.fn_effectful => |func| FlatType{ .fn_effectful = try self.instantiateFunc(func) },
.fn_unbound => |func| FlatType{ .fn_unbound = try self.instantiateFunc(func) },
.record => |record| FlatType{ .record = try self.instantiateRecord(record) },
.record_unbound => |fields| FlatType{ .record_unbound = try self.instantiateRecordFields(fields) },
.empty_record => FlatType.empty_record,
.tag_union => |tag_union| FlatType{ .tag_union = try self.instantiateTagUnion(tag_union) },
.empty_tag_union => FlatType.empty_tag_union,
};
}
fn instantiateNominalType(self: *Self, nominal: NominalType) std.mem.Allocator.Error!NominalType {
const backing_var = self.store.getNominalBackingVar(nominal);
const fresh_backing_var = try self.instantiateVar(backing_var);
var fresh_vars = std.ArrayList(Var).empty;
defer fresh_vars.deinit(self.store.gpa);
var iter = self.store.iterNominalArgs(nominal);
while (iter.next()) |arg_var| {
const fresh_elem = try self.instantiateVar(arg_var);
try fresh_vars.append(self.store.gpa, fresh_elem);
}
return (try self.store.mkNominal(nominal.ident, fresh_backing_var, fresh_vars.items, nominal.origin_module, nominal.is_opaque)).structure.nominal_type;
}
fn instantiateTuple(self: *Self, tuple: Tuple) std.mem.Allocator.Error!Tuple {
// Use index-based iteration to avoid iterator invalidation
// (see comment in instantiateFunc for details)
var fresh_elems = std.ArrayList(Var).empty;
defer fresh_elems.deinit(self.store.gpa);
const elems_start: usize = @intFromEnum(tuple.elems.start);
for (0..tuple.elems.count) |i| {
const elem_var = self.store.vars.items.items[elems_start + i];
const fresh_elem = try self.instantiateVar(elem_var);
try fresh_elems.append(self.store.gpa, fresh_elem);
}
const fresh_elems_range = try self.store.appendVars(fresh_elems.items);
return Tuple{ .elems = fresh_elems_range };
}
fn instantiateFunc(self: *Self, func: Func) std.mem.Allocator.Error!Func {
// IMPORTANT: We must use index-based iteration here, not slice-based.
// The slice would point into the backing ArrayList, but instantiateVar
// can recursively call appendVars which may reallocate the array,
// invalidating the slice pointer.
var fresh_args = std.ArrayList(Var).empty;
defer fresh_args.deinit(self.store.gpa);
const args_start: usize = @intFromEnum(func.args.start);
for (0..func.args.count) |i| {
// Re-fetch the var on each iteration since the backing array may have moved
const arg_var = self.store.vars.items.items[args_start + i];
const fresh_arg = try self.instantiateVar(arg_var);
try fresh_args.append(self.store.gpa, fresh_arg);
}
const fresh_ret = try self.instantiateVar(func.ret);
const fresh_args_range = try self.store.appendVars(fresh_args.items);
return Func{
.args = fresh_args_range,
.ret = fresh_ret,
.needs_instantiation = true,
};
}
fn instantiateRecordFields(self: *Self, fields: RecordField.SafeMultiList.Range) std.mem.Allocator.Error!RecordField.SafeMultiList.Range {
// IMPORTANT: We must use index-based iteration here, not slice-based.
// The slice would point into the backing MultiArrayList, but instantiateVar
// can recursively call appendRecordFields which may reallocate the array,
// invalidating the slice pointers.
if (fields.count == 0) {
return try self.store.appendRecordFields(&.{});
}
var fresh_fields = std.ArrayList(RecordField).empty;
defer fresh_fields.deinit(self.store.gpa);
const fields_start: usize = @intFromEnum(fields.start);
for (0..fields.count) |i| {
// Re-fetch the field data on each iteration since the backing array may have moved
const field = self.store.record_fields.get(@enumFromInt(fields_start + i));
const fresh_type = try self.instantiateVar(field.var_);
_ = try fresh_fields.append(self.store.gpa, RecordField{
.name = field.name,
.var_ = fresh_type,
});
}
return try self.store.appendRecordFields(fresh_fields.items);
}
fn instantiateRecord(self: *Self, record: Record) std.mem.Allocator.Error!Record {
// IMPORTANT: We must use index-based iteration here, not slice-based.
// The slice would point into the backing MultiArrayList, but instantiateVar
// can recursively call appendRecordFields which may reallocate the array,
// invalidating the slice pointers.
if (record.fields.count == 0) {
return Record{
.fields = try self.store.appendRecordFields(&.{}),
.ext = try self.instantiateVar(record.ext),
};
}
var fresh_fields = std.ArrayList(RecordField).empty;
defer fresh_fields.deinit(self.store.gpa);
const fields_start: usize = @intFromEnum(record.fields.start);
for (0..record.fields.count) |i| {
// Re-fetch the field data on each iteration since the backing array may have moved
const field = self.store.record_fields.get(@enumFromInt(fields_start + i));
const fresh_type = try self.instantiateVar(field.var_);
_ = try fresh_fields.append(self.store.gpa, RecordField{
.name = field.name,
.var_ = fresh_type,
});
}
const fields_range = try self.store.appendRecordFields(fresh_fields.items);
return Record{
.fields = fields_range,
.ext = try self.instantiateVar(record.ext),
};
}
fn instantiateTagUnion(self: *Self, tag_union: TagUnion) std.mem.Allocator.Error!TagUnion {
// IMPORTANT: We must use index-based iteration here, not slice-based.
// The slice would point into the backing MultiArrayList, but instantiateVar
// can recursively call appendTags which may reallocate the array,
// invalidating the slice pointers.
if (tag_union.tags.count == 0) {
return TagUnion{
.tags = try self.store.appendTags(&.{}),
.ext = try self.instantiateVar(tag_union.ext),
};
}
var fresh_tags = std.ArrayList(Tag).empty;
defer fresh_tags.deinit(self.store.gpa);
const tags_start: usize = @intFromEnum(tag_union.tags.start);
for (0..tag_union.tags.count) |tag_i| {
// Re-fetch the tag data on each iteration since the backing array may have moved
const tag = self.store.tags.get(@enumFromInt(tags_start + tag_i));
const tag_name = tag.name;
const tag_args = tag.args;
var fresh_args = std.ArrayList(Var).empty;
defer fresh_args.deinit(self.store.gpa);
// Skip the loop entirely for tags with no arguments.
// This avoids accessing tag_args.start which may be undefined when count is 0.
if (tag_args.count > 0) {
// Use index-based iteration to avoid iterator invalidation
// (see comment in instantiateFunc for details)
const args_start: usize = @intFromEnum(tag_args.start);
for (0..tag_args.count) |i| {
const arg_var = self.store.vars.items.items[args_start + i];
const fresh_arg = try self.instantiateVar(arg_var);
try fresh_args.append(self.store.gpa, fresh_arg);
}
}
const fresh_args_range = try self.store.appendVars(fresh_args.items);
_ = try fresh_tags.append(self.store.gpa, Tag{
.name = tag_name,
.args = fresh_args_range,
});
}
// Sort the fresh tags alphabetically by name before appending.
// This ensures tag discriminants are consistent after instantiation.
std.mem.sort(Tag, fresh_tags.items, self.idents, comptime Tag.sortByNameAsc);
const tags_range = try self.store.appendTags(fresh_tags.items);
return TagUnion{
.tags = tags_range,
.ext = try self.instantiateVar(tag_union.ext),
};
}
pub fn getIdent(self: *const Self, idx: Ident.Idx) []const u8 {
return self.idents.getText(idx);
}
fn instantiateStaticDispatchConstraints(self: *Self, constraints: StaticDispatchConstraint.SafeList.Range) std.mem.Allocator.Error!StaticDispatchConstraint.SafeList.Range {
const constraints_len = constraints.len();
if (constraints_len == 0) {
return StaticDispatchConstraint.SafeList.Range.empty();
} else {
var fresh_constraints = try std.ArrayList(StaticDispatchConstraint).initCapacity(self.store.gpa, constraints.len());
defer fresh_constraints.deinit(self.store.gpa);
// Use index-based iteration to avoid iterator invalidation
// (see comment in instantiateFunc for details)
const constraints_start: usize = @intFromEnum(constraints.start);
for (0..constraints_len) |i| {
const constraint = self.store.static_dispatch_constraints.items.items[constraints_start + i];
const fresh_constraint = try self.instantiateStaticDispatchConstraint(constraint);
try fresh_constraints.append(self.store.gpa, fresh_constraint);
}
const fresh_constraints_range = try self.store.appendStaticDispatchConstraints(fresh_constraints.items);
return fresh_constraints_range;
}
}
fn instantiateStaticDispatchConstraint(self: *Self, constraint: StaticDispatchConstraint) std.mem.Allocator.Error!StaticDispatchConstraint {
return StaticDispatchConstraint{
.fn_name = constraint.fn_name,
.fn_var = try self.instantiateVar(constraint.fn_var),
.origin = constraint.origin,
};
}
};