diff --git a/src/check/Check.zig b/src/check/Check.zig index 961b9c2580..dcd4fafd9e 100644 --- a/src/check/Check.zig +++ b/src/check/Check.zig @@ -4157,7 +4157,7 @@ fn checkBinopExpr( // num, otherwise the propgate error _ = try self.unify(expr_var, lhs_var, env); }, - .lt, .gt, .le, .ge, .eq, .ne => { + .lt, .gt, .le, .ge => { // Ensure the operands are the same type const result = try self.unify(lhs_var, rhs_var, env); @@ -4168,6 +4168,54 @@ fn checkBinopExpr( try self.unifyWith(expr_var, .err, env); } }, + .eq, .ne => { + // For == and !=, we need to check if the type implements is_eq/is_ne + // Create a static dispatch constraint for the is_eq/is_ne method + + // Ensure the operands are the same type + const lhs_rhs_result = try self.unify(lhs_var, rhs_var, env); + if (lhs_rhs_result.isProblem()) { + try self.unifyWith(expr_var, .err, env); + } else { + // Get the appropriate method name + const method_name = if (binop.op == .eq) self.cir.is_eq_ident else self.cir.is_ne_ident; + + // Create the function type: lhs_type, rhs_type -> Bool + const args_range = try self.types.appendVars(&.{ lhs_var, rhs_var }); + + // The return type is Bool + const ret_var = try self.freshBool(env, expr_region); + + // Create the constraint function type + const constraint_fn_var = try self.freshFromContent(.{ .structure = .{ .fn_unbound = Func{ + .args = args_range, + .ret = ret_var, + .needs_instantiation = false, + } } }, env, expr_region); + try env.var_pool.addVarToRank(constraint_fn_var, env.rank()); + + // Create the static dispatch constraint + const constraint = StaticDispatchConstraint{ + .fn_name = method_name, + .fn_var = constraint_fn_var, + .origin = .desugared_binop, + }; + const constraint_range = try self.types.appendStaticDispatchConstraints(&.{constraint}); + + // Create a constrained flex and unify it with the lhs (receiver) + const constrained_var = try self.freshFromContent( + .{ .flex = Flex{ .name = null, .constraints = constraint_range } }, + env, + expr_region, + ); + try env.var_pool.addVarToRank(constrained_var, env.rank()); + + _ = try self.unify(constrained_var, lhs_var, env); + + // Set the expression to redirect to the return type (Bool) + _ = try self.unify(expr_var, ret_var, env); + } + }, .@"and" => { const lhs_fresh_bool = try self.freshBool(env, expr_region); const lhs_result = try self.unify(lhs_fresh_bool, lhs_var, env); @@ -4738,8 +4786,51 @@ fn checkDeferredStaticDispatchConstraints(self: *Self, env: *Env) std.mem.Alloca ); } } + } else if (dispatcher_content == .structure and + (dispatcher_content.structure == .record or + dispatcher_content.structure == .tuple or + dispatcher_content.structure == .tag_union or + dispatcher_content.structure == .empty_record or + dispatcher_content.structure == .empty_tag_union)) + { + // Anonymous structural types (records, tuples, tag unions) have implicit is_eq + // only if all their components also support is_eq + const constraints = self.types.sliceStaticDispatchConstraints(deferred_constraint.constraints); + for (constraints) |constraint| { + const constraint_fn_name_bytes = self.cir.getIdent(constraint.fn_name); + + // Check if this is a call to is_eq (anonymous types have implicit structural equality) + if (std.mem.eql(u8, constraint_fn_name_bytes, "is_eq")) { + // Check if all components of this anonymous type support is_eq + if (self.typeSupportsIsEq(dispatcher_content.structure)) { + // All components support is_eq, unify return type with Bool + const resolved_constraint = self.types.resolveVar(constraint.fn_var); + const mb_resolved_func = resolved_constraint.desc.content.unwrapFunc(); + if (mb_resolved_func) |resolved_func| { + const region = self.getRegionAt(deferred_constraint.var_); + const bool_var = try self.freshBool(env, region); + _ = try self.unify(bool_var, resolved_func.ret, env); + } + } else { + // Some component doesn't support is_eq (e.g., contains a function) + try self.reportEqualityError( + deferred_constraint.var_, + constraint, + env, + ); + } + } else { + // Other methods are not supported on anonymous types + try self.reportConstraintError( + deferred_constraint.var_, + constraint, + .not_nominal, + env, + ); + } + } } else { - // If the root type is anything but a nominal type, push an error + // If the root type is anything but a nominal type or anonymous structural type, push an error const constraints = self.types.sliceStaticDispatchConstraints(deferred_constraint.constraints); if (constraints.len > 0) { @@ -4761,6 +4852,75 @@ fn checkDeferredStaticDispatchConstraints(self: *Self, env: *Env) std.mem.Alloca env.deferred_static_dispatch_constraints.items.clearRetainingCapacity(); } +/// Check if a structural type supports is_eq. +/// A type supports is_eq if: +/// - It's not a function type +/// - All of its components (record fields, tuple elements, tag payloads) also support is_eq +/// - For nominal types, we assume they support is_eq (TODO: actually check for is_eq method) +fn typeSupportsIsEq(self: *Self, flat_type: types_mod.FlatType) bool { + return switch (flat_type) { + // Function types do not support is_eq + .fn_pure, .fn_effectful, .fn_unbound => false, + + // Empty types trivially support is_eq + .empty_record, .empty_tag_union => true, + + // Records support is_eq if all field types support is_eq + .record => |record| { + const fields_slice = self.types.getRecordFieldsSlice(record.fields); + for (fields_slice.items(.var_)) |field_var| { + if (!self.varSupportsIsEq(field_var)) return false; + } + return true; + }, + + // Tuples support is_eq if all element types support is_eq + .tuple => |tuple| { + const elems = self.types.sliceVars(tuple.elems); + for (elems) |elem_var| { + if (!self.varSupportsIsEq(elem_var)) return false; + } + return true; + }, + + // Tag unions support is_eq if all payload types support is_eq + .tag_union => |tag_union| { + const tags_slice = self.types.getTagsSlice(tag_union.tags); + for (tags_slice.items(.args)) |tag_args| { + const args = self.types.sliceVars(tag_args); + for (args) |arg_var| { + if (!self.varSupportsIsEq(arg_var)) return false; + } + } + return true; + }, + + // Nominal types: TODO: actually check if they have an is_eq method + // For now, assume they do (numbers, Bool, Str, etc. all have is_eq) + .nominal_type => true, + + // Unbound records need to be resolved first + .record_unbound => true, // TODO: check resolved type + }; +} + +/// Check if a type variable supports is_eq by resolving it and checking its content +fn varSupportsIsEq(self: *Self, var_: Var) bool { + const resolved = self.types.resolveVar(var_); + return switch (resolved.desc.content) { + .structure => |s| self.typeSupportsIsEq(s), + // Flex/rigid vars could be anything, assume they support is_eq for now + // (the actual constraint will be checked when the type is known) + .flex, .rigid => true, + // Aliases: check the underlying type + .alias => |alias| self.varSupportsIsEq(self.types.getAliasBackingVar(alias)), + // Recursion vars: assume they support is_eq (recursive types like List are ok) + .recursion_var => true, + // Error types: allow them to proceed + .err => true, + }; +} + /// Mark a constraint function's return type as error fn markConstraintFunctionAsError(self: *Self, constraint: StaticDispatchConstraint, env: *Env) !void { const resolved_constraint = self.types.resolveVar(constraint.fn_var); @@ -4807,6 +4967,26 @@ fn reportConstraintError( try self.markConstraintFunctionAsError(constraint, env); } +/// Report an error when an anonymous type doesn't support equality +fn reportEqualityError( + self: *Self, + dispatcher_var: Var, + constraint: StaticDispatchConstraint, + env: *Env, +) !void { + const snapshot = try self.snapshots.deepCopyVar(self.types, dispatcher_var); + const equality_problem = problem.Problem{ .static_dispach = .{ + .type_does_not_support_equality = .{ + .dispatcher_var = dispatcher_var, + .dispatcher_snapshot = snapshot, + .fn_var = constraint.fn_var, + }, + } }; + _ = try self.problems.appendProblem(self.cir.gpa, equality_problem); + + try self.markConstraintFunctionAsError(constraint, env); +} + /// Pool for reusing Env instances to avoid repeated allocations const EnvPool = struct { available: std.ArrayList(Env), diff --git a/src/check/problem.zig b/src/check/problem.zig index a947550bc7..e9e55d7901 100644 --- a/src/check/problem.zig +++ b/src/check/problem.zig @@ -217,6 +217,7 @@ pub const InvalidBoolBinop = struct { pub const StaticDispatch = union(enum) { dispatcher_not_nominal: DispatcherNotNominal, dispatcher_does_not_impl_method: DispatcherDoesNotImplMethod, + type_does_not_support_equality: TypeDoesNotSupportEquality, }; /// Error when you try to static dispatch on something that's not a nominal type @@ -240,6 +241,14 @@ pub const DispatcherDoesNotImplMethod = struct { pub const DispatcherType = enum { nominal, rigid }; }; +/// Error when an anonymous type (record, tuple, tag union) doesn't support equality +/// because one or more of its components contain types that don't have is_eq +pub const TypeDoesNotSupportEquality = struct { + dispatcher_var: Var, + dispatcher_snapshot: SnapshotContentIdx, + fn_var: Var, +}; + // bug // /// Error when you try to apply the wrong number of arguments to a type in @@ -367,6 +376,7 @@ pub const ReportBuilder = struct { switch (detail) { .dispatcher_not_nominal => |data| return self.buildStaticDispatchDispatcherNotNominal(data), .dispatcher_does_not_impl_method => |data| return self.buildStaticDispatchDispatcherDoesNotImplMethod(data), + .type_does_not_support_equality => |data| return self.buildTypeDoesNotSupportEquality(data), } }, .number_does_not_fit => |data| { @@ -1768,6 +1778,288 @@ pub const ReportBuilder = struct { return report; } + /// Build a report for when an anonymous type doesn't support equality + fn buildTypeDoesNotSupportEquality( + self: *Self, + data: TypeDoesNotSupportEquality, + ) !Report { + var report = Report.init(self.gpa, "TYPE DOES NOT SUPPORT EQUALITY", .runtime_error); + errdefer report.deinit(); + + self.snapshot_writer.resetContext(); + try self.snapshot_writer.write(data.dispatcher_snapshot); + const snapshot_str = try report.addOwnedString(self.snapshot_writer.get()); + + const region = self.can_ir.store.regions.get(@enumFromInt(@intFromEnum(data.fn_var))); + const region_info = self.module_env.calcRegionInfo(region.*); + + try report.document.addReflowingText("This expression uses "); + try report.document.addAnnotated("==", .emphasized); + try report.document.addReflowingText(" or "); + try report.document.addAnnotated("!=", .emphasized); + try report.document.addReflowingText(" on a type that doesn't support equality:"); + try report.document.addLineBreak(); + + try report.document.addSourceRegion( + region_info, + .error_highlight, + self.filename, + self.source, + self.module_env.getLineStarts(), + ); + try report.document.addLineBreak(); + + try report.document.addReflowingText("The type is:"); + try report.document.addLineBreak(); + try report.document.addText(" "); + try report.document.addAnnotated(snapshot_str, .type_variable); + try report.document.addLineBreak(); + try report.document.addLineBreak(); + + // Get the content and explain which parts don't support equality + const content = self.snapshots.getContent(data.dispatcher_snapshot); + if (content == .structure) { + switch (content.structure) { + .record => |record| { + try self.explainRecordEqualityFailure(&report, record); + }, + .tuple => |tuple| { + try self.explainTupleEqualityFailure(&report, tuple); + }, + .tag_union => |tag_union| { + try self.explainTagUnionEqualityFailure(&report, tag_union); + }, + .fn_pure, .fn_effectful, .fn_unbound => { + try report.document.addReflowingText("Functions cannot be compared for equality."); + try report.document.addLineBreak(); + }, + else => {}, + } + } + + return report; + } + + /// Explain which record fields don't support equality + fn explainRecordEqualityFailure( + self: *Self, + report: *Report, + record: snapshot.SnapshotRecord, + ) !void { + const fields = self.snapshots.sliceRecordFields(record.fields); + var has_problem_fields = false; + + // First pass: check if any fields don't support equality + for (fields.items(.content)) |field_content_idx| { + if (!self.snapshotSupportsEquality(field_content_idx)) { + has_problem_fields = true; + break; + } + } + + if (has_problem_fields) { + try report.document.addReflowingText("This record does not support equality because these fields have types that don't support "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(":"); + try report.document.addLineBreak(); + + const field_names = fields.items(.name); + const field_contents = fields.items(.content); + for (field_names, field_contents) |name, field_content_idx| { + if (!self.snapshotSupportsEquality(field_content_idx)) { + const field_name = self.can_ir.getIdentText(name); + + self.snapshot_writer.resetContext(); + try self.snapshot_writer.write(field_content_idx); + const field_type_str = try report.addOwnedString(self.snapshot_writer.get()); + + try report.document.addText(" "); + try report.document.addAnnotated(field_name, .emphasized); + try report.document.addText(": "); + try report.document.addAnnotated(field_type_str, .type_variable); + try report.document.addLineBreak(); + } + } + try report.document.addLineBreak(); + try report.document.addAnnotated("Hint: ", .emphasized); + try report.document.addReflowingText("Anonymous records only have an "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(" method if all of their fields have "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(" methods."); + try report.document.addLineBreak(); + } + } + + /// Explain which tuple elements don't support equality + fn explainTupleEqualityFailure( + self: *Self, + report: *Report, + tuple: snapshot.SnapshotTuple, + ) !void { + const elems = self.snapshots.sliceVars(tuple.elems); + var has_problem_elems = false; + + // First pass: check if any elements don't support equality + for (elems) |elem_content_idx| { + if (!self.snapshotSupportsEquality(elem_content_idx)) { + has_problem_elems = true; + break; + } + } + + if (has_problem_elems) { + try report.document.addReflowingText("This tuple does not support equality because these elements have types that don't support "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(":"); + try report.document.addLineBreak(); + + for (elems, 0..) |elem_content_idx, i| { + if (!self.snapshotSupportsEquality(elem_content_idx)) { + self.snapshot_writer.resetContext(); + try self.snapshot_writer.write(elem_content_idx); + const elem_type_str = try report.addOwnedString(self.snapshot_writer.get()); + + try report.document.addText(" element "); + var buf: [20]u8 = undefined; + const index_str = std.fmt.bufPrint(&buf, "{}", .{i}) catch "?"; + try report.document.addAnnotated(index_str, .emphasized); + try report.document.addText(": "); + try report.document.addAnnotated(elem_type_str, .type_variable); + try report.document.addLineBreak(); + } + } + try report.document.addLineBreak(); + try report.document.addAnnotated("Hint: ", .emphasized); + try report.document.addReflowingText("Tuples only have an "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(" method if all of their elements have "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(" methods."); + try report.document.addLineBreak(); + } + } + + /// Explain which tag union payloads don't support equality + fn explainTagUnionEqualityFailure( + self: *Self, + report: *Report, + tag_union: snapshot.SnapshotTagUnion, + ) !void { + const tags = self.snapshots.sliceTags(tag_union.tags); + var has_problem_tags = false; + + // First pass: check if any tag payloads don't support equality + for (tags.items(.args)) |tag_args| { + const args = self.snapshots.sliceVars(tag_args); + for (args) |arg_content_idx| { + if (!self.snapshotSupportsEquality(arg_content_idx)) { + has_problem_tags = true; + break; + } + } + if (has_problem_tags) break; + } + + if (has_problem_tags) { + try report.document.addReflowingText("This tag union does not support equality because these tags have payload types that don't support "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(":"); + try report.document.addLineBreak(); + + const tag_names = tags.items(.name); + const tag_args_list = tags.items(.args); + for (tag_names, tag_args_list) |name, tag_args| { + const args = self.snapshots.sliceVars(tag_args); + var tag_has_problem = false; + for (args) |arg_content_idx| { + if (!self.snapshotSupportsEquality(arg_content_idx)) { + tag_has_problem = true; + break; + } + } + if (tag_has_problem) { + const tag_name = self.can_ir.getIdentText(name); + try report.document.addText(" "); + try report.document.addAnnotated(tag_name, .emphasized); + + // Show the problematic payload types + if (args.len > 0) { + try report.document.addText(" ("); + var first = true; + for (args) |arg_content_idx| { + if (!first) try report.document.addText(", "); + first = false; + + self.snapshot_writer.resetContext(); + try self.snapshot_writer.write(arg_content_idx); + const arg_type_str = try report.addOwnedString(self.snapshot_writer.get()); + try report.document.addAnnotated(arg_type_str, .type_variable); + } + try report.document.addText(")"); + } + try report.document.addLineBreak(); + } + } + try report.document.addLineBreak(); + try report.document.addAnnotated("Hint: ", .emphasized); + try report.document.addReflowingText("Tag unions only have an "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(" method if all of their payload types have "); + try report.document.addAnnotated("is_eq", .emphasized); + try report.document.addReflowingText(" methods."); + try report.document.addLineBreak(); + } + } + + /// Check if a snapshotted type supports equality + fn snapshotSupportsEquality(self: *Self, content_idx: snapshot.SnapshotContentIdx) bool { + const content = self.snapshots.getContent(content_idx); + return switch (content) { + .structure => |s| switch (s) { + // Functions never support equality + .fn_pure, .fn_effectful, .fn_unbound => false, + // Empty types trivially support equality + .empty_record, .empty_tag_union => true, + // Records: all fields must support equality + .record => |record| { + const fields = self.snapshots.sliceRecordFields(record.fields); + for (fields.items(.content)) |field_content| { + if (!self.snapshotSupportsEquality(field_content)) return false; + } + return true; + }, + // Tuples: all elements must support equality + .tuple => |tuple| { + const elems = self.snapshots.sliceVars(tuple.elems); + for (elems) |elem_content| { + if (!self.snapshotSupportsEquality(elem_content)) return false; + } + return true; + }, + // Tag unions: all payloads must support equality + .tag_union => |tag_union| { + const tags_slice = self.snapshots.sliceTags(tag_union.tags); + for (tags_slice.items(.args)) |tag_args| { + const args = self.snapshots.sliceVars(tag_args); + for (args) |arg_content| { + if (!self.snapshotSupportsEquality(arg_content)) return false; + } + } + return true; + }, + // Other types (nominal, box, etc.) assumed to support equality + else => true, + }, + // Aliases: check the underlying type + .alias => |alias| self.snapshotSupportsEquality(alias.backing), + // Recursion vars: assume they support equality + .recursion_var => true, + // Other types (flex, rigid, recursive, err) assumed to support equality + else => true, + }; + } + // number problems // /// Build a report for "number does not fit in type" diagnostic diff --git a/src/check/snapshot.zig b/src/check/snapshot.zig index 5bb927622b..2ed61f8e66 100644 --- a/src/check/snapshot.zig +++ b/src/check/snapshot.zig @@ -1217,6 +1217,10 @@ pub const Store = struct { return self.static_dispatch_constraints.sliceRange(range); } + pub fn sliceTags(self: *const Self, range: SnapshotTagSafeList.Range) SnapshotTagSafeList.Slice { + return self.tags.sliceRange(range); + } + pub fn getContent(self: *const Self, idx: SnapshotContentIdx) SnapshotContent { return self.contents.get(idx).*; } diff --git a/src/check/test/type_checking_integration.zig b/src/check/test/type_checking_integration.zig index ee89820765..5bb1901df0 100644 --- a/src/check/test/type_checking_integration.zig +++ b/src/check/test/type_checking_integration.zig @@ -196,6 +196,85 @@ test "check type - record" { try checkTypesExpr(source, .pass, "{ hello: Str, world: a } where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)])]"); } +// anonymous type equality (is_eq) // + +test "check type - record equality - same records are equal" { + const source = + \\{ x: 1, y: 2 } == { x: 1, y: 2 } + ; + try checkTypesExpr(source, .pass, "Bool"); +} + +test "check type - tuple equality - same tuples are equal" { + const source = + \\(1, 2) == (1, 2) + ; + try checkTypesExpr(source, .pass, "Bool"); +} + +test "check type - empty record equality" { + const source = + \\{} == {} + ; + try checkTypesExpr(source, .pass, "Bool"); +} + +test "check type - record with function field - no is_eq" { + // Records containing functions should not have is_eq because functions don't have is_eq + const source = + \\{ x: 1, f: |a| a + 1 } == { x: 1, f: |a| a + 1 } + ; + try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY"); +} + +test "check type - tuple with function element - no is_eq" { + // Tuples containing functions should not have is_eq because functions don't have is_eq + const source = + \\(1, |a| a) == (1, |a| a) + ; + try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY"); +} + +test "check type - nested record equality" { + // Nested records should type-check as Bool + const source = + \\{ a: { x: 1 }, b: 2 } == { a: { x: 1 }, b: 2 } + ; + try checkTypesExpr(source, .pass, "Bool"); +} + +test "check type - nested tuple equality" { + // Nested tuples should type-check as Bool + const source = + \\((1, 2), 3) == ((1, 2), 3) + ; + try checkTypesExpr(source, .pass, "Bool"); +} + +test "check type - nested record with function - no is_eq" { + // Nested records containing functions should not have is_eq + const source = + \\{ a: { f: |x| x } } == { a: { f: |x| x } } + ; + try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY"); +} + +test "check type - tag union equality" { + // Tag unions should type-check for equality + const source = + \\Ok(1) == Ok(1) + ; + try checkTypesExpr(source, .pass, "Bool"); +} + +test "check type - tag union with function payload - no is_eq" { + // Tag unions with function payloads should not have is_eq + const source = + \\Fn(|x| x) == Fn(|x| x) + ; + try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY"); +} + // tags // test "check type - tag" { diff --git a/src/check/unify.zig b/src/check/unify.zig index 73f27a93be..86afe48987 100644 --- a/src/check/unify.zig +++ b/src/check/unify.zig @@ -2038,12 +2038,18 @@ const Unifier = struct { const trace = tracy.trace(@src()); defer trace.end(); - const range_start: u32 = self.types_store.record_fields.len(); - - // Here, iterate over shared fields, sub unifying the field variables. - // At this point, the fields are know to be identical, so we arbitrary choose b + // First, unify all field types. This may cause nested record unifications + // which will append their own fields to the store. We must NOT interleave + // our field appends with these nested calls. for (shared_fields) |shared| { try self.unifyGuarded(shared.a.var_, shared.b.var_); + } + + // Now that all nested unifications are complete, append OUR fields. + // This ensures our fields form a contiguous range. + const range_start: u32 = self.types_store.record_fields.len(); + + for (shared_fields) |shared| { _ = self.types_store.appendRecordFields(&[_]RecordField{.{ .name = shared.b.name, .var_ = shared.b.var_, @@ -2058,7 +2064,7 @@ const Unifier = struct { _ = self.types_store.appendRecordFields(extended_fields) catch return Error.AllocatorError; } - // Merge vars + // Merge vars - now the range correctly contains only THIS record's fields self.merge(vars, Content{ .structure = FlatType{ .record = .{ .fields = self.types_store.record_fields.rangeToEnd(range_start), .ext = ext, diff --git a/src/eval/interpreter.zig b/src/eval/interpreter.zig index 1f4ff06d0a..c859b1da77 100644 --- a/src/eval/interpreter.zig +++ b/src/eval/interpreter.zig @@ -4442,13 +4442,13 @@ pub const Interpreter = struct { .tag_union => { return try self.structuralEqualTag(lhs, rhs, lhs_var); }, - .list => |elem_var| { - return try self.structuralEqualList(lhs, rhs, elem_var); - }, .empty_record => true, - .list_unbound, .record_unbound, .fn_pure, .fn_effectful, .fn_unbound, .nominal_type, .empty_tag_union, .box => error.NotImplemented, - .str => error.NotImplemented, - .num => error.NotImplemented, + .empty_tag_union => true, + .nominal_type => |nom| { + // For nominal types, dispatch to their is_eq method + return try self.dispatchNominalIsEq(lhs, rhs, nom, lhs_var); + }, + .record_unbound, .fn_pure, .fn_effectful, .fn_unbound => error.NotImplemented, }; } @@ -4630,6 +4630,41 @@ pub const Interpreter = struct { return true; } + /// Dispatch is_eq method call for a nominal type + fn dispatchNominalIsEq( + self: *Interpreter, + lhs: StackValue, + rhs: StackValue, + nom: types.NominalType, + lhs_var: types.Var, + ) StructuralEqError!bool { + // TODO: Properly dispatch to the nominal type's is_eq method + // For now, use a simplified approach: + // - If the nominal type is a wrapper around a scalar or simple structure, compare directly + // - Otherwise, fall back to structural comparison of the backing type + + // Get the backing var of the nominal type + const backing_var = self.runtime_types.getNominalBackingVar(nom); + const backing_resolved = self.runtime_types.resolveVar(backing_var); + + // If the backing type is a structure, recursively compare + if (backing_resolved.desc.content == .structure) { + return self.valuesStructurallyEqual(lhs, backing_var, rhs, backing_var); + } + + // For other cases, fall back to attempting scalar comparison + // This handles cases like Bool which wraps a tag union but is represented as a scalar + if (lhs.layout.tag == .scalar and rhs.layout.tag == .scalar) { + const order = self.compareNumericScalars(lhs, rhs) catch return error.NotImplemented; + return order == .eq; + } + + // Can't compare - likely a user-defined nominal type that needs is_eq dispatch + // TODO: Implement proper method dispatch by looking up is_eq in the nominal type's module + _ = lhs_var; + return error.NotImplemented; + } + pub fn getCanonicalBoolRuntimeVar(self: *Interpreter) !types.Var { if (self.canonical_bool_rt_var) |cached| return cached; // Use the dynamic bool_stmt index (from the Bool module) @@ -5437,22 +5472,40 @@ pub const Interpreter = struct { var rhs = try self.evalExprMinimal(rhs_expr, roc_ops, rhs_rt_var); defer rhs.decref(&self.runtime_layout_store, roc_ops); - // Get the nominal type information from lhs - const nominal_info = switch (lhs_resolved.desc.content) { + // Get the nominal type information from lhs, or handle anonymous structural types + const nominal_info: ?struct { origin: base_pkg.Ident.Idx, ident: base_pkg.Ident.Idx } = switch (lhs_resolved.desc.content) { .structure => |s| switch (s) { .nominal_type => |nom| .{ .origin = nom.origin_module, .ident = nom.ident.ident_idx, }, - else => return error.InvalidMethodReceiver, + .record, .tuple, .tag_union, .empty_record, .empty_tag_union => blk: { + // Anonymous structural types have implicit is_eq + if (method_ident == self.env.is_eq_ident) { + const result = self.valuesStructurallyEqual(lhs, lhs_rt_var, rhs, rhs_rt_var) catch |err| { + // If structural equality is not implemented for this type, return false + if (err == error.NotImplemented) { + return try self.makeBoolValue(false); + } + return err; + }; + return try self.makeBoolValue(result); + } + break :blk null; + }, + else => null, }, - else => return error.InvalidMethodReceiver, + else => null, }; + if (nominal_info == null) { + return error.InvalidMethodReceiver; + } + // Resolve the method function const method_func = self.resolveMethodFunction( - nominal_info.origin, - nominal_info.ident, + nominal_info.?.origin, + nominal_info.?.ident, method_ident, roc_ops, ) catch |err| return err; diff --git a/src/eval/test/eval_test.zig b/src/eval/test/eval_test.zig index 0774eb4133..90aceed62c 100644 --- a/src/eval/test/eval_test.zig +++ b/src/eval/test/eval_test.zig @@ -848,3 +848,50 @@ test "ModuleEnv serialization and interpreter evaluation" { } } } + +// Tests for anonymous type equality (is_eq on records, tuples, and tag unions) + +test "anonymous record equality" { + // Same records should be equal + try runExpectBool("{ x: 1, y: 2 } == { x: 1, y: 2 }", true, .no_trace); + // Different values should not be equal + try runExpectBool("{ x: 1, y: 2 } == { x: 1, y: 3 }", false, .no_trace); + // Field order shouldn't matter + try runExpectBool("{ x: 1, y: 2 } == { y: 2, x: 1 }", true, .no_trace); +} + +test "anonymous tuple equality" { + // Same tuples should be equal + try runExpectBool("(1, 2) == (1, 2)", true, .no_trace); + // Different values should not be equal + try runExpectBool("(1, 2) == (1, 3)", false, .no_trace); +} + +test "empty record equality" { + try runExpectBool("{} == {}", true, .no_trace); +} + +test "string field equality" { + try runExpectBool("{ name: \"hello\" } == { name: \"hello\" }", true, .no_trace); + try runExpectBool("{ name: \"hello\" } == { name: \"world\" }", false, .no_trace); +} + +test "nested record equality" { + try runExpectBool("{ a: { x: 1 }, b: 2 } == { a: { x: 1 }, b: 2 }", true, .no_trace); + try runExpectBool("{ a: { x: 1 }, b: 2 } == { a: { x: 2 }, b: 2 }", false, .no_trace); + try runExpectBool("{ outer: { inner: { deep: 42 } } } == { outer: { inner: { deep: 42 } } }", true, .no_trace); + try runExpectBool("{ outer: { inner: { deep: 42 } } } == { outer: { inner: { deep: 99 } } }", false, .no_trace); +} + +test "bool field equality" { + // Use comparison expressions to produce boolean values for record fields + try runExpectBool("{ flag: (1 == 1) } == { flag: (1 == 1) }", true, .no_trace); + try runExpectBool("{ flag: (1 == 1) } == { flag: (1 != 1) }", false, .no_trace); +} + +test "nested tuple equality" { + try runExpectBool("((1, 2), 3) == ((1, 2), 3)", true, .no_trace); + try runExpectBool("((1, 2), 3) == ((1, 9), 3)", false, .no_trace); + try runExpectBool("(1, (2, 3)) == (1, (2, 3))", true, .no_trace); + try runExpectBool("(1, (2, 3)) == (1, (2, 9))", false, .no_trace); +} diff --git a/test/snapshots/if_then_else/if_then_else_nested_chain.md b/test/snapshots/if_then_else/if_then_else_nested_chain.md index 8ff0e2d100..10d7a36919 100644 --- a/test/snapshots/if_then_else/if_then_else_nested_chain.md +++ b/test/snapshots/if_then_else/if_then_else_nested_chain.md @@ -125,7 +125,7 @@ NO CHANGE ~~~clojure (inferred-types (defs - (patt (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)])]"))) + (patt (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)]), a.is_eq : a, a -> Bool]"))) (expressions - (expr (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)])]")))) + (expr (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)]), a.is_eq : a, a -> Bool]")))) ~~~