mirror of
https://github.com/roc-lang/roc.git
synced 2025-12-23 08:48:03 +00:00
438 lines
19 KiB
Zig
438 lines
19 KiB
Zig
//! Type instantiation for Hindley-Milner type inference.
|
||
//!
|
||
//! This module provides functionality to instantiate polymorphic types with fresh
|
||
//! type variables while preserving type aliases and structure. This is a critical
|
||
//! component for proper handling of annotated functions in the type system.
|
||
|
||
const std = @import("std");
|
||
const base = @import("base");
|
||
const collections = @import("collections");
|
||
const types_store = @import("store.zig");
|
||
const types_mod = @import("types.zig");
|
||
|
||
const TypesStore = types_store.Store;
|
||
const Var = types_mod.Var;
|
||
const Flex = types_mod.Flex;
|
||
const StaticDispatchConstraint = types_mod.StaticDispatchConstraint;
|
||
const Rigid = types_mod.Rigid;
|
||
const Content = types_mod.Content;
|
||
const FlatType = types_mod.FlatType;
|
||
const Alias = types_mod.Alias;
|
||
const Func = types_mod.Func;
|
||
const Record = types_mod.Record;
|
||
const TagUnion = types_mod.TagUnion;
|
||
const RecordField = types_mod.RecordField;
|
||
const Tag = types_mod.Tag;
|
||
const Num = types_mod.Num;
|
||
const NominalType = types_mod.NominalType;
|
||
const Tuple = types_mod.Tuple;
|
||
const Rank = types_mod.Rank;
|
||
const Mark = types_mod.Mark;
|
||
const Ident = base.Ident;
|
||
|
||
/// Type to manage instantiation.
|
||
///
|
||
/// Entry point is `instantiateVar`
|
||
///
|
||
/// This type does not own any of it's fields – it's a convenience wrapper to
|
||
/// making threading it's field through all the recursive functions easier
|
||
pub const Instantiator = struct {
|
||
// not owned
|
||
store: *TypesStore,
|
||
idents: *const base.Ident.Store,
|
||
var_map: *std.AutoHashMap(Var, Var),
|
||
|
||
current_rank: Rank = Rank.top_level,
|
||
rigid_behavior: RigidBehavior,
|
||
|
||
/// The mode to use when instantiating
|
||
pub const RigidBehavior = union(enum) {
|
||
/// In this mode, all rigids are instantiated as new flex vars
|
||
/// Note that the the rigid var structure will be preserved.
|
||
/// E.g. `a -> a`, `a` will reference the same new rigid var
|
||
fresh_flex,
|
||
|
||
/// In this mode, all rigids are instantiated as new rigid variables
|
||
/// Note that the the rigid var structure will be preserved.
|
||
/// E.g. `a -> a`, `a` will reference the same new flex var
|
||
fresh_rigid,
|
||
|
||
/// In this mode, all rigids we be substituted with values in the provided map.
|
||
/// If a rigid var is not in the map, then that variable will be set to
|
||
/// `.err` & in debug mode it will error
|
||
substitute_rigids: *std.AutoHashMapUnmanaged(Ident.Idx, Var),
|
||
};
|
||
|
||
const Self = @This();
|
||
|
||
// instantiation //
|
||
|
||
/// Instantiate a variable
|
||
pub fn instantiateVar(
|
||
self: *Self,
|
||
initial_var: Var,
|
||
) std.mem.Allocator.Error!Var {
|
||
const resolved = self.store.resolveVar(initial_var);
|
||
const resolved_var = resolved.var_;
|
||
|
||
// Check if we've already instantiated this variable
|
||
if (self.var_map.get(resolved_var)) |fresh_var| {
|
||
return fresh_var;
|
||
}
|
||
|
||
switch (resolved.desc.content) {
|
||
.rigid => |rigid| {
|
||
// If this var is rigid, then create a new var depending on the
|
||
// provided behavior
|
||
const fresh_type: enum { flex, rigid } = blk: {
|
||
switch (self.rigid_behavior) {
|
||
.fresh_rigid => {
|
||
break :blk .rigid;
|
||
},
|
||
.fresh_flex => {
|
||
break :blk .flex;
|
||
},
|
||
.substitute_rigids => |rigid_subs| {
|
||
// If this is a var that we're substituting, then we
|
||
// we just return it.
|
||
|
||
const existing_var = inner_blk: {
|
||
if (rigid_subs.get(rigid.name)) |existing_flex| {
|
||
break :inner_blk existing_flex;
|
||
} else {
|
||
std.debug.assert(false);
|
||
break :inner_blk try self.store.freshFromContentWithRank(
|
||
.err,
|
||
self.current_rank,
|
||
);
|
||
}
|
||
};
|
||
|
||
// Remember this substitution for recursive references
|
||
try self.var_map.put(resolved_var, existing_var);
|
||
|
||
return existing_var;
|
||
},
|
||
}
|
||
};
|
||
|
||
// Remember this substitution for recursive references
|
||
// IMPORTANT: This has to be inserted _before_ we recurse into `instantiateContent`
|
||
const fresh_var = try self.store.freshFromContentWithRank(.{ .flex = Flex.init() }, self.current_rank);
|
||
try self.var_map.put(resolved_var, fresh_var);
|
||
|
||
// Copy the rigid var's constraints
|
||
const fresh_constraints = try self.instantiateStaticDispatchConstraints(rigid.constraints);
|
||
|
||
// Copy the rigid var's constraints
|
||
const fresh_content = switch (fresh_type) {
|
||
.flex => Content{ .flex = Flex{ .name = rigid.name, .constraints = fresh_constraints } },
|
||
.rigid => Content{ .rigid = Rigid{ .name = rigid.name, .constraints = fresh_constraints } },
|
||
};
|
||
|
||
// Update the placeholder fresh var with the real content
|
||
try self.store.dangerousSetVarDesc(
|
||
fresh_var,
|
||
.{
|
||
.content = fresh_content,
|
||
.rank = self.current_rank,
|
||
.mark = Mark.none,
|
||
},
|
||
);
|
||
|
||
return fresh_var;
|
||
},
|
||
else => {
|
||
// Remember this substitution for recursive references
|
||
// IMPORTANT: This has to be inserted _before_ we recurse into `instantiateContent`
|
||
const fresh_var = try self.store.fresh();
|
||
try self.var_map.put(resolved_var, fresh_var);
|
||
|
||
// Generate the content
|
||
const fresh_content = try self.instantiateContent(resolved.desc.content);
|
||
|
||
// Update the placeholder fresh var with the real content
|
||
try self.store.dangerousSetVarDesc(
|
||
fresh_var,
|
||
.{
|
||
.content = fresh_content,
|
||
.rank = self.current_rank,
|
||
.mark = Mark.none,
|
||
},
|
||
);
|
||
|
||
return fresh_var;
|
||
},
|
||
}
|
||
}
|
||
|
||
fn instantiateContent(self: *Self, content: Content) std.mem.Allocator.Error!Content {
|
||
return switch (content) {
|
||
.flex => |flex| Content{ .flex = try self.instantiateFlex(flex) },
|
||
.rigid => {
|
||
// Rigids should be handled by `instantiateVar`
|
||
// If we have run into one here, it is abug
|
||
unreachable;
|
||
},
|
||
.alias => |alias| {
|
||
// Instantiate the structure recursively
|
||
return try self.instantiateAlias(alias);
|
||
},
|
||
.structure => |flat_type| blk: {
|
||
// Instantiate the structure recursively
|
||
const fresh_flat_type = try self.instantiateFlatType(flat_type);
|
||
break :blk Content{ .structure = fresh_flat_type };
|
||
},
|
||
.recursion_var => |rec_var| blk: {
|
||
// Instantiate the structure the recursion var points to
|
||
const fresh_structure = try self.instantiateVar(rec_var.structure);
|
||
break :blk Content{ .recursion_var = .{ .structure = fresh_structure, .name = rec_var.name } };
|
||
},
|
||
.err => Content.err,
|
||
};
|
||
}
|
||
|
||
fn instantiateFlex(self: *Self, flex: Flex) std.mem.Allocator.Error!Flex {
|
||
const fresh_constraints = try self.instantiateStaticDispatchConstraints(flex.constraints);
|
||
|
||
return Flex{ .name = flex.name, .constraints = fresh_constraints };
|
||
}
|
||
|
||
fn instantiateAlias(self: *Self, alias: Alias) std.mem.Allocator.Error!Content {
|
||
var fresh_vars = std.ArrayList(Var).empty;
|
||
defer fresh_vars.deinit(self.store.gpa);
|
||
|
||
var iter = self.store.iterAliasArgs(alias);
|
||
while (iter.next()) |arg_var| {
|
||
const fresh_elem = try self.instantiateVar(arg_var);
|
||
try fresh_vars.append(self.store.gpa, fresh_elem);
|
||
}
|
||
|
||
const backing_var = self.store.getAliasBackingVar(alias);
|
||
const fresh_backing_var = try self.instantiateVar(backing_var);
|
||
|
||
return self.store.mkAlias(alias.ident, fresh_backing_var, fresh_vars.items);
|
||
}
|
||
|
||
fn instantiateFlatType(self: *Self, flat_type: FlatType) std.mem.Allocator.Error!FlatType {
|
||
return switch (flat_type) {
|
||
.tuple => |tuple| FlatType{ .tuple = try self.instantiateTuple(tuple) },
|
||
.nominal_type => |nominal| FlatType{ .nominal_type = try self.instantiateNominalType(nominal) },
|
||
.fn_pure => |func| FlatType{ .fn_pure = try self.instantiateFunc(func) },
|
||
.fn_effectful => |func| FlatType{ .fn_effectful = try self.instantiateFunc(func) },
|
||
.fn_unbound => |func| FlatType{ .fn_unbound = try self.instantiateFunc(func) },
|
||
.record => |record| FlatType{ .record = try self.instantiateRecord(record) },
|
||
.record_unbound => |fields| FlatType{ .record_unbound = try self.instantiateRecordFields(fields) },
|
||
.empty_record => FlatType.empty_record,
|
||
.tag_union => |tag_union| FlatType{ .tag_union = try self.instantiateTagUnion(tag_union) },
|
||
.empty_tag_union => FlatType.empty_tag_union,
|
||
};
|
||
}
|
||
|
||
fn instantiateNominalType(self: *Self, nominal: NominalType) std.mem.Allocator.Error!NominalType {
|
||
const backing_var = self.store.getNominalBackingVar(nominal);
|
||
const fresh_backing_var = try self.instantiateVar(backing_var);
|
||
|
||
var fresh_vars = std.ArrayList(Var).empty;
|
||
defer fresh_vars.deinit(self.store.gpa);
|
||
|
||
var iter = self.store.iterNominalArgs(nominal);
|
||
while (iter.next()) |arg_var| {
|
||
const fresh_elem = try self.instantiateVar(arg_var);
|
||
try fresh_vars.append(self.store.gpa, fresh_elem);
|
||
}
|
||
|
||
return (try self.store.mkNominal(nominal.ident, fresh_backing_var, fresh_vars.items, nominal.origin_module, nominal.is_opaque)).structure.nominal_type;
|
||
}
|
||
|
||
fn instantiateTuple(self: *Self, tuple: Tuple) std.mem.Allocator.Error!Tuple {
|
||
// Use index-based iteration to avoid iterator invalidation
|
||
// (see comment in instantiateFunc for details)
|
||
var fresh_elems = std.ArrayList(Var).empty;
|
||
defer fresh_elems.deinit(self.store.gpa);
|
||
|
||
const elems_start: usize = @intFromEnum(tuple.elems.start);
|
||
for (0..tuple.elems.count) |i| {
|
||
const elem_var = self.store.vars.items.items[elems_start + i];
|
||
const fresh_elem = try self.instantiateVar(elem_var);
|
||
try fresh_elems.append(self.store.gpa, fresh_elem);
|
||
}
|
||
|
||
const fresh_elems_range = try self.store.appendVars(fresh_elems.items);
|
||
return Tuple{ .elems = fresh_elems_range };
|
||
}
|
||
fn instantiateFunc(self: *Self, func: Func) std.mem.Allocator.Error!Func {
|
||
// IMPORTANT: We must use index-based iteration here, not slice-based.
|
||
// The slice would point into the backing ArrayList, but instantiateVar
|
||
// can recursively call appendVars which may reallocate the array,
|
||
// invalidating the slice pointer.
|
||
var fresh_args = std.ArrayList(Var).empty;
|
||
defer fresh_args.deinit(self.store.gpa);
|
||
|
||
const args_start: usize = @intFromEnum(func.args.start);
|
||
for (0..func.args.count) |i| {
|
||
// Re-fetch the var on each iteration since the backing array may have moved
|
||
const arg_var = self.store.vars.items.items[args_start + i];
|
||
const fresh_arg = try self.instantiateVar(arg_var);
|
||
try fresh_args.append(self.store.gpa, fresh_arg);
|
||
}
|
||
|
||
const fresh_ret = try self.instantiateVar(func.ret);
|
||
const fresh_args_range = try self.store.appendVars(fresh_args.items);
|
||
return Func{
|
||
.args = fresh_args_range,
|
||
.ret = fresh_ret,
|
||
.needs_instantiation = true,
|
||
};
|
||
}
|
||
|
||
fn instantiateRecordFields(self: *Self, fields: RecordField.SafeMultiList.Range) std.mem.Allocator.Error!RecordField.SafeMultiList.Range {
|
||
// IMPORTANT: We must use index-based iteration here, not slice-based.
|
||
// The slice would point into the backing MultiArrayList, but instantiateVar
|
||
// can recursively call appendRecordFields which may reallocate the array,
|
||
// invalidating the slice pointers.
|
||
if (fields.count == 0) {
|
||
return try self.store.appendRecordFields(&.{});
|
||
}
|
||
|
||
var fresh_fields = std.ArrayList(RecordField).empty;
|
||
defer fresh_fields.deinit(self.store.gpa);
|
||
|
||
const fields_start: usize = @intFromEnum(fields.start);
|
||
for (0..fields.count) |i| {
|
||
// Re-fetch the field data on each iteration since the backing array may have moved
|
||
const field = self.store.record_fields.get(@enumFromInt(fields_start + i));
|
||
const fresh_type = try self.instantiateVar(field.var_);
|
||
_ = try fresh_fields.append(self.store.gpa, RecordField{
|
||
.name = field.name,
|
||
.var_ = fresh_type,
|
||
});
|
||
}
|
||
|
||
return try self.store.appendRecordFields(fresh_fields.items);
|
||
}
|
||
|
||
fn instantiateRecord(self: *Self, record: Record) std.mem.Allocator.Error!Record {
|
||
// IMPORTANT: We must use index-based iteration here, not slice-based.
|
||
// The slice would point into the backing MultiArrayList, but instantiateVar
|
||
// can recursively call appendRecordFields which may reallocate the array,
|
||
// invalidating the slice pointers.
|
||
if (record.fields.count == 0) {
|
||
return Record{
|
||
.fields = try self.store.appendRecordFields(&.{}),
|
||
.ext = try self.instantiateVar(record.ext),
|
||
};
|
||
}
|
||
|
||
var fresh_fields = std.ArrayList(RecordField).empty;
|
||
defer fresh_fields.deinit(self.store.gpa);
|
||
|
||
const fields_start: usize = @intFromEnum(record.fields.start);
|
||
for (0..record.fields.count) |i| {
|
||
// Re-fetch the field data on each iteration since the backing array may have moved
|
||
const field = self.store.record_fields.get(@enumFromInt(fields_start + i));
|
||
const fresh_type = try self.instantiateVar(field.var_);
|
||
_ = try fresh_fields.append(self.store.gpa, RecordField{
|
||
.name = field.name,
|
||
.var_ = fresh_type,
|
||
});
|
||
}
|
||
|
||
const fields_range = try self.store.appendRecordFields(fresh_fields.items);
|
||
return Record{
|
||
.fields = fields_range,
|
||
.ext = try self.instantiateVar(record.ext),
|
||
};
|
||
}
|
||
|
||
fn instantiateTagUnion(self: *Self, tag_union: TagUnion) std.mem.Allocator.Error!TagUnion {
|
||
// IMPORTANT: We must use index-based iteration here, not slice-based.
|
||
// The slice would point into the backing MultiArrayList, but instantiateVar
|
||
// can recursively call appendTags which may reallocate the array,
|
||
// invalidating the slice pointers.
|
||
if (tag_union.tags.count == 0) {
|
||
return TagUnion{
|
||
.tags = try self.store.appendTags(&.{}),
|
||
.ext = try self.instantiateVar(tag_union.ext),
|
||
};
|
||
}
|
||
|
||
var fresh_tags = std.ArrayList(Tag).empty;
|
||
defer fresh_tags.deinit(self.store.gpa);
|
||
|
||
const tags_start: usize = @intFromEnum(tag_union.tags.start);
|
||
for (0..tag_union.tags.count) |tag_i| {
|
||
// Re-fetch the tag data on each iteration since the backing array may have moved
|
||
const tag = self.store.tags.get(@enumFromInt(tags_start + tag_i));
|
||
const tag_name = tag.name;
|
||
const tag_args = tag.args;
|
||
|
||
var fresh_args = std.ArrayList(Var).empty;
|
||
defer fresh_args.deinit(self.store.gpa);
|
||
|
||
// Skip the loop entirely for tags with no arguments.
|
||
// This avoids accessing tag_args.start which may be undefined when count is 0.
|
||
if (tag_args.count > 0) {
|
||
// Use index-based iteration to avoid iterator invalidation
|
||
// (see comment in instantiateFunc for details)
|
||
const args_start: usize = @intFromEnum(tag_args.start);
|
||
for (0..tag_args.count) |i| {
|
||
const arg_var = self.store.vars.items.items[args_start + i];
|
||
const fresh_arg = try self.instantiateVar(arg_var);
|
||
try fresh_args.append(self.store.gpa, fresh_arg);
|
||
}
|
||
}
|
||
|
||
const fresh_args_range = try self.store.appendVars(fresh_args.items);
|
||
|
||
_ = try fresh_tags.append(self.store.gpa, Tag{
|
||
.name = tag_name,
|
||
.args = fresh_args_range,
|
||
});
|
||
}
|
||
|
||
// Sort the fresh tags alphabetically by name before appending.
|
||
// This ensures tag discriminants are consistent after instantiation.
|
||
std.mem.sort(Tag, fresh_tags.items, self.idents, comptime Tag.sortByNameAsc);
|
||
|
||
const tags_range = try self.store.appendTags(fresh_tags.items);
|
||
return TagUnion{
|
||
.tags = tags_range,
|
||
.ext = try self.instantiateVar(tag_union.ext),
|
||
};
|
||
}
|
||
|
||
pub fn getIdent(self: *const Self, idx: Ident.Idx) []const u8 {
|
||
return self.idents.getText(idx);
|
||
}
|
||
|
||
fn instantiateStaticDispatchConstraints(self: *Self, constraints: StaticDispatchConstraint.SafeList.Range) std.mem.Allocator.Error!StaticDispatchConstraint.SafeList.Range {
|
||
const constraints_len = constraints.len();
|
||
if (constraints_len == 0) {
|
||
return StaticDispatchConstraint.SafeList.Range.empty();
|
||
} else {
|
||
var fresh_constraints = try std.ArrayList(StaticDispatchConstraint).initCapacity(self.store.gpa, constraints.len());
|
||
defer fresh_constraints.deinit(self.store.gpa);
|
||
|
||
// Use index-based iteration to avoid iterator invalidation
|
||
// (see comment in instantiateFunc for details)
|
||
const constraints_start: usize = @intFromEnum(constraints.start);
|
||
for (0..constraints_len) |i| {
|
||
const constraint = self.store.static_dispatch_constraints.items.items[constraints_start + i];
|
||
const fresh_constraint = try self.instantiateStaticDispatchConstraint(constraint);
|
||
try fresh_constraints.append(self.store.gpa, fresh_constraint);
|
||
}
|
||
|
||
const fresh_constraints_range = try self.store.appendStaticDispatchConstraints(fresh_constraints.items);
|
||
return fresh_constraints_range;
|
||
}
|
||
}
|
||
|
||
fn instantiateStaticDispatchConstraint(self: *Self, constraint: StaticDispatchConstraint) std.mem.Allocator.Error!StaticDispatchConstraint {
|
||
return StaticDispatchConstraint{
|
||
.fn_name = constraint.fn_name,
|
||
.fn_var = try self.instantiateVar(constraint.fn_var),
|
||
.origin = constraint.origin,
|
||
};
|
||
}
|
||
};
|