From fdb4fda57754ea16db197d240cbd58356d96e2fe Mon Sep 17 00:00:00 2001 From: Jared Ramirez Date: Thu, 11 Sep 2025 14:03:12 -0400 Subject: [PATCH] Recursive nominal types --- src/canonicalize/Can.zig | 75 +-- src/canonicalize/NodeStore.zig | 27 +- src/check/Check.zig | 494 ++++++++++--------- src/check/test/type_checking_integration.zig | 12 +- src/repl/Repl.zig | 2 +- src/snapshot_tool/main.zig | 4 +- test/snapshots/annotations.md | 80 ++- 7 files changed, 387 insertions(+), 307 deletions(-) diff --git a/src/canonicalize/Can.zig b/src/canonicalize/Can.zig index 46984a16d3..4419e6482a 100644 --- a/src/canonicalize/Can.zig +++ b/src/canonicalize/Can.zig @@ -441,10 +441,10 @@ pub fn canonicalizeFile( }, }; - const placeholder_type_decl_idx = try self.env.addStatementAndTypeVar(placeholder_cir_type_decl, Content{ .flex_var = null }, region); + const type_decl_stmt_idx = try self.env.addStatementAndTypeVar(placeholder_cir_type_decl, .err, region); // Introduce the type name into scope early to support recursive references - try self.scopeIntroduceTypeDecl(type_header.name, placeholder_type_decl_idx, region); + try self.scopeIntroduceTypeDecl(type_header.name, type_decl_stmt_idx, region); // Process type parameters and annotation in a separate scope const anno_idx = blk: { @@ -459,19 +459,6 @@ pub fn canonicalizeFile( break :blk try self.canonicalizeTypeAnno(type_decl.anno, .type_decl_anno); }; - // Get type variables to args (lhs) - const header_arg_vars: []TypeVar = @ptrCast(self.env.store.sliceTypeAnnos(type_header.args)); - - // Get type variable to the backing type (rhs) - const anno_var = ModuleEnv.varFrom(anno_idx); - - // Check if the backing type is already an error type - const backing_resolved = self.env.types.resolveVar(anno_var); - const backing_is_error = backing_resolved.desc.content == .err; - - // The identified of the type - const type_ident = types.TypeIdent{ .ident_idx = type_header.name }; - // Canonicalize where clauses if present if (type_decl.where) |_| { try self.env.pushDiagnostic(Diagnostic{ .where_clause_not_allowed_in_type_decl = .{ @@ -480,58 +467,35 @@ pub fn canonicalizeFile( } // Create the real CIR type declaration statement with the canonicalized annotation - const real_cir_type_decl, const type_decl_content = blk: { + const type_decl_stmt = blk: { switch (type_decl.kind) { .alias => { - const alias_content = if (backing_is_error) - types.Content{ .err = {} } - else - try self.env.types.mkAlias(type_ident, anno_var, header_arg_vars); - - break :blk .{ - Statement{ - .s_alias_decl = .{ - .header = header_idx, - .anno = anno_idx, - }, + break :blk Statement{ + .s_alias_decl = .{ + .header = header_idx, + .anno = anno_idx, }, - alias_content, }; }, .nominal => { - const nominal_content = if (backing_is_error) - types.Content{ .err = {} } - else - try self.env.types.mkNominal( - type_ident, - anno_var, - header_arg_vars, - try self.env.insertIdent(base.Ident.for_text(self.env.module_name)), - ); - - break :blk .{ - Statement{ - .s_nominal_decl = .{ - .header = header_idx, - .anno = anno_idx, - }, + break :blk Statement{ + .s_nominal_decl = .{ + .header = header_idx, + .anno = anno_idx, }, - nominal_content, }; }, } }; // Create the real statement and add it to scratch statements - const type_decl_stmt_idx = try self.env.addStatementAndTypeVar(real_cir_type_decl, type_decl_content, region); + try self.env.store.setStatementNode(type_decl_stmt_idx, type_decl_stmt); try self.env.store.addScratchStatement(type_decl_stmt_idx); - // Update the scope to point to the real statement instead of the placeholder - try self.scopeUpdateTypeDecl(type_header.name, type_decl_stmt_idx); - + // TODO: is this needed? // Remove from exposed_type_texts since the type is now fully defined - const type_text = self.env.getIdent(type_header.name); - _ = self.exposed_type_texts.remove(type_text); + // const type_text = self.env.getIdent(type_header.name); + // _ = self.exposed_type_texts.remove(type_text); }, else => { // Skip non-type-declaration statements in first pass @@ -5312,6 +5276,15 @@ fn canonicalizeTypeHeader(self: *Self, header_idx: AST.TypeHeader.Idx) std.mem.A }, Content{ .flex_var = null }, region); }; + // Check if this is a builtin type + // TODO: Can we compare idents or something here? The byte slice comparison is ineffecient + if (TypeAnno.Builtin.fromBytes(self.env.getIdentText(name_ident))) |_| { + return try self.env.pushMalformed(CIR.TypeHeader.Idx, Diagnostic{ .ident_already_in_scope = .{ + .ident = name_ident, + .region = region, + } }); + } + // Canonicalize type arguments - these are parameter declarations, not references const scratch_top = self.env.store.scratchTypeAnnoTop(); defer self.env.store.clearScratchTypeAnnosFrom(scratch_top); diff --git a/src/canonicalize/NodeStore.zig b/src/canonicalize/NodeStore.zig index 4b31c5869b..6457b95783 100644 --- a/src/canonicalize/NodeStore.zig +++ b/src/canonicalize/NodeStore.zig @@ -1155,6 +1155,29 @@ pub fn getExposedItem(store: *const NodeStore, exposedItem: CIR.ExposedItem.Idx) /// IMPORTANT: You should not use this function directly! Instead, use it's /// corresponding function in `ModuleEnv`. pub fn addStatement(store: *NodeStore, statement: CIR.Statement, region: base.Region) std.mem.Allocator.Error!CIR.Statement.Idx { + const node = try store.makeStatementNode(statement); + const node_idx = try store.nodes.append(store.gpa, node); + _ = try store.regions.append(store.gpa, region); + return @enumFromInt(@intFromEnum(node_idx)); +} + +/// Set a statement idx to the provided statement +/// +/// This is used when defininig recursive type declarations: +/// 1. Make the placeholder node +/// 2. Introduce to scope +/// 3. Canonicalize the annotation +/// 4. Update the placeholder node with the actual annotation +pub fn setStatementNode(store: *NodeStore, stmt_idx: CIR.Statement.Idx, statement: CIR.Statement) std.mem.Allocator.Error!void { + const node = try store.makeStatementNode(statement); + store.nodes.set(@enumFromInt(@intFromEnum(stmt_idx)), node); +} + +/// Creates a statement node, but does not append to the store. +/// IMPORTANT: It *does* append to extra_data though +/// +/// See `setStatementNode` to see why this exists +fn makeStatementNode(store: *NodeStore, statement: CIR.Statement) std.mem.Allocator.Error!Node { var node = Node{ .data_1 = 0, .data_2 = 0, @@ -1285,9 +1308,7 @@ pub fn addStatement(store: *NodeStore, statement: CIR.Statement, region: base.Re }, } - const node_idx = try store.nodes.append(store.gpa, node); - _ = try store.regions.append(store.gpa, region); - return @enumFromInt(@intFromEnum(node_idx)); + return node; } /// Adds an expression node to the store. diff --git a/src/check/Check.zig b/src/check/Check.zig index f082e8c888..7534da1738 100644 --- a/src/check/Check.zig +++ b/src/check/Check.zig @@ -43,23 +43,34 @@ cir: *ModuleEnv, regions: *Region.List, other_modules: []const *ModuleEnv, common_idents: CommonIdents, -// owned -snapshots: SnapshotStore, -problems: ProblemStore, -unify_scratch: unifier.Scratch, -occurs_scratch: occurs.Scratch, -// annos - new -anno_free_vars: base.Scratch(FreeVar), -decl_free_vars: base.Scratch(FreeVar), -scratch_vars: base.Scratch(Var), -scratch_tags: base.Scratch(types_mod.Tag), -scratch_record_fields: base.Scratch(types_mod.RecordField), -var_pool: VarPool, -generalizer: Generalizer, -// annos/instantiation - old +/// type snapshots used in error messages +snapshots: SnapshotStore, +/// type problems +problems: ProblemStore, +/// reusable scratch arrays used in unification +unify_scratch: unifier.Scratch, +/// reusable scratch arrays used in occurs check +occurs_scratch: occurs.Scratch, +/// free vars collected when generation types from annotation +anno_free_vars: base.Scratch(FreeVar), +/// free vars collected when generation types from type decls +decl_free_vars: base.Scratch(FreeVar), +/// stmts we've already seen when generation a type from an annotation +seen_annos: std.AutoHashMap(CIR.TypeAnno.Idx, Var), +/// pool of variables that need to be generalized, built up during checking +var_pool: VarPool, +/// wrapper around generalization, contains some internal state used to do it's work +generalizer: Generalizer, +/// scratch vars used to build up intermediate lists, used for various things +scratch_vars: base.Scratch(Var), +/// scratch tags used to build up intermediate lists, used for various things +scratch_tags: base.Scratch(types_mod.Tag), +/// scratch record fields used to build up intermediate lists, used for various things +scratch_record_fields: base.Scratch(types_mod.RecordField), +/// used in instantiation. TODO: Move into something like Instantiator var_map: std.AutoHashMap(Var, Var), -annotation_rigid_var_subs: Instantiate.RigidToFlexSubs, +/// used in instantiation. TODO: Move into something like Instantiator anonymous_rigid_var_subs: Instantiate.RigidToFlexSubs, /// Cache for imported types. This cache lives for the entire type-checking session /// of a module, so the same imported type can be reused across the entire module. @@ -101,10 +112,10 @@ pub fn init( .var_map = std.AutoHashMap(Var, Var).init(gpa), .anno_free_vars = try base.Scratch(FreeVar).init(gpa), .decl_free_vars = try base.Scratch(FreeVar).init(gpa), + .seen_annos = std.AutoHashMap(CIR.TypeAnno.Idx, Var).init(gpa), .scratch_vars = try base.Scratch(types_mod.Var).init(gpa), .scratch_tags = try base.Scratch(types_mod.Tag).init(gpa), .scratch_record_fields = try base.Scratch(types_mod.RecordField).init(gpa), - .annotation_rigid_var_subs = try Instantiate.RigidToFlexSubs.init(gpa), .anonymous_rigid_var_subs = try Instantiate.RigidToFlexSubs.init(gpa), .import_cache = ImportCache{}, .constraint_origins = std.AutoHashMap(Var, Var).init(gpa), @@ -122,7 +133,7 @@ pub fn deinit(self: *Self) void { self.var_map.deinit(); self.anno_free_vars.deinit(self.gpa); self.decl_free_vars.deinit(self.gpa); - self.annotation_rigid_var_subs.deinit(self.gpa); + self.seen_annos.deinit(); self.anonymous_rigid_var_subs.deinit(self.gpa); self.scratch_vars.deinit(self.gpa); self.scratch_tags.deinit(self.gpa); @@ -457,10 +468,28 @@ pub fn checkFile(self: *Self) std.mem.Allocator.Error!void { try self.checkDefs(); } +// repl // + +/// Check an expr for the repl +pub fn checkExprRepl(self: *Self, expr_idx: CIR.Expr.Idx) std.mem.Allocator.Error!void { + // Push the rank for this definitoin + try self.var_pool.pushRank(); + defer self.var_pool.popRank(); + + // Ensure that the current rank in the pool is top-level + const rank = types_mod.Rank.top_level; + std.debug.assert(rank == self.var_pool.current_rank); + + _ = try self.checkExprNew(expr_idx, rank, .no_expectation); + + // Now that we are existing the scope, we must generalize then pop this rank + try self.generalizer.generalize(&self.var_pool, rank); +} + // defs // /// Check the types for all defs -pub fn checkDefs(self: *Self) std.mem.Allocator.Error!void { +fn checkDefs(self: *Self) std.mem.Allocator.Error!void { const trace = tracy.trace(@src()); defer trace.end(); @@ -477,6 +506,7 @@ fn checkDef(self: *Self, def_idx: CIR.Def.Idx) std.mem.Allocator.Error!void { // Push the rank for this definitoin try self.var_pool.pushRank(); + defer self.var_pool.popRank(); // Ensure that the current rank in the pool is top-level const rank = types_mod.Rank.top_level; @@ -516,7 +546,6 @@ fn checkDef(self: *Self, def_idx: CIR.Def.Idx) std.mem.Allocator.Error!void { // Now that we are existing the scope, we must generalize then pop this rank try self.generalizer.generalize(&self.var_pool, rank); - self.var_pool.popRank(); } // annotations // @@ -541,227 +570,243 @@ fn generateAnnoType(self: *Self, free_vars_ctx: FreeVarCtx, anno_idx: CIR.TypeAn const trace = tracy.trace(@src()); defer trace.end(); + // First, check if we've seen this anno before + // This guards against recursive types + if (self.seen_annos.get(anno_idx)) |var_| { + return var_; + } + + // Get the annotation const anno = self.cir.store.getTypeAnno(anno_idx); const anno_region = self.cir.store.getNodeRegion(ModuleEnv.nodeIdxFrom(anno_idx)); - switch (anno) { - .rigid_var => |rigid| { - // If we have a rigid var, first check if we've seen it before - // If so, we redirect this instance to the previously seen instance - for (free_vars_ctx.sliceFreeVars()) |cached_var| { - if (cached_var.ident.idx == rigid.name.idx) { - return try self.freshRedirect(cached_var.var_, anno_region); + // Create a placeholder and put it into the seen variables + const placeholder_var = try self.fresh(Rank.generalized, anno_region); + try self.seen_annos.put(anno_idx, placeholder_var); + + const anno_var = blk: { + switch (anno) { + .rigid_var => |rigid| { + // If we have a rigid var, first check if we've seen it before + // If so, we redirect this instance to the previously seen instance + for (free_vars_ctx.sliceFreeVars()) |cached_var| { + if (cached_var.ident.idx == rigid.name.idx) { + break :blk try self.freshRedirect(cached_var.var_, anno_region); + } } - } - // If this is the first time we've seen this var, create it and add - // it to the cache - const var_ = blk: { - switch (free_vars_ctx.mode) { - .rigid => break :blk try self.freshFromContent(.{ .rigid_var = rigid.name }, Rank.generalized, anno_region), - .flex => break :blk try self.fresh(Rank.generalized, anno_region), + // If this is the first time we've seen this var, create it and add + // it to the cache + const var_ = inner_blk: { + switch (free_vars_ctx.mode) { + .rigid => break :inner_blk try self.freshFromContent(.{ .rigid_var = rigid.name }, Rank.generalized, anno_region), + .flex => break :inner_blk try self.fresh(Rank.generalized, anno_region), + } + }; + try free_vars_ctx.scratch.append(self.gpa, .{ .ident = rigid.name, .var_ = var_ }); + break :blk var_; + }, + .underscore => { + break :blk try self.fresh(Rank.generalized, anno_region); + }, + .lookup => |lookup| { + switch (lookup.base) { + .builtin => |builtin_type| { + break :blk try self.generateBuiltinTypeInstance(lookup.name, builtin_type, &.{}, anno_region); + }, + .local => |local| { + break :blk try self.generateTypeDeclInstance(&.{}, anno_region, local.decl_idx); + }, + .external => |_| { + @panic("TODO: External type lookups"); + // // TODO External + // const resolved_external = try self.resolveVarFromExternal(external.module_idx, external.target_node_idx) orelse { + // // TODO? + // break :blk try self.freshFromContent(.err, Rank.generalized, anno_region); + // }; + // break :blk try self.instantiateVarAnon(resolved_external.local_var, .{ .explicit = anno_region }); + }, } - }; - try free_vars_ctx.scratch.append(self.gpa, .{ .ident = rigid.name, .var_ = var_ }); - return var_; - }, - .underscore => { - return try self.fresh(Rank.generalized, anno_region); - }, - .lookup => |lookup| { - switch (lookup.base) { - .builtin => |builtin_type| { - return try self.generateBuiltinTypeInstance(lookup.name, builtin_type, &.{}, anno_region); - }, - .local => |local| { - return try self.generateTypeDeclInstance(&.{}, anno_region, local.decl_idx); - }, - .external => |_| { - @panic("TODO: External type lookups"); - // // TODO External - // const resolved_external = try self.resolveVarFromExternal(external.module_idx, external.target_node_idx) orelse { - // // TODO? - // return try self.freshFromContent(.err, Rank.generalized, anno_region); - // }; - // return try self.instantiateVarAnon(resolved_external.local_var, .{ .explicit = anno_region }); - }, - } - }, - .apply => |a| { - const scratch_vars_top = self.scratch_vars.top(); - defer self.scratch_vars.clearFrom(scratch_vars_top); - - // Generate the types for the arguments - const anno_args = self.cir.store.sliceTypeAnnos(a.args); - for (anno_args) |anno_arg| { - try self.scratch_vars.append(self.gpa, try self.generateAnnoType( - free_vars_ctx, - anno_arg, - )); - } - const args_var_slice = self.scratch_vars.sliceFromStart(scratch_vars_top); - - switch (a.base) { - .builtin => |builtin_type| { - return try self.generateBuiltinTypeInstance(a.name, builtin_type, args_var_slice, anno_region); - }, - .local => |local| { - return try self.generateTypeDeclInstance(args_var_slice, anno_region, local.decl_idx); - }, - .external => |_| { - @panic("TODO: External type apply"); - // // TODO External - // const resolved_external = try self.resolveVarFromExternal(external.module_idx, external.target_node_idx) orelse { - // // TODO? - // return try self.freshFromContent(.err, Rank.generalized, anno_region); - // }; - // return try self.instantiateVarAnon(resolved_external.local_var, .{ .explicit = anno_region }); - }, - } - }, - .@"fn" => |func| { - const scratch_vars_top = self.scratch_vars.top(); - defer self.scratch_vars.clearFrom(scratch_vars_top); - - const args_anno_slice = self.cir.store.sliceTypeAnnos(func.args); - for (args_anno_slice) |arg_anno_idx| { - try self.scratch_vars.append(self.gpa, try self.generateAnnoType( - free_vars_ctx, - arg_anno_idx, - )); - } - const args_var_slice = self.scratch_vars.sliceFromStart(scratch_vars_top); - - const fn_ret_var = try self.generateAnnoType(free_vars_ctx, func.ret); - - const fn_type = blk: { - if (func.effectful) { - break :blk try self.types.mkFuncEffectful(args_var_slice, fn_ret_var); - } else { - break :blk try self.types.mkFuncPure(args_var_slice, fn_ret_var); - } - }; - return try self.freshFromContent(fn_type, Rank.generalized, anno_region); - }, - .tag_union => |tag_union| { - const scratch_tags_top = self.scratch_tags.top(); - defer self.scratch_tags.clearFrom(scratch_tags_top); - - const tag_anno_slices = self.cir.store.sliceTypeAnnos(tag_union.tags); - for (tag_anno_slices) |tag_anno_idx| { - // Get the tag anno - const tag_type_anno = self.cir.store.getTypeAnno(tag_anno_idx); - std.debug.assert(tag_type_anno == .tag); - const tag = tag_type_anno.tag; - + }, + .apply => |a| { const scratch_vars_top = self.scratch_vars.top(); defer self.scratch_vars.clearFrom(scratch_vars_top); - // Generate the types for each tag arg - const tag_anno_args_slice = self.cir.store.sliceTypeAnnos(tag.args); - for (tag_anno_args_slice) |tag_arg_idx| { + // Generate the types for the arguments + const anno_args = self.cir.store.sliceTypeAnnos(a.args); + for (anno_args) |anno_arg| { try self.scratch_vars.append(self.gpa, try self.generateAnnoType( free_vars_ctx, - tag_arg_idx, + anno_arg, )); } - const tag_vars_slice = self.scratch_vars.sliceFromStart(scratch_vars_top); + const args_var_slice = self.scratch_vars.sliceFromStart(scratch_vars_top); - // Add the processed tag to scratch - try self.scratch_tags.append(self.gpa, try self.types.mkTag( - tag.name, - tag_vars_slice, - )); - } - - // Get the slice of tags - const tags_slice = self.scratch_tags.sliceFromStart(scratch_tags_top); - std.mem.sort(types_mod.Tag, tags_slice, self.cir.common.getIdentStore(), comptime types_mod.Tag.sortByNameAsc); - - // Process the ext if it exists. Absence means it's a closed union - const ext_var = blk: { - if (tag_union.ext) |ext_anno_idx| { - break :blk try self.generateAnnoType(free_vars_ctx, ext_anno_idx); - } else { - break :blk try self.freshFromContent(.{ .structure = .empty_tag_union }, Rank.generalized, anno_region); + switch (a.base) { + .builtin => |builtin_type| { + break :blk try self.generateBuiltinTypeInstance(a.name, builtin_type, args_var_slice, anno_region); + }, + .local => |local| { + break :blk try self.generateTypeDeclInstance(args_var_slice, anno_region, local.decl_idx); + }, + .external => |_| { + @panic("TODO: External type apply"); + // // TODO External + // const resolved_external = try self.resolveVarFromExternal(external.module_idx, external.target_node_idx) orelse { + // // TODO? + // break :blk try self.freshFromContent(.err, Rank.generalized, anno_region); + // }; + // break :blk try self.instantiateVarAnon(resolved_external.local_var, .{ .explicit = anno_region }); + }, } - }; + }, + .@"fn" => |func| { + const scratch_vars_top = self.scratch_vars.top(); + defer self.scratch_vars.clearFrom(scratch_vars_top); - // Create the type for the anno in the store - return try self.freshFromContent(try self.types.mkTagUnion(tags_slice, ext_var), Rank.generalized, anno_region); - }, - .tag => { - // This indicates a malformed type annotation. Tags should only - // exist as direct childen of tag_unions - std.debug.assert(false); - return try self.freshFromContent(.err, Rank.generalized, anno_region); - }, - .record => |rec| { - const scratch_record_fields_top = self.scratch_record_fields.top(); - defer self.scratch_record_fields.clearFrom(scratch_record_fields_top); + const args_anno_slice = self.cir.store.sliceTypeAnnos(func.args); + for (args_anno_slice) |arg_anno_idx| { + try self.scratch_vars.append(self.gpa, try self.generateAnnoType( + free_vars_ctx, + arg_anno_idx, + )); + } + const args_var_slice = self.scratch_vars.sliceFromStart(scratch_vars_top); - const recs_anno_slice = self.cir.store.sliceAnnoRecordFields(rec.fields); + const fn_ret_var = try self.generateAnnoType(free_vars_ctx, func.ret); - for (recs_anno_slice) |rec_anno_idx| { - const rec_field = self.cir.store.getAnnoRecordField(rec_anno_idx); + const fn_type = inner_blk: { + if (func.effectful) { + break :inner_blk try self.types.mkFuncEffectful(args_var_slice, fn_ret_var); + } else { + break :inner_blk try self.types.mkFuncPure(args_var_slice, fn_ret_var); + } + }; + break :blk try self.freshFromContent(fn_type, Rank.generalized, anno_region); + }, + .tag_union => |tag_union| { + const scratch_tags_top = self.scratch_tags.top(); + defer self.scratch_tags.clearFrom(scratch_tags_top); - const record_field_var = try self.generateAnnoType(free_vars_ctx, rec_field.ty); + const tag_anno_slices = self.cir.store.sliceTypeAnnos(tag_union.tags); + for (tag_anno_slices) |tag_anno_idx| { + // Get the tag anno + const tag_type_anno = self.cir.store.getTypeAnno(tag_anno_idx); + std.debug.assert(tag_type_anno == .tag); + const tag = tag_type_anno.tag; - // Add the processed tag to scratch - try self.scratch_record_fields.append(self.gpa, types_mod.RecordField{ - .name = rec_field.name, - .var_ = record_field_var, - }); - } + const scratch_vars_top = self.scratch_vars.top(); + defer self.scratch_vars.clearFrom(scratch_vars_top); - // Get the slice of record_fields - const record_fields_slice = self.scratch_record_fields.sliceFromStart(scratch_record_fields_top); - std.mem.sort(types_mod.RecordField, record_fields_slice, self.cir.common.getIdentStore(), comptime types_mod.RecordField.sortByNameAsc); - const fields_type_range = try self.types.appendRecordFields(record_fields_slice); + // Generate the types for each tag arg + const tag_anno_args_slice = self.cir.store.sliceTypeAnnos(tag.args); + for (tag_anno_args_slice) |tag_arg_idx| { + try self.scratch_vars.append(self.gpa, try self.generateAnnoType( + free_vars_ctx, + tag_arg_idx, + )); + } + const tag_vars_slice = self.scratch_vars.sliceFromStart(scratch_vars_top); - // Process the ext if it exists. Absence means it's a closed union - // TODO: Capture ext in record field CIR - // const ext_var = blk: { - // if (rec.ext) |ext_anno_idx| { - // try self.generateAnnoType(rigid_vars_ctx, ext_anno_idx); - // break :blk ModuleEnv.varFrom(ext_anno_idx); - // } else { - // break :blk try self.freshFromContent(.{ .structure = .empty_record }, Rank.generalized, anno_region); - // } - // }; - const ext_var = try self.freshFromContent(.{ .structure = .empty_record }, Rank.generalized, anno_region); + // Add the processed tag to scratch + try self.scratch_tags.append(self.gpa, try self.types.mkTag( + tag.name, + tag_vars_slice, + )); + } - // Create the type for the anno in the store - return try self.freshFromContent( - .{ .structure = types_mod.FlatType{ .record = .{ - .fields = fields_type_range, - .ext = ext_var, - } } }, - Rank.generalized, - anno_region, - ); - }, - .tuple => |tuple| { - const scratch_vars_top = self.scratch_vars.top(); - defer self.scratch_vars.clearFrom(scratch_vars_top); + // Get the slice of tags + const tags_slice = self.scratch_tags.sliceFromStart(scratch_tags_top); + std.mem.sort(types_mod.Tag, tags_slice, self.cir.common.getIdentStore(), comptime types_mod.Tag.sortByNameAsc); - const elems_anno_slice = self.cir.store.sliceTypeAnnos(tuple.elems); - for (elems_anno_slice) |arg_anno_idx| { - try self.scratch_vars.append(self.gpa, try self.generateAnnoType( - free_vars_ctx, - arg_anno_idx, - )); - } - const elems_range = try self.types.appendVars(@ptrCast(elems_anno_slice)); - return try self.freshFromContent(.{ .structure = .{ .tuple = .{ .elems = elems_range } } }, Rank.generalized, anno_region); - }, - .parens => |parens| { - return try self.generateAnnoType(free_vars_ctx, parens.anno); - }, - .malformed => { - return try self.freshFromContent(.err, Rank.generalized, anno_region); - }, - } + // Process the ext if it exists. Absence means it's a closed union + const ext_var = inner_blk: { + if (tag_union.ext) |ext_anno_idx| { + break :inner_blk try self.generateAnnoType(free_vars_ctx, ext_anno_idx); + } else { + break :inner_blk try self.freshFromContent(.{ .structure = .empty_tag_union }, Rank.generalized, anno_region); + } + }; + + // Create the type for the anno in the store + break :blk try self.freshFromContent(try self.types.mkTagUnion(tags_slice, ext_var), Rank.generalized, anno_region); + }, + .tag => { + // This indicates a malformed type annotation. Tags should only + // exist as direct childen of tag_unions + std.debug.assert(false); + break :blk try self.freshFromContent(.err, Rank.generalized, anno_region); + }, + .record => |rec| { + const scratch_record_fields_top = self.scratch_record_fields.top(); + defer self.scratch_record_fields.clearFrom(scratch_record_fields_top); + + const recs_anno_slice = self.cir.store.sliceAnnoRecordFields(rec.fields); + + for (recs_anno_slice) |rec_anno_idx| { + const rec_field = self.cir.store.getAnnoRecordField(rec_anno_idx); + + const record_field_var = try self.generateAnnoType(free_vars_ctx, rec_field.ty); + + // Add the processed tag to scratch + try self.scratch_record_fields.append(self.gpa, types_mod.RecordField{ + .name = rec_field.name, + .var_ = record_field_var, + }); + } + + // Get the slice of record_fields + const record_fields_slice = self.scratch_record_fields.sliceFromStart(scratch_record_fields_top); + std.mem.sort(types_mod.RecordField, record_fields_slice, self.cir.common.getIdentStore(), comptime types_mod.RecordField.sortByNameAsc); + const fields_type_range = try self.types.appendRecordFields(record_fields_slice); + + // Process the ext if it exists. Absence means it's a closed union + // TODO: Capture ext in record field CIR + // const ext_var = inner_blk: { + // if (rec.ext) |ext_anno_idx| { + // try self.generateAnnoType(rigid_vars_ctx, ext_anno_idx); + // break :inner_blk ModuleEnv.varFrom(ext_anno_idx); + // } else { + // break :inner_blk try self.freshFromContent(.{ .structure = .empty_record }, Rank.generalized, anno_region); + // } + // }; + const ext_var = try self.freshFromContent(.{ .structure = .empty_record }, Rank.generalized, anno_region); + + // Create the type for the anno in the store + break :blk try self.freshFromContent( + .{ .structure = types_mod.FlatType{ .record = .{ + .fields = fields_type_range, + .ext = ext_var, + } } }, + Rank.generalized, + anno_region, + ); + }, + .tuple => |tuple| { + const scratch_vars_top = self.scratch_vars.top(); + defer self.scratch_vars.clearFrom(scratch_vars_top); + + const elems_anno_slice = self.cir.store.sliceTypeAnnos(tuple.elems); + for (elems_anno_slice) |arg_anno_idx| { + try self.scratch_vars.append(self.gpa, try self.generateAnnoType( + free_vars_ctx, + arg_anno_idx, + )); + } + const elems_range = try self.types.appendVars(@ptrCast(elems_anno_slice)); + break :blk try self.freshFromContent(.{ .structure = .{ .tuple = .{ .elems = elems_range } } }, Rank.generalized, anno_region); + }, + .parens => |parens| { + break :blk try self.generateAnnoType(free_vars_ctx, parens.anno); + }, + .malformed => { + break :blk try self.freshFromContent(.err, Rank.generalized, anno_region); + }, + } + }; + + try self.types.setVarRedirect(placeholder_var, anno_var); + return placeholder_var; } /// Generate a type variable from the provided type declaration, substituting @@ -864,6 +909,8 @@ fn generateTypeDeclInstance( // name, args, and origin module, if the rhs of the nominal type is // invalid (ie an .err) then that error must propgate "through" the // nominal type. So the whole thing must be materialized here. + // + // TODO: Should this be flex instead of rigid? const backing_var = try self.generateAnnoType( FreeVarCtx{ .scratch = &self.decl_free_vars, .start = decl_free_vars_top, .mode = .rigid }, nominal.anno, @@ -1023,7 +1070,7 @@ pub const Expected = union(enum) { expected: struct { var_: Var, from_annotation: bool }, }; -pub fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, expected: Expected) std.mem.Allocator.Error!bool { +fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, expected: Expected) std.mem.Allocator.Error!bool { const trace = tracy.trace(@src()); defer trace.end(); @@ -1402,8 +1449,10 @@ pub fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, e defer self.anno_free_vars.clearFrom(anno_free_vars_top); // Enter a new rank - const next_rank = rank.next(); try self.var_pool.pushRank(); + defer self.var_pool.popRank(); + + const next_rank = rank.next(); std.debug.assert(next_rank == self.var_pool.current_rank); // Check all statements in the block @@ -1457,7 +1506,6 @@ pub fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, e // Now that we are existing the scope, we must generalize then pop this rank try self.generalizer.generalize(&self.var_pool, next_rank); - self.var_pool.popRank(); }, // function // .e_lambda => |lambda| { @@ -1503,8 +1551,11 @@ pub fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, e // Check the argument patterns const arg_pattern_idxs = self.cir.store.slicePatterns(lambda.args); - const next_rank = rank.next(); + // Enter the next rank try self.var_pool.pushRank(); + defer self.var_pool.popRank(); + + const next_rank = rank.next(); std.debug.assert(next_rank == self.var_pool.current_rank); // Now, check if the expected function has the the same number of @@ -1563,7 +1614,6 @@ pub fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, e // Now that we are existing the scope, we must generalize then pop this rank try self.generalizer.generalize(&self.var_pool, next_rank); - self.var_pool.popRank(); }, .e_closure => |closure| { does_fx = try self.checkExprNew(closure.lambda_idx, rank, expected) or does_fx; @@ -1788,7 +1838,7 @@ pub fn checkExprNew(self: *Self, expr_idx: CIR.Expr.Idx, rank: types_mod.Rank, e // pattern // /// Check the types for the provided pattern -pub fn checkPatternNew(self: *Self, pattern_idx: CIR.Pattern.Idx, rank: types_mod.Rank, expected: Expected) std.mem.Allocator.Error!void { +fn checkPatternNew(self: *Self, pattern_idx: CIR.Pattern.Idx, rank: types_mod.Rank, expected: Expected) std.mem.Allocator.Error!void { const trace = tracy.trace(@src()); defer trace.end(); diff --git a/src/check/test/type_checking_integration.zig b/src/check/test/type_checking_integration.zig index e54a537f9e..0be680552a 100644 --- a/src/check/test/type_checking_integration.zig +++ b/src/check/test/type_checking_integration.zig @@ -388,12 +388,12 @@ test "check type - nominal recursive type" { const source = \\module [] \\ - \\List(a) := [Nil, Cons(a, List(a))] + \\ConsList(a) := [Nil, Cons(a, ConsList(a))] \\ - \\x : List(Str) - \\x = List.Cons("hello", List.Nil) + \\x : ConsList(Str) + \\x = ConsList.Cons("hello", ConsList.Nil) ; - try assertFileTypeCheckPass(test_allocator, source, "List(Str)"); + try assertFileTypeCheckPass(test_allocator, source, "ConsList(Str)"); } // helpers - expr // @@ -463,7 +463,7 @@ fn assertFileTypeCheckPass(allocator: std.mem.Allocator, source: []const u8, exp // Type check var checker = try Check.init(allocator, &module_env.types, &module_env, &.{}, &module_env.store.regions, module_common_idents); defer checker.deinit(); - try checker.checkDefs(); + try checker.checkFile(); // Assert no problems var report_buf = try std.ArrayList(u8).initCapacity(allocator, 256); @@ -528,7 +528,7 @@ fn assertFileTypeCheckFail(allocator: std.mem.Allocator, source: []const u8, exp // Type check var checker = try Check.init(allocator, &module_env.types, &module_env, &.{}, &module_env.store.regions, module_common_idents); defer checker.deinit(); - try checker.checkDefs(); + try checker.checkFile(); // Assert no problems try testing.expectEqual(1, checker.problems.problems.items.len); diff --git a/src/repl/Repl.zig b/src/repl/Repl.zig index ef9a7ba8e1..e6045cb4bd 100644 --- a/src/repl/Repl.zig +++ b/src/repl/Repl.zig @@ -307,7 +307,7 @@ fn evaluatePureExpression(self: *Repl, expr_source: []const u8) ![]const u8 { }; defer checker.deinit(); - _ = checker.checkExpr(canonical_expr_idx.get_idx()) catch |err| { + _ = checker.checkExprRepl(canonical_expr_idx.get_idx()) catch |err| { return try std.fmt.allocPrint(self.allocator, "Type check error: {}", .{err}); }; diff --git a/src/snapshot_tool/main.zig b/src/snapshot_tool/main.zig index 500e4038b1..603291288d 100644 --- a/src/snapshot_tool/main.zig +++ b/src/snapshot_tool/main.zig @@ -1135,9 +1135,9 @@ fn processSnapshotContent( solver.debugAssertArraysInSync(); if (maybe_expr_idx) |expr_idx| { - _ = try solver.checkExpr(expr_idx.idx); + _ = try solver.checkExprRepl(expr_idx.idx); } else { - try solver.checkDefs(); + try solver.checkFile(); } // Cache round-trip validation - ensure ModuleCache serialization/deserialization works diff --git a/test/snapshots/annotations.md b/test/snapshots/annotations.md index 60274eb239..14931f1a4e 100644 --- a/test/snapshots/annotations.md +++ b/test/snapshots/annotations.md @@ -7,7 +7,10 @@ type=file ~~~roc module [] -f = |g, v| g(v) +ConsList(a) := [Nil, Cons(a, ConsList(a))] + +x : ConsList(Str) +x = ConsList.Cons("hello", ConsList.Nil) ~~~ # EXPECTED TYPE MISMATCH - annotations.md:18:28:18:28 @@ -19,24 +22,41 @@ NIL # TOKENS ~~~zig KwModule(1:1-1:7),OpenSquare(1:8-1:9),CloseSquare(1:9-1:10), -LowerIdent(3:1-3:2),OpAssign(3:3-3:4),OpBar(3:5-3:6),LowerIdent(3:6-3:7),Comma(3:7-3:8),LowerIdent(3:9-3:10),OpBar(3:10-3:11),LowerIdent(3:12-3:13),NoSpaceOpenRound(3:13-3:14),LowerIdent(3:14-3:15),CloseRound(3:15-3:16), -EndOfFile(4:1-4:1), +UpperIdent(3:1-3:9),NoSpaceOpenRound(3:9-3:10),LowerIdent(3:10-3:11),CloseRound(3:11-3:12),OpColonEqual(3:13-3:15),OpenSquare(3:16-3:17),UpperIdent(3:17-3:20),Comma(3:20-3:21),UpperIdent(3:22-3:26),NoSpaceOpenRound(3:26-3:27),LowerIdent(3:27-3:28),Comma(3:28-3:29),UpperIdent(3:30-3:38),NoSpaceOpenRound(3:38-3:39),LowerIdent(3:39-3:40),CloseRound(3:40-3:41),CloseRound(3:41-3:42),CloseSquare(3:42-3:43), +LowerIdent(5:1-5:2),OpColon(5:3-5:4),UpperIdent(5:5-5:13),NoSpaceOpenRound(5:13-5:14),UpperIdent(5:14-5:17),CloseRound(5:17-5:18), +LowerIdent(6:1-6:2),OpAssign(6:3-6:4),UpperIdent(6:5-6:13),NoSpaceDotUpperIdent(6:13-6:18),NoSpaceOpenRound(6:18-6:19),StringStart(6:19-6:20),StringPart(6:20-6:25),StringEnd(6:25-6:26),Comma(6:26-6:27),UpperIdent(6:28-6:36),NoSpaceDotUpperIdent(6:36-6:40),CloseRound(6:40-6:41), +EndOfFile(7:1-7:1), ~~~ # PARSE ~~~clojure -(file @1.1-3.16 +(file @1.1-6.41 (module @1.1-1.10 (exposes @1.8-1.10)) (statements - (s-decl @3.1-3.16 - (p-ident @3.1-3.2 (raw "f")) - (e-lambda @3.5-3.16 + (s-type-decl @3.1-3.43 + (header @3.1-3.12 (name "ConsList") (args - (p-ident @3.6-3.7 (raw "g")) - (p-ident @3.9-3.10 (raw "v"))) - (e-apply @3.12-3.16 - (e-ident @3.12-3.13 (raw "g")) - (e-ident @3.14-3.15 (raw "v"))))))) + (ty-var @3.10-3.11 (raw "a")))) + (ty-tag-union @3.16-3.43 + (tags + (ty @3.17-3.20 (name "Nil")) + (ty-apply @3.22-3.42 + (ty @3.22-3.26 (name "Cons")) + (ty-var @3.27-3.28 (raw "a")) + (ty-apply @3.30-3.41 + (ty @3.30-3.38 (name "ConsList")) + (ty-var @3.39-3.40 (raw "a"))))))) + (s-type-anno @5.1-5.18 (name "x") + (ty-apply @5.5-5.18 + (ty @5.5-5.13 (name "ConsList")) + (ty @5.14-5.17 (name "Str")))) + (s-decl @6.1-6.41 + (p-ident @6.1-6.2 (raw "x")) + (e-apply @6.5-6.41 + (e-tag @6.5-6.18 (raw "ConsList.Cons")) + (e-string @6.19-6.26 + (e-string-part @6.20-6.25 (raw "hello"))) + (e-tag @6.28-6.40 (raw "ConsList.Nil")))))) ~~~ # FORMATTED ~~~roc @@ -46,20 +66,36 @@ NO CHANGE ~~~clojure (can-ir (d-let - (p-assign @3.1-3.2 (ident "f")) - (e-lambda @3.5-3.16 - (args - (p-assign @3.6-3.7 (ident "g")) - (p-assign @3.9-3.10 (ident "v"))) - (e-call @3.12-3.16 - (e-lookup-local @3.14-3.15 - (p-assign @3.9-3.10 (ident "v"))))))) + (p-assign @6.1-6.2 (ident "x")) + (e-nominal @6.5-6.41 (nominal "ConsList") + (e-tag @6.5-6.41 (name "Cons") + (args + (e-string @6.19-6.26 + (e-literal @6.20-6.25 (string "hello"))) + (e-nominal @6.28-6.40 (nominal "ConsList") + (e-tag @6.28-6.40 (name "Nil")))))) + (annotation @6.1-6.2 + (declared-type + (ty-apply @5.5-5.18 (name "ConsList") (local) + (ty-lookup @5.14-5.17 (name "Str") (builtin)))))) + (s-nominal-decl @3.1-3.43 + (ty-header @3.1-3.12 (name "ConsList") + (ty-args + (ty-rigid-var @3.10-3.11 (name "a")))) + (ty-tag-union @3.16-3.43 + (tag_name @3.17-3.20 (name "Nil")) + (tag_name @3.22-3.42 (name "Cons"))))) ~~~ # TYPES ~~~clojure (inferred-types (defs - (patt @3.1-3.2 (type "a -> b, a -> b"))) + (patt @6.1-6.2 (type "ConsList(Str)"))) + (type_decls + (nominal @3.1-3.43 (type "Error") + (ty-header @3.1-3.12 (name "ConsList") + (ty-args + (ty-rigid-var @3.10-3.11 (name "a")))))) (expressions - (expr @3.5-3.16 (type "a -> b, a -> b")))) + (expr @6.5-6.41 (type "ConsList(Str)")))) ~~~