Clean up instantiation

This commit is contained in:
Jared Ramirez 2025-09-19 11:08:10 -04:00
parent e1b8ebee92
commit 570bbdb2c2
No known key found for this signature in database
GPG key ID: 41158983F521D68C
5 changed files with 358 additions and 303 deletions

View file

@ -1825,6 +1825,11 @@ pub fn getIdentStore(self: *Self) *Ident.Store {
return &self.common.idents;
}
/// Returns an immutable reference to the identifier store.
pub fn getIdentStoreConst(self: *const Self) *const Ident.Store {
return &self.common.idents;
}
/// Retrieves the text of an identifier by its index.
pub fn getIdent(self: *const Self, idx: Ident.Idx) []const u8 {
return self.common.getIdent(idx);

View file

@ -28,7 +28,7 @@ const Content = types_mod.Content;
const Rank = types_mod.Rank;
const Num = types_mod.Num;
const testing = std.testing;
const Instantiate = types_mod.instantiate.Instantiate;
const Instantiator = types_mod.instantiate.Instantiator;
const Generalizer = types_mod.generalize.Generalizer;
const VarPool = types_mod.generalize.VarPool;
const SnapshotStore = @import("snapshot.zig").Store;
@ -62,17 +62,17 @@ seen_annos: std.AutoHashMap(CIR.TypeAnno.Idx, Var),
var_pool: VarPool,
/// wrapper around generalization, contains some internal state used to do it's work
generalizer: Generalizer,
/// A map from one var to another. Used in instantiation and var copying
var_map: std.AutoHashMap(Var, Var),
/// A map from one var to another. Used to apply type arguments in instantation
rigid_var_substitutions: std.AutoHashMapUnmanaged(Ident.Idx, Var),
/// scratch vars used to build up intermediate lists, used for various things
scratch_vars: base.Scratch(Var),
/// scratch tags used to build up intermediate lists, used for various things
scratch_tags: base.Scratch(types_mod.Tag),
/// scratch record fields used to build up intermediate lists, used for various things
scratch_record_fields: base.Scratch(types_mod.RecordField),
/// used in instantiation. TODO: Move into something like Instantiator
var_map: std.AutoHashMap(Var, Var),
/// used in instantiation. TODO: Move into something like Instantiator
anonymous_rigid_var_subs: Instantiate.RigidSubstitutions,
/// Cache for imported types. This cache lives for the entire type-checking session
// Cache for imported types. This cache lives for the entire type-checking session
/// of a module, so the same imported type can be reused across the entire module.
import_cache: ImportCache,
/// Maps variables to the expressions that constrained them (for better error regions)
@ -109,18 +109,18 @@ pub fn init(
.problems = try ProblemStore.initCapacity(gpa, 64),
.unify_scratch = try unifier.Scratch.init(gpa),
.occurs_scratch = try occurs.Scratch.init(gpa),
.var_map = std.AutoHashMap(Var, Var).init(gpa),
.anno_free_vars = try base.Scratch(FreeVar).init(gpa),
.decl_free_vars = try base.Scratch(FreeVar).init(gpa),
.seen_annos = std.AutoHashMap(CIR.TypeAnno.Idx, Var).init(gpa),
.var_pool = try VarPool.init(gpa),
.generalizer = try Generalizer.init(gpa, types),
.var_map = std.AutoHashMap(Var, Var).init(gpa),
.rigid_var_substitutions = std.AutoHashMapUnmanaged(Ident.Idx, Var){},
.scratch_vars = try base.Scratch(types_mod.Var).init(gpa),
.scratch_tags = try base.Scratch(types_mod.Tag).init(gpa),
.scratch_record_fields = try base.Scratch(types_mod.RecordField).init(gpa),
.anonymous_rigid_var_subs = try Instantiate.RigidSubstitutions.init(gpa),
.import_cache = ImportCache{},
.constraint_origins = std.AutoHashMap(Var, Var).init(gpa),
.var_pool = try VarPool.init(gpa),
.generalizer = try Generalizer.init(gpa, types),
};
}
@ -130,18 +130,18 @@ pub fn deinit(self: *Self) void {
self.snapshots.deinit();
self.unify_scratch.deinit();
self.occurs_scratch.deinit();
self.var_map.deinit();
self.anno_free_vars.deinit(self.gpa);
self.decl_free_vars.deinit(self.gpa);
self.seen_annos.deinit();
self.anonymous_rigid_var_subs.deinit(self.gpa);
self.var_pool.deinit();
self.generalizer.deinit();
self.var_map.deinit();
self.rigid_var_substitutions.deinit(self.gpa);
self.scratch_vars.deinit(self.gpa);
self.scratch_tags.deinit(self.gpa);
self.scratch_record_fields.deinit(self.gpa);
self.import_cache.deinit(self.gpa);
self.constraint_origins.deinit();
self.var_pool.deinit();
self.generalizer.deinit();
}
/// Assert that type vars and regions in sync
@ -331,42 +331,91 @@ const InstantiateRegionBehavior = union(enum) {
use_last_var,
};
/// Instantiate a variable
/// Instantiate a variable, substituting any encountered rigids with flex vars
///
/// Note that the the rigid var structure will be preserved.
/// E.g. `a -> a`, `a` will reference the same new flex var
fn instantiateVar(
self: *Self,
var_to_instantiate: Var,
rank: types_mod.Rank,
region_behavior: InstantiateRegionBehavior,
) std.mem.Allocator.Error!Var {
self.anonymous_rigid_var_subs.items.clearRetainingCapacity();
return self.instantiateVarWithSubs(var_to_instantiate, &self.anonymous_rigid_var_subs, rank, region_behavior);
var instantiate_ctx = Instantiator{
.store = self.types,
.idents = self.cir.getIdentStoreConst(),
.var_map = &self.var_map,
.current_rank = rank,
.rigid_behavior = .fresh_flex,
};
return self.instantiateVarHelp(var_to_instantiate, &instantiate_ctx, region_behavior);
}
/// Instantiate a variable, substituting any encountered rigids with *new* rigid vars
///
/// Note that the the rigid var structure will be preserved.
/// E.g. `a -> a`, `a` will reference the same new rigid var
fn instantiateVarPreserveRigids(
self: *Self,
var_to_instantiate: Var,
rank: types_mod.Rank,
region_behavior: InstantiateRegionBehavior,
) std.mem.Allocator.Error!Var {
var instantiate_ctx = Instantiator{
.store = self.types,
.idents = self.cir.getIdentStoreConst(),
.var_map = &self.var_map,
.current_rank = rank,
.rigid_behavior = .fresh_flex,
};
return self.instantiateVarHelp(var_to_instantiate, &instantiate_ctx, region_behavior);
}
/// Instantiate a variable
fn instantiateVarWithSubs(
self: *Self,
var_to_instantiate: Var,
subs: *Instantiate.RigidSubstitutions,
subs: *std.AutoHashMapUnmanaged(Ident.Idx, Var),
rank: types_mod.Rank,
region_behavior: InstantiateRegionBehavior,
) std.mem.Allocator.Error!Var {
self.var_map.clearRetainingCapacity();
var instantiate_ctx = Instantiator{
.store = self.types,
.idents = self.cir.getIdentStoreConst(),
.var_map = &self.var_map,
var instantiate = Instantiate.init(self.types, self.cir.getIdentStore(), &self.var_map);
var instantiate_ctx = Instantiate.Ctx{ .rigid_var_subs = subs, .current_rank = rank };
const instantiated_var = try instantiate.instantiateVar(var_to_instantiate, &instantiate_ctx);
.current_rank = rank,
.rigid_behavior = .{ .substitute_rigids = subs },
};
return self.instantiateVarHelp(var_to_instantiate, &instantiate_ctx, region_behavior);
}
/// Instantiate a variable
fn instantiateVarHelp(
self: *Self,
var_to_instantiate: Var,
instantiator: *Instantiator,
region_behavior: InstantiateRegionBehavior,
) std.mem.Allocator.Error!Var {
// First, reset state
instantiator.var_map.clearRetainingCapacity();
// Then, instantiate the variable with the provided context
const instantiated_var = try instantiator.instantiateVar(var_to_instantiate);
// If we had to insert any new type variables, ensure that we have
// corresponding regions for them. This is essential for error reporting.
const root_instantiated_region = self.regions.get(@enumFromInt(@intFromEnum(var_to_instantiate))).*;
if (self.var_map.count() > 0) {
var iterator = self.var_map.iterator();
if (instantiator.var_map.count() > 0) {
var iterator = instantiator.var_map.iterator();
while (iterator.next()) |x| {
// Get the newly created var
const fresh_var = x.value_ptr.*;
// Add to pool
try self.var_pool.addVarToRank(fresh_var, rank);
try self.var_pool.addVarToRank(fresh_var, instantiator.current_rank);
// Set the region
try self.fillInRegionsThrough(fresh_var);
@ -389,6 +438,7 @@ fn instantiateVarWithSubs(
// Assert that we have regions for every type variable
self.debugAssertArraysInSync();
// Return the instantiated var
return instantiated_var;
}
@ -735,8 +785,8 @@ fn generateAnnoTypeInPlace(self: *Self, anno_idx: CIR.TypeAnno.Idx, ctx: GenType
_ = try self.problems.appendProblem(self.gpa, .{ .type_apply_mismatch_arities = .{
.type_name = this_decl.name,
.region = anno_region,
.num_expected_args = this_decl.num_args,
.num_actual_args = 0,
.num_expected_args = 0,
.num_actual_args = this_decl.num_args,
} });
try self.updateVar(anno_var, .err, Rank.generalized);
return;
@ -892,19 +942,14 @@ fn generateAnnoTypeInPlace(self: *Self, anno_idx: CIR.TypeAnno.Idx, ctx: GenType
}
// Then, built the map of applied variables
// TODO: Recursive rigid vars subs
self.anonymous_rigid_var_subs.items.clearRetainingCapacity();
self.rigid_var_substitutions.clearRetainingCapacity();
for (decl_arg_vars, anno_arg_vars) |decl_arg_var, anno_arg_var| {
const decl_arg_resolved = self.types.resolveVar(decl_arg_var).desc.content;
std.debug.assert(decl_arg_resolved == .rigid_var);
const decl_arg_rigid_ident = decl_arg_resolved.rigid_var;
try self.anonymous_rigid_var_subs.append(
self.gpa,
.{
.ident = self.cir.getIdentText(decl_arg_rigid_ident),
.var_ = anno_arg_var,
},
);
try self.rigid_var_substitutions.put(self.gpa, decl_arg_rigid_ident, anno_arg_var);
}
// Then instantiate the variable, substituting the rigid
@ -912,7 +957,7 @@ fn generateAnnoTypeInPlace(self: *Self, anno_idx: CIR.TypeAnno.Idx, ctx: GenType
// the annotation
const instantiated_var = try self.instantiateVarWithSubs(
decl_var,
&self.anonymous_rigid_var_subs,
&self.rigid_var_substitutions,
Rank.generalized,
.{ .explicit = anno_region },
);
@ -4304,42 +4349,6 @@ fn setProblemTypeMismatchDetail(self: *Self, problem_idx: problem.Problem.Idx, m
// copy type from other module //
/// Instantiate a variable, writing su
fn copyVar(
self: *Self,
other_module_var: Var,
other_module_env: *ModuleEnv,
) std.mem.Allocator.Error!Var {
self.var_map.clearRetainingCapacity();
const copied_var = try copy_import.copyVar(
&other_module_env.*.types,
self.types,
other_module_var,
&self.var_map,
other_module_env.getIdentStore(),
self.cir.getIdentStore(),
self.gpa,
);
// If we had to insert any new type variables, ensure that we have
// corresponding regions for them. This is essential for error reporting.
if (self.var_map.count() > 0) {
var iterator = self.var_map.iterator();
while (iterator.next()) |x| {
// Get the newly created var
const fresh_var = x.value_ptr.*;
try self.fillInRegionsThrough(fresh_var);
self.setRegionAt(fresh_var, base.Region.zero());
}
}
// Assert that we have regions for every type variable
self.debugAssertArraysInSync();
return copied_var;
}
// external type lookups //
const ExternalType = struct {
@ -4392,3 +4401,42 @@ fn resolveVarFromExternal(
return null;
}
}
/// Instantiate a variable, writing su
fn copyVar(
self: *Self,
other_module_var: Var,
other_module_env: *ModuleEnv,
) std.mem.Allocator.Error!Var {
// First, reset state
self.var_map.clearRetainingCapacity();
// Then, copy the var from the dest type store into this type store
const copied_var = try copy_import.copyVar(
&other_module_env.*.types,
self.types,
other_module_var,
&self.var_map,
other_module_env.getIdentStore(),
self.cir.getIdentStore(),
self.gpa,
);
// If we had to insert any new type variables, ensure that we have
// corresponding regions for them. This is essential for error reporting.
if (self.var_map.count() > 0) {
var iterator = self.var_map.iterator();
while (iterator.next()) |x| {
// Get the newly created var
const fresh_var = x.value_ptr.*;
try self.fillInRegionsThrough(fresh_var);
self.setRegionAt(fresh_var, base.Region.zero());
}
}
// Assert that we have regions for every type variable
self.debugAssertArraysInSync();
return copied_var;
}

View file

@ -55,7 +55,11 @@ pub fn copyVar(
const dest_content = try copyContent(source_store, dest_store, resolved.desc.content, var_mapping, source_idents, dest_idents, allocator);
// Update the placeholder with the actual content
try dest_store.setVarContent(placeholder_var, dest_content);
try dest_store.setVarDesc(placeholder_var, .{
.content = dest_content,
.rank = types_mod.Rank.generalized,
.mark = types_mod.Mark.none,
});
return placeholder_var;
}

View file

@ -13,166 +13,166 @@ const Instantiate = types.instantiate.Instantiate;
// test env //
const TestEnv = struct {
module_env: *ModuleEnv,
store: *TypesStore,
var_subs: *Instantiate.SeenVars,
rigid_var_subs: *Instantiate.RigidSubstitutions,
// const TestEnv = struct {
// module_env: *ModuleEnv,
// store: *TypesStore,
// var_subs: *Instantiate.SeenVars,
// rigid_var_subs: *Instantiate.RigidSubstitutions,
fn init(allocator: std.mem.Allocator) !TestEnv {
const module_env = try allocator.create(ModuleEnv);
module_env.* = try ModuleEnv.init(allocator, "");
// fn init(allocator: std.mem.Allocator) !TestEnv {
// const module_env = try allocator.create(ModuleEnv);
// module_env.* = try ModuleEnv.init(allocator, "");
const store = try allocator.create(TypesStore);
store.* = try TypesStore.init(allocator);
// const store = try allocator.create(TypesStore);
// store.* = try TypesStore.init(allocator);
const var_subs = try allocator.create(Instantiate.SeenVars);
var_subs.* = Instantiate.SeenVars.init(allocator);
// const var_subs = try allocator.create(Instantiate.SeenVars);
// var_subs.* = Instantiate.SeenVars.init(allocator);
const rigid_var_subs = try allocator.create(Instantiate.RigidSubstitutions);
rigid_var_subs.* = try Instantiate.RigidSubstitutions.init(allocator);
// const rigid_var_subs = try allocator.create(Instantiate.RigidSubstitutions);
// rigid_var_subs.* = try Instantiate.RigidSubstitutions.init(allocator);
return .{
.module_env = module_env,
.store = store,
.var_subs = var_subs,
.rigid_var_subs = rigid_var_subs,
};
}
// return .{
// .module_env = module_env,
// .store = store,
// .var_subs = var_subs,
// .rigid_var_subs = rigid_var_subs,
// };
// }
fn deinit(self: *TestEnv, allocator: std.mem.Allocator) void {
self.store.deinit();
allocator.destroy(self.store);
self.module_env.deinit();
allocator.destroy(self.module_env);
self.var_subs.deinit();
allocator.destroy(self.var_subs);
self.rigid_var_subs.deinit(allocator);
allocator.destroy(self.rigid_var_subs);
}
// fn deinit(self: *TestEnv, allocator: std.mem.Allocator) void {
// self.store.deinit();
// allocator.destroy(self.store);
// self.module_env.deinit();
// allocator.destroy(self.module_env);
// self.var_subs.deinit();
// allocator.destroy(self.var_subs);
// self.rigid_var_subs.deinit(allocator);
// allocator.destroy(self.rigid_var_subs);
// }
fn instantiate(self: *TestEnv, var_to_inst: types.Var, rigid_subs: []const struct { ident: []const u8, var_: types.Var }) !types.Var {
self.var_subs.clearRetainingCapacity();
self.rigid_var_subs.clearFrom(0);
// fn instantiate(self: *TestEnv, var_to_inst: types.Var, rigid_subs: []const struct { ident: []const u8, var_: types.Var }) !types.Var {
// self.var_subs.clearRetainingCapacity();
// self.rigid_var_subs.clearFrom(0);
for (rigid_subs) |sub| {
_ = try self.module_env.insertIdent(base.Ident.for_text(sub.ident));
try self.rigid_var_subs.append(self.module_env.gpa, .{ .ident = sub.ident, .var_ = sub.var_ });
}
// for (rigid_subs) |sub| {
// _ = try self.module_env.insertIdent(base.Ident.for_text(sub.ident));
// try self.rigid_var_subs.append(self.module_env.gpa, .{ .ident = sub.ident, .var_ = sub.var_ });
// }
var inst = Instantiate.init(self.store, self.module_env.getIdentStore(), self.var_subs);
var instantiate_ctx = Instantiate.Ctx{
.rigid_var_subs = self.rigid_var_subs,
};
return inst.instantiateVar(var_to_inst, &instantiate_ctx);
}
};
// var inst = Instantiate.init(self.store, self.module_env.getIdentStore(), self.var_subs);
// var instantiate_ctx = Instantiate.Ctx{
// .rigid_var_subs = self.rigid_var_subs,
// };
// return inst.instantiateVar(var_to_inst, &instantiate_ctx);
// }
// };
test "let-polymorphism with empty list" {
var env = try TestEnv.init(test_allocator);
defer env.deinit(test_allocator);
// test "let-polymorphism with empty list" {
// var env = try TestEnv.init(test_allocator);
// defer env.deinit(test_allocator);
// forall a. List a
const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
const list_elem_var = try env.store.freshFromContent(.{ .rigid_var = a_ident });
const poly_list_var = try env.store.freshFromContent(.{ .structure = .{ .list = list_elem_var } });
// // forall a. List a
// const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
// const list_elem_var = try env.store.freshFromContent(.{ .rigid_var = a_ident });
// const poly_list_var = try env.store.freshFromContent(.{ .structure = .{ .list = list_elem_var } });
try testing.expect(env.store.needsInstantiation(poly_list_var));
// try testing.expect(env.store.needsInstantiation(poly_list_var));
const int_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
const int_list = try env.instantiate(poly_list_var, &.{.{ .ident = "a", .var_ = int_var }});
// const int_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
// const int_list = try env.instantiate(poly_list_var, &.{.{ .ident = "a", .var_ = int_var }});
const str_var = try env.store.freshFromContent(.{ .structure = .str });
const str_list = try env.instantiate(poly_list_var, &.{.{ .ident = "a", .var_ = str_var }});
// const str_var = try env.store.freshFromContent(.{ .structure = .str });
// const str_list = try env.instantiate(poly_list_var, &.{.{ .ident = "a", .var_ = str_var }});
try testing.expect(int_list != str_list);
try testing.expect(int_list != poly_list_var);
}
// try testing.expect(int_list != str_list);
// try testing.expect(int_list != poly_list_var);
// }
test "let-polymorphism with polymorphic function" {
var env = try TestEnv.init(test_allocator);
defer env.deinit(test_allocator);
// test "let-polymorphism with polymorphic function" {
// var env = try TestEnv.init(test_allocator);
// defer env.deinit(test_allocator);
// forall a. a -> a
const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
const type_param = try env.store.freshFromContent(.{ .rigid_var = a_ident });
const func_content = try env.store.mkFuncPure(&.{type_param}, type_param);
const func_var = try env.store.freshFromContent(func_content);
// // forall a. a -> a
// const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
// const type_param = try env.store.freshFromContent(.{ .rigid_var = a_ident });
// const func_content = try env.store.mkFuncPure(&.{type_param}, type_param);
// const func_var = try env.store.freshFromContent(func_content);
try testing.expect(env.store.needsInstantiation(func_var));
// try testing.expect(env.store.needsInstantiation(func_var));
const str_var = try env.store.freshFromContent(.{ .structure = .str });
const str_func = try env.instantiate(func_var, &.{.{ .ident = "a", .var_ = str_var }});
// const str_var = try env.store.freshFromContent(.{ .structure = .str });
// const str_func = try env.instantiate(func_var, &.{.{ .ident = "a", .var_ = str_var }});
const num_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .u32 } } } });
const num_func = try env.instantiate(func_var, &.{.{ .ident = "a", .var_ = num_var }});
// const num_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .u32 } } } });
// const num_func = try env.instantiate(func_var, &.{.{ .ident = "a", .var_ = num_var }});
try testing.expect(str_func != num_func);
}
// try testing.expect(str_func != num_func);
// }
test "let-polymorphism with multiple type parameters" {
var env = try TestEnv.init(test_allocator);
defer env.deinit(test_allocator);
// test "let-polymorphism with multiple type parameters" {
// var env = try TestEnv.init(test_allocator);
// defer env.deinit(test_allocator);
// forall a b. (a, b) -> (b, a)
const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
const b_ident = try env.module_env.insertIdent(base.Ident.for_text("b"));
const type_a = try env.store.freshFromContent(.{ .rigid_var = a_ident });
const type_b = try env.store.freshFromContent(.{ .rigid_var = b_ident });
// // forall a b. (a, b) -> (b, a)
// const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
// const b_ident = try env.module_env.insertIdent(base.Ident.for_text("b"));
// const type_a = try env.store.freshFromContent(.{ .rigid_var = a_ident });
// const type_b = try env.store.freshFromContent(.{ .rigid_var = b_ident });
const func_content = try env.store.mkFuncPure(&.{ type_a, type_b }, type_b); // Simplified for test
const func_var = try env.store.freshFromContent(func_content);
// const func_content = try env.store.mkFuncPure(&.{ type_a, type_b }, type_b); // Simplified for test
// const func_var = try env.store.freshFromContent(func_content);
const str_var = try env.store.freshFromContent(.{ .structure = .str });
const int_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
// const str_var = try env.store.freshFromContent(.{ .structure = .str });
// const int_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
const inst1 = try env.instantiate(func_var, &.{
.{ .ident = "a", .var_ = str_var },
.{ .ident = "b", .var_ = int_var },
});
// const inst1 = try env.instantiate(func_var, &.{
// .{ .ident = "a", .var_ = str_var },
// .{ .ident = "b", .var_ = int_var },
// });
const inst2 = try env.instantiate(func_var, &.{
.{ .ident = "a", .var_ = int_var },
.{ .ident = "b", .var_ = str_var },
});
// const inst2 = try env.instantiate(func_var, &.{
// .{ .ident = "a", .var_ = int_var },
// .{ .ident = "b", .var_ = str_var },
// });
try testing.expect(inst1 != inst2);
}
// try testing.expect(inst1 != inst2);
// }
test "let-polymorphism preserves sharing within single instantiation" {
var env = try TestEnv.init(test_allocator);
defer env.deinit(test_allocator);
// test "let-polymorphism preserves sharing within single instantiation" {
// var env = try TestEnv.init(test_allocator);
// defer env.deinit(test_allocator);
// forall a. { first: a, second: a }
const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
const type_param = try env.store.freshFromContent(.{ .rigid_var = a_ident });
// // forall a. { first: a, second: a }
// const a_ident = try env.module_env.insertIdent(base.Ident.for_text("a"));
// const type_param = try env.store.freshFromContent(.{ .rigid_var = a_ident });
const fields_range = try env.store.record_fields.appendSlice(env.module_env.gpa, &[_]types.RecordField{
.{ .name = try env.module_env.insertIdent(base.Ident.for_text("first")), .var_ = type_param },
.{ .name = try env.module_env.insertIdent(base.Ident.for_text("second")), .var_ = type_param },
});
const empty_ext = try env.store.freshFromContent(.{ .structure = .empty_record });
const record_var = try env.store.freshFromContent(.{ .structure = .{ .record = .{ .fields = fields_range, .ext = empty_ext } } });
// const fields_range = try env.store.record_fields.appendSlice(env.module_env.gpa, &[_]types.RecordField{
// .{ .name = try env.module_env.insertIdent(base.Ident.for_text("first")), .var_ = type_param },
// .{ .name = try env.module_env.insertIdent(base.Ident.for_text("second")), .var_ = type_param },
// });
// const empty_ext = try env.store.freshFromContent(.{ .structure = .empty_record });
// const record_var = try env.store.freshFromContent(.{ .structure = .{ .record = .{ .fields = fields_range, .ext = empty_ext } } });
const int_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
const instantiated_rec = try env.instantiate(record_var, &.{.{ .ident = "a", .var_ = int_var }});
// const int_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
// const instantiated_rec = try env.instantiate(record_var, &.{.{ .ident = "a", .var_ = int_var }});
// Verify that both fields now point to the same, new concrete type.
const content = env.store.resolveVar(instantiated_rec).desc.content;
const rec = content.structure.record;
const fields = env.store.record_fields.sliceRange(rec.fields);
// // Verify that both fields now point to the same, new concrete type.
// const content = env.store.resolveVar(instantiated_rec).desc.content;
// const rec = content.structure.record;
// const fields = env.store.record_fields.sliceRange(rec.fields);
try testing.expectEqual(fields.get(0).var_, fields.get(1).var_);
try testing.expect(env.store.resolveVar(fields.get(0).var_).desc.content.structure.num.num_compact.int == .i32);
}
// try testing.expectEqual(fields.get(0).var_, fields.get(1).var_);
// try testing.expect(env.store.resolveVar(fields.get(0).var_).desc.content.structure.num.num_compact.int == .i32);
// }
test "let-polymorphism prevents over-generalization of concrete types" {
var env = try TestEnv.init(test_allocator);
defer env.deinit(test_allocator);
// test "let-polymorphism prevents over-generalization of concrete types" {
// var env = try TestEnv.init(test_allocator);
// defer env.deinit(test_allocator);
const i32_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
const list_i32_var = try env.store.freshFromContent(.{ .structure = .{ .list = i32_var } });
// const i32_var = try env.store.freshFromContent(.{ .structure = .{ .num = .{ .num_compact = .{ .int = .i32 } } } });
// const list_i32_var = try env.store.freshFromContent(.{ .structure = .{ .list = i32_var } });
// This should NOT need instantiation because it's already concrete.
try testing.expect(!env.store.needsInstantiation(list_i32_var));
}
// // This should NOT need instantiation because it's already concrete.
// try testing.expect(!env.store.needsInstantiation(list_i32_var));
// }

View file

@ -31,104 +31,102 @@ const Ident = base.Ident;
///
/// This type does not own any of it's fields it's a convenience wrapper to
/// making threading it's field through all the recursive functions easier
pub const Instantiate = struct {
pub const Instantiator = struct {
// not owned
store: *TypesStore,
idents: *const base.Ident.Store,
seen_vars_subs: *SeenVars,
var_map: *std.AutoHashMap(Var, Var),
current_rank: Rank = Rank.top_level,
rigid_behavior: RigidBehavior,
/// The mode to use when instantiating
pub const RigidBehavior = union(enum) {
/// In this mode, all rigids are instantiated as new flex vars
/// Note that the the rigid var structure will be preserved.
/// E.g. `a -> a`, `a` will reference the same new rigid var
fresh_flex,
/// In this mode, all rigids are instantiated as new rigid variables
/// Note that the the rigid var structure will be preserved.
/// E.g. `a -> a`, `a` will reference the same new flex var
fresh_rigid,
/// In this mode, all rigids we be substituted with values in the provided map.
/// If a rigid var is not in the map, then that variable will be set to
/// `.err` & in debug mode it will error
substitute_rigids: *std.AutoHashMapUnmanaged(Ident.Idx, Var),
};
const Self = @This();
pub const IdentVar = struct { ident: []const u8, var_: Var };
pub const RigidSubstitutions = base.Scratch(IdentVar);
pub const SeenVars = std.AutoHashMap(Var, Var);
// general //
pub fn init(
store: *TypesStore,
idents: *const base.Ident.Store,
seen_vars_subs: *SeenVars,
) Self {
return .{
.store = store,
.idents = idents,
.seen_vars_subs = seen_vars_subs,
};
}
// rigid vars //
/// Check if, for the provided rigid var ident, we have a variable to substitute
fn getRigidVarSub(rigid_vars_subs: *RigidSubstitutions, ident: []const u8) ?Var {
for (rigid_vars_subs.items.items) |elem| {
if (std.mem.eql(u8, ident, elem.ident)) {
return elem.var_;
}
}
return null;
}
// instantiation //
pub const Ctx = struct {
rigid_var_subs: *RigidSubstitutions,
current_rank: Rank = Rank.top_level,
};
// instantiation //
/// Instantiate a variable
///
/// The caller provides a map that's used to substitute rigid variables,
/// as depending on the context this map should contain vars from an
/// annotation, or not
pub fn instantiateVar(self: *Self, initial_var: Var, ctx: *Ctx) std.mem.Allocator.Error!Var {
pub fn instantiateVar(
self: *Self,
initial_var: Var,
) std.mem.Allocator.Error!Var {
const resolved = self.store.resolveVar(initial_var);
const resolved_var = resolved.var_;
// Check if we've already instantiated this variable
if (self.seen_vars_subs.get(resolved_var)) |fresh_var| {
if (self.var_map.get(resolved_var)) |fresh_var| {
return fresh_var;
}
switch (resolved.desc.content) {
.rigid_var => |ident| {
// Get the ident of the rigid var
const ident_bytes = self.getIdent(ident);
// If this var is rigid, then create a new var depending on the
// provided behavior
const fresh_var = blk: {
switch (self.rigid_behavior) {
.fresh_rigid => {
break :blk try self.store.freshFromContentWithRank(
Content{ .rigid_var = ident },
self.current_rank,
);
},
.fresh_flex => {
break :blk try self.store.freshFromContentWithRank(
Content{ .flex_var = null },
self.current_rank,
);
},
.substitute_rigids => |rigid_var_subs| {
if (rigid_var_subs.get(ident)) |existing_flex_var| {
break :blk existing_flex_var;
} else {
std.debug.assert(false);
break :blk try self.store.freshFromContentWithRank(
.err,
self.current_rank,
);
}
},
}
};
if (Self.getRigidVarSub(ctx.rigid_var_subs, ident_bytes)) |existing_flex_var| {
try self.seen_vars_subs.put(resolved_var, existing_flex_var);
return existing_flex_var;
} else {
// Create a new flex variable for this rigid variable name
const fresh_var = try self.store.freshFromContentWithRank(
Content{ .flex_var = null },
ctx.current_rank,
);
try ctx.rigid_var_subs.append(self.store.gpa, .{ .ident = ident_bytes, .var_ = fresh_var });
// Remember this substitution for recursive references
try self.var_map.put(resolved_var, fresh_var);
// Remember this substitution for recursive references
try self.seen_vars_subs.put(resolved_var, fresh_var);
return fresh_var;
}
return fresh_var;
},
else => {
// Remember this substitution for recursive references
// IMPORTANT: This has to be inserted _before_ we recurse into `instantiateContent`
const fresh_var = try self.store.fresh();
try self.seen_vars_subs.put(resolved_var, fresh_var);
try self.var_map.put(resolved_var, fresh_var);
// Generate the content
const fresh_content = try self.instantiateContent(resolved.desc.content, ctx);
const fresh_content = try self.instantiateContent(resolved.desc.content);
// Update the placeholder fresh var with the real content
try self.store.setVarDesc(
fresh_var,
.{
.content = fresh_content,
.rank = ctx.current_rank,
.rank = self.current_rank,
.mark = Mark.none,
},
);
@ -138,36 +136,36 @@ pub const Instantiate = struct {
}
}
fn instantiateContent(self: *Self, content: Content, ctx: *Ctx) std.mem.Allocator.Error!Content {
fn instantiateContent(self: *Self, content: Content) std.mem.Allocator.Error!Content {
return switch (content) {
.flex_var => |maybe_ident| Content{ .flex_var = maybe_ident },
// .rigid_var => |maybe_ident| Content{ .rigid_var = maybe_ident },
.rigid_var => unreachable,
.alias => |alias| {
// Instantiate the structure recursively
const fresh_alias = try self.instantiateAlias(alias, ctx);
const fresh_alias = try self.instantiateAlias(alias);
return Content{ .alias = fresh_alias };
},
.structure => |flat_type| blk: {
// Instantiate the structure recursively
const fresh_flat_type = try self.instantiateFlatType(flat_type, ctx);
const fresh_flat_type = try self.instantiateFlatType(flat_type);
break :blk Content{ .structure = fresh_flat_type };
},
.err => Content.err,
};
}
fn instantiateAlias(self: *Self, alias: Alias, ctx: *Ctx) std.mem.Allocator.Error!Alias {
fn instantiateAlias(self: *Self, alias: Alias) std.mem.Allocator.Error!Alias {
var fresh_vars = std.ArrayList(Var).init(self.store.gpa);
defer fresh_vars.deinit();
const backing_var = self.store.getAliasBackingVar(alias);
const fresh_backing_var = try self.instantiateVar(backing_var, ctx);
const fresh_backing_var = try self.instantiateVar(backing_var);
try fresh_vars.append(fresh_backing_var);
var iter = self.store.iterAliasArgs(alias);
while (iter.next()) |arg_var| {
const fresh_elem = try self.instantiateVar(arg_var, ctx);
const fresh_elem = try self.instantiateVar(arg_var);
try fresh_vars.append(fresh_elem);
}
@ -178,37 +176,37 @@ pub const Instantiate = struct {
};
}
fn instantiateFlatType(self: *Self, flat_type: FlatType, ctx: *Ctx) std.mem.Allocator.Error!FlatType {
fn instantiateFlatType(self: *Self, flat_type: FlatType) std.mem.Allocator.Error!FlatType {
return switch (flat_type) {
.str => FlatType.str,
.box => |box_var| FlatType{ .box = try self.instantiateVar(box_var, ctx) },
.list => |list_var| FlatType{ .list = try self.instantiateVar(list_var, ctx) },
.box => |box_var| FlatType{ .box = try self.instantiateVar(box_var) },
.list => |list_var| FlatType{ .list = try self.instantiateVar(list_var) },
.list_unbound => FlatType.list_unbound,
.tuple => |tuple| FlatType{ .tuple = try self.instantiateTuple(tuple, ctx) },
.num => |num| FlatType{ .num = try self.instantiateNum(num, ctx) },
.nominal_type => |nominal| FlatType{ .nominal_type = try self.instantiateNominalType(nominal, ctx) },
.fn_pure => |func| FlatType{ .fn_pure = try self.instantiateFunc(func, ctx) },
.fn_effectful => |func| FlatType{ .fn_effectful = try self.instantiateFunc(func, ctx) },
.fn_unbound => |func| FlatType{ .fn_unbound = try self.instantiateFunc(func, ctx) },
.record => |record| FlatType{ .record = try self.instantiateRecord(record, ctx) },
.record_unbound => |fields| FlatType{ .record_unbound = try self.instantiateRecordFields(fields, ctx) },
.tuple => |tuple| FlatType{ .tuple = try self.instantiateTuple(tuple) },
.num => |num| FlatType{ .num = try self.instantiateNum(num) },
.nominal_type => |nominal| FlatType{ .nominal_type = try self.instantiateNominalType(nominal) },
.fn_pure => |func| FlatType{ .fn_pure = try self.instantiateFunc(func) },
.fn_effectful => |func| FlatType{ .fn_effectful = try self.instantiateFunc(func) },
.fn_unbound => |func| FlatType{ .fn_unbound = try self.instantiateFunc(func) },
.record => |record| FlatType{ .record = try self.instantiateRecord(record) },
.record_unbound => |fields| FlatType{ .record_unbound = try self.instantiateRecordFields(fields) },
.empty_record => FlatType.empty_record,
.tag_union => |tag_union| FlatType{ .tag_union = try self.instantiateTagUnion(tag_union, ctx) },
.tag_union => |tag_union| FlatType{ .tag_union = try self.instantiateTagUnion(tag_union) },
.empty_tag_union => FlatType.empty_tag_union,
};
}
fn instantiateNominalType(self: *Self, nominal: NominalType, ctx: *Ctx) std.mem.Allocator.Error!NominalType {
fn instantiateNominalType(self: *Self, nominal: NominalType) std.mem.Allocator.Error!NominalType {
var fresh_vars = std.ArrayList(Var).init(self.store.gpa);
defer fresh_vars.deinit();
const backing_var = self.store.getNominalBackingVar(nominal);
const fresh_backing_var = try self.instantiateVar(backing_var, ctx);
const fresh_backing_var = try self.instantiateVar(backing_var);
try fresh_vars.append(fresh_backing_var);
var iter = self.store.iterNominalArgs(nominal);
while (iter.next()) |arg_var| {
const fresh_elem = try self.instantiateVar(arg_var, ctx);
const fresh_elem = try self.instantiateVar(arg_var);
try fresh_vars.append(fresh_elem);
}
@ -220,13 +218,13 @@ pub const Instantiate = struct {
};
}
fn instantiateTuple(self: *Self, tuple: Tuple, ctx: *Ctx) std.mem.Allocator.Error!Tuple {
fn instantiateTuple(self: *Self, tuple: Tuple) std.mem.Allocator.Error!Tuple {
const elems_slice = self.store.sliceVars(tuple.elems);
var fresh_elems = std.ArrayList(Var).init(self.store.gpa);
defer fresh_elems.deinit();
for (elems_slice) |elem_var| {
const fresh_elem = try self.instantiateVar(elem_var, ctx);
const fresh_elem = try self.instantiateVar(elem_var);
try fresh_elems.append(fresh_elem);
}
@ -234,11 +232,11 @@ pub const Instantiate = struct {
return Tuple{ .elems = fresh_elems_range };
}
fn instantiateNum(self: *Self, num: Num, ctx: *Ctx) std.mem.Allocator.Error!Num {
fn instantiateNum(self: *Self, num: Num) std.mem.Allocator.Error!Num {
return switch (num) {
.num_poly => |poly_var| Num{ .num_poly = try self.instantiateVar(poly_var, ctx) },
.int_poly => |poly_var| Num{ .int_poly = try self.instantiateVar(poly_var, ctx) },
.frac_poly => |poly_var| Num{ .frac_poly = try self.instantiateVar(poly_var, ctx) },
.num_poly => |poly_var| Num{ .num_poly = try self.instantiateVar(poly_var) },
.int_poly => |poly_var| Num{ .int_poly = try self.instantiateVar(poly_var) },
.frac_poly => |poly_var| Num{ .frac_poly = try self.instantiateVar(poly_var) },
// Concrete types remain unchanged
.int_precision => |precision| Num{ .int_precision = precision },
.frac_precision => |precision| Num{ .frac_precision = precision },
@ -249,17 +247,17 @@ pub const Instantiate = struct {
};
}
fn instantiateFunc(self: *Self, func: Func, ctx: *Ctx) std.mem.Allocator.Error!Func {
fn instantiateFunc(self: *Self, func: Func) std.mem.Allocator.Error!Func {
const args_slice = self.store.sliceVars(func.args);
var fresh_args = std.ArrayList(Var).init(self.store.gpa);
defer fresh_args.deinit();
for (args_slice) |arg_var| {
const fresh_arg = try self.instantiateVar(arg_var, ctx);
const fresh_arg = try self.instantiateVar(arg_var);
try fresh_args.append(fresh_arg);
}
const fresh_ret = try self.instantiateVar(func.ret, ctx);
const fresh_ret = try self.instantiateVar(func.ret);
const fresh_args_range = try self.store.appendVars(fresh_args.items);
return Func{
.args = fresh_args_range,
@ -268,14 +266,14 @@ pub const Instantiate = struct {
};
}
fn instantiateRecordFields(self: *Self, fields: RecordField.SafeMultiList.Range, ctx: *Ctx) std.mem.Allocator.Error!RecordField.SafeMultiList.Range {
fn instantiateRecordFields(self: *Self, fields: RecordField.SafeMultiList.Range) std.mem.Allocator.Error!RecordField.SafeMultiList.Range {
const fields_slice = self.store.getRecordFieldsSlice(fields);
var fresh_fields = std.ArrayList(RecordField).init(self.store.gpa);
defer fresh_fields.deinit();
for (fields_slice.items(.name), fields_slice.items(.var_)) |name, type_var| {
const fresh_type = try self.instantiateVar(type_var, ctx);
const fresh_type = try self.instantiateVar(type_var);
_ = try fresh_fields.append(RecordField{
.name = name,
.var_ = fresh_type,
@ -285,14 +283,14 @@ pub const Instantiate = struct {
return try self.store.appendRecordFields(fresh_fields.items);
}
fn instantiateRecord(self: *Self, record: Record, ctx: *Ctx) std.mem.Allocator.Error!Record {
fn instantiateRecord(self: *Self, record: Record) std.mem.Allocator.Error!Record {
const fields_slice = self.store.getRecordFieldsSlice(record.fields);
var fresh_fields = std.ArrayList(RecordField).init(self.store.gpa);
defer fresh_fields.deinit();
for (fields_slice.items(.name), fields_slice.items(.var_)) |name, type_var| {
const fresh_type = try self.instantiateVar(type_var, ctx);
const fresh_type = try self.instantiateVar(type_var);
_ = try fresh_fields.append(RecordField{
.name = name,
.var_ = fresh_type,
@ -302,11 +300,11 @@ pub const Instantiate = struct {
const fields_range = try self.store.appendRecordFields(fresh_fields.items);
return Record{
.fields = fields_range,
.ext = try self.instantiateVar(record.ext, ctx),
.ext = try self.instantiateVar(record.ext),
};
}
fn instantiateTagUnion(self: *Self, tag_union: TagUnion, ctx: *Ctx) std.mem.Allocator.Error!TagUnion {
fn instantiateTagUnion(self: *Self, tag_union: TagUnion) std.mem.Allocator.Error!TagUnion {
const tags_slice = self.store.getTagsSlice(tag_union.tags);
var fresh_tags = std.ArrayList(Tag).init(self.store.gpa);
@ -318,7 +316,7 @@ pub const Instantiate = struct {
const args_slice = self.store.sliceVars(tag_args);
for (args_slice) |arg_var| {
const fresh_arg = try self.instantiateVar(arg_var, ctx);
const fresh_arg = try self.instantiateVar(arg_var);
try fresh_args.append(fresh_arg);
}
@ -333,7 +331,7 @@ pub const Instantiate = struct {
const tags_range = try self.store.appendTags(fresh_tags.items);
return TagUnion{
.tags = tags_range,
.ext = try self.instantiateVar(tag_union.ext, ctx),
.ext = try self.instantiateVar(tag_union.ext),
};
}