mirror of
https://github.com/roc-lang/roc.git
synced 2025-12-23 08:48:03 +00:00
Merge pull request #8415 from roc-lang/is_eq_fixes
Infer is_eq for anonymous records, tuples, and tag unions
This commit is contained in:
commit
47e0f7f752
8 changed files with 682 additions and 21 deletions
|
|
@ -4157,7 +4157,7 @@ fn checkBinopExpr(
|
|||
// num, otherwise the propgate error
|
||||
_ = try self.unify(expr_var, lhs_var, env);
|
||||
},
|
||||
.lt, .gt, .le, .ge, .eq, .ne => {
|
||||
.lt, .gt, .le, .ge => {
|
||||
// Ensure the operands are the same type
|
||||
const result = try self.unify(lhs_var, rhs_var, env);
|
||||
|
||||
|
|
@ -4168,6 +4168,54 @@ fn checkBinopExpr(
|
|||
try self.unifyWith(expr_var, .err, env);
|
||||
}
|
||||
},
|
||||
.eq, .ne => {
|
||||
// For == and !=, we need to check if the type implements is_eq/is_ne
|
||||
// Create a static dispatch constraint for the is_eq/is_ne method
|
||||
|
||||
// Ensure the operands are the same type
|
||||
const lhs_rhs_result = try self.unify(lhs_var, rhs_var, env);
|
||||
if (lhs_rhs_result.isProblem()) {
|
||||
try self.unifyWith(expr_var, .err, env);
|
||||
} else {
|
||||
// Get the appropriate method name
|
||||
const method_name = if (binop.op == .eq) self.cir.is_eq_ident else self.cir.is_ne_ident;
|
||||
|
||||
// Create the function type: lhs_type, rhs_type -> Bool
|
||||
const args_range = try self.types.appendVars(&.{ lhs_var, rhs_var });
|
||||
|
||||
// The return type is Bool
|
||||
const ret_var = try self.freshBool(env, expr_region);
|
||||
|
||||
// Create the constraint function type
|
||||
const constraint_fn_var = try self.freshFromContent(.{ .structure = .{ .fn_unbound = Func{
|
||||
.args = args_range,
|
||||
.ret = ret_var,
|
||||
.needs_instantiation = false,
|
||||
} } }, env, expr_region);
|
||||
try env.var_pool.addVarToRank(constraint_fn_var, env.rank());
|
||||
|
||||
// Create the static dispatch constraint
|
||||
const constraint = StaticDispatchConstraint{
|
||||
.fn_name = method_name,
|
||||
.fn_var = constraint_fn_var,
|
||||
.origin = .desugared_binop,
|
||||
};
|
||||
const constraint_range = try self.types.appendStaticDispatchConstraints(&.{constraint});
|
||||
|
||||
// Create a constrained flex and unify it with the lhs (receiver)
|
||||
const constrained_var = try self.freshFromContent(
|
||||
.{ .flex = Flex{ .name = null, .constraints = constraint_range } },
|
||||
env,
|
||||
expr_region,
|
||||
);
|
||||
try env.var_pool.addVarToRank(constrained_var, env.rank());
|
||||
|
||||
_ = try self.unify(constrained_var, lhs_var, env);
|
||||
|
||||
// Set the expression to redirect to the return type (Bool)
|
||||
_ = try self.unify(expr_var, ret_var, env);
|
||||
}
|
||||
},
|
||||
.@"and" => {
|
||||
const lhs_fresh_bool = try self.freshBool(env, expr_region);
|
||||
const lhs_result = try self.unify(lhs_fresh_bool, lhs_var, env);
|
||||
|
|
@ -4738,8 +4786,51 @@ fn checkDeferredStaticDispatchConstraints(self: *Self, env: *Env) std.mem.Alloca
|
|||
);
|
||||
}
|
||||
}
|
||||
} else if (dispatcher_content == .structure and
|
||||
(dispatcher_content.structure == .record or
|
||||
dispatcher_content.structure == .tuple or
|
||||
dispatcher_content.structure == .tag_union or
|
||||
dispatcher_content.structure == .empty_record or
|
||||
dispatcher_content.structure == .empty_tag_union))
|
||||
{
|
||||
// Anonymous structural types (records, tuples, tag unions) have implicit is_eq
|
||||
// only if all their components also support is_eq
|
||||
const constraints = self.types.sliceStaticDispatchConstraints(deferred_constraint.constraints);
|
||||
for (constraints) |constraint| {
|
||||
const constraint_fn_name_bytes = self.cir.getIdent(constraint.fn_name);
|
||||
|
||||
// Check if this is a call to is_eq (anonymous types have implicit structural equality)
|
||||
if (std.mem.eql(u8, constraint_fn_name_bytes, "is_eq")) {
|
||||
// Check if all components of this anonymous type support is_eq
|
||||
if (self.typeSupportsIsEq(dispatcher_content.structure)) {
|
||||
// All components support is_eq, unify return type with Bool
|
||||
const resolved_constraint = self.types.resolveVar(constraint.fn_var);
|
||||
const mb_resolved_func = resolved_constraint.desc.content.unwrapFunc();
|
||||
if (mb_resolved_func) |resolved_func| {
|
||||
const region = self.getRegionAt(deferred_constraint.var_);
|
||||
const bool_var = try self.freshBool(env, region);
|
||||
_ = try self.unify(bool_var, resolved_func.ret, env);
|
||||
}
|
||||
} else {
|
||||
// Some component doesn't support is_eq (e.g., contains a function)
|
||||
try self.reportEqualityError(
|
||||
deferred_constraint.var_,
|
||||
constraint,
|
||||
env,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Other methods are not supported on anonymous types
|
||||
try self.reportConstraintError(
|
||||
deferred_constraint.var_,
|
||||
constraint,
|
||||
.not_nominal,
|
||||
env,
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// If the root type is anything but a nominal type, push an error
|
||||
// If the root type is anything but a nominal type or anonymous structural type, push an error
|
||||
|
||||
const constraints = self.types.sliceStaticDispatchConstraints(deferred_constraint.constraints);
|
||||
if (constraints.len > 0) {
|
||||
|
|
@ -4761,6 +4852,75 @@ fn checkDeferredStaticDispatchConstraints(self: *Self, env: *Env) std.mem.Alloca
|
|||
env.deferred_static_dispatch_constraints.items.clearRetainingCapacity();
|
||||
}
|
||||
|
||||
/// Check if a structural type supports is_eq.
|
||||
/// A type supports is_eq if:
|
||||
/// - It's not a function type
|
||||
/// - All of its components (record fields, tuple elements, tag payloads) also support is_eq
|
||||
/// - For nominal types, we assume they support is_eq (TODO: actually check for is_eq method)
|
||||
fn typeSupportsIsEq(self: *Self, flat_type: types_mod.FlatType) bool {
|
||||
return switch (flat_type) {
|
||||
// Function types do not support is_eq
|
||||
.fn_pure, .fn_effectful, .fn_unbound => false,
|
||||
|
||||
// Empty types trivially support is_eq
|
||||
.empty_record, .empty_tag_union => true,
|
||||
|
||||
// Records support is_eq if all field types support is_eq
|
||||
.record => |record| {
|
||||
const fields_slice = self.types.getRecordFieldsSlice(record.fields);
|
||||
for (fields_slice.items(.var_)) |field_var| {
|
||||
if (!self.varSupportsIsEq(field_var)) return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
|
||||
// Tuples support is_eq if all element types support is_eq
|
||||
.tuple => |tuple| {
|
||||
const elems = self.types.sliceVars(tuple.elems);
|
||||
for (elems) |elem_var| {
|
||||
if (!self.varSupportsIsEq(elem_var)) return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
|
||||
// Tag unions support is_eq if all payload types support is_eq
|
||||
.tag_union => |tag_union| {
|
||||
const tags_slice = self.types.getTagsSlice(tag_union.tags);
|
||||
for (tags_slice.items(.args)) |tag_args| {
|
||||
const args = self.types.sliceVars(tag_args);
|
||||
for (args) |arg_var| {
|
||||
if (!self.varSupportsIsEq(arg_var)) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
|
||||
// Nominal types: TODO: actually check if they have an is_eq method
|
||||
// For now, assume they do (numbers, Bool, Str, etc. all have is_eq)
|
||||
.nominal_type => true,
|
||||
|
||||
// Unbound records need to be resolved first
|
||||
.record_unbound => true, // TODO: check resolved type
|
||||
};
|
||||
}
|
||||
|
||||
/// Check if a type variable supports is_eq by resolving it and checking its content
|
||||
fn varSupportsIsEq(self: *Self, var_: Var) bool {
|
||||
const resolved = self.types.resolveVar(var_);
|
||||
return switch (resolved.desc.content) {
|
||||
.structure => |s| self.typeSupportsIsEq(s),
|
||||
// Flex/rigid vars could be anything, assume they support is_eq for now
|
||||
// (the actual constraint will be checked when the type is known)
|
||||
.flex, .rigid => true,
|
||||
// Aliases: check the underlying type
|
||||
.alias => |alias| self.varSupportsIsEq(self.types.getAliasBackingVar(alias)),
|
||||
// Recursion vars: assume they support is_eq (recursive types like List are ok)
|
||||
.recursion_var => true,
|
||||
// Error types: allow them to proceed
|
||||
.err => true,
|
||||
};
|
||||
}
|
||||
|
||||
/// Mark a constraint function's return type as error
|
||||
fn markConstraintFunctionAsError(self: *Self, constraint: StaticDispatchConstraint, env: *Env) !void {
|
||||
const resolved_constraint = self.types.resolveVar(constraint.fn_var);
|
||||
|
|
@ -4807,6 +4967,26 @@ fn reportConstraintError(
|
|||
try self.markConstraintFunctionAsError(constraint, env);
|
||||
}
|
||||
|
||||
/// Report an error when an anonymous type doesn't support equality
|
||||
fn reportEqualityError(
|
||||
self: *Self,
|
||||
dispatcher_var: Var,
|
||||
constraint: StaticDispatchConstraint,
|
||||
env: *Env,
|
||||
) !void {
|
||||
const snapshot = try self.snapshots.deepCopyVar(self.types, dispatcher_var);
|
||||
const equality_problem = problem.Problem{ .static_dispach = .{
|
||||
.type_does_not_support_equality = .{
|
||||
.dispatcher_var = dispatcher_var,
|
||||
.dispatcher_snapshot = snapshot,
|
||||
.fn_var = constraint.fn_var,
|
||||
},
|
||||
} };
|
||||
_ = try self.problems.appendProblem(self.cir.gpa, equality_problem);
|
||||
|
||||
try self.markConstraintFunctionAsError(constraint, env);
|
||||
}
|
||||
|
||||
/// Pool for reusing Env instances to avoid repeated allocations
|
||||
const EnvPool = struct {
|
||||
available: std.ArrayList(Env),
|
||||
|
|
|
|||
|
|
@ -217,6 +217,7 @@ pub const InvalidBoolBinop = struct {
|
|||
pub const StaticDispatch = union(enum) {
|
||||
dispatcher_not_nominal: DispatcherNotNominal,
|
||||
dispatcher_does_not_impl_method: DispatcherDoesNotImplMethod,
|
||||
type_does_not_support_equality: TypeDoesNotSupportEquality,
|
||||
};
|
||||
|
||||
/// Error when you try to static dispatch on something that's not a nominal type
|
||||
|
|
@ -240,6 +241,14 @@ pub const DispatcherDoesNotImplMethod = struct {
|
|||
pub const DispatcherType = enum { nominal, rigid };
|
||||
};
|
||||
|
||||
/// Error when an anonymous type (record, tuple, tag union) doesn't support equality
|
||||
/// because one or more of its components contain types that don't have is_eq
|
||||
pub const TypeDoesNotSupportEquality = struct {
|
||||
dispatcher_var: Var,
|
||||
dispatcher_snapshot: SnapshotContentIdx,
|
||||
fn_var: Var,
|
||||
};
|
||||
|
||||
// bug //
|
||||
|
||||
/// Error when you try to apply the wrong number of arguments to a type in
|
||||
|
|
@ -367,6 +376,7 @@ pub const ReportBuilder = struct {
|
|||
switch (detail) {
|
||||
.dispatcher_not_nominal => |data| return self.buildStaticDispatchDispatcherNotNominal(data),
|
||||
.dispatcher_does_not_impl_method => |data| return self.buildStaticDispatchDispatcherDoesNotImplMethod(data),
|
||||
.type_does_not_support_equality => |data| return self.buildTypeDoesNotSupportEquality(data),
|
||||
}
|
||||
},
|
||||
.number_does_not_fit => |data| {
|
||||
|
|
@ -1768,6 +1778,288 @@ pub const ReportBuilder = struct {
|
|||
return report;
|
||||
}
|
||||
|
||||
/// Build a report for when an anonymous type doesn't support equality
|
||||
fn buildTypeDoesNotSupportEquality(
|
||||
self: *Self,
|
||||
data: TypeDoesNotSupportEquality,
|
||||
) !Report {
|
||||
var report = Report.init(self.gpa, "TYPE DOES NOT SUPPORT EQUALITY", .runtime_error);
|
||||
errdefer report.deinit();
|
||||
|
||||
self.snapshot_writer.resetContext();
|
||||
try self.snapshot_writer.write(data.dispatcher_snapshot);
|
||||
const snapshot_str = try report.addOwnedString(self.snapshot_writer.get());
|
||||
|
||||
const region = self.can_ir.store.regions.get(@enumFromInt(@intFromEnum(data.fn_var)));
|
||||
const region_info = self.module_env.calcRegionInfo(region.*);
|
||||
|
||||
try report.document.addReflowingText("This expression uses ");
|
||||
try report.document.addAnnotated("==", .emphasized);
|
||||
try report.document.addReflowingText(" or ");
|
||||
try report.document.addAnnotated("!=", .emphasized);
|
||||
try report.document.addReflowingText(" on a type that doesn't support equality:");
|
||||
try report.document.addLineBreak();
|
||||
|
||||
try report.document.addSourceRegion(
|
||||
region_info,
|
||||
.error_highlight,
|
||||
self.filename,
|
||||
self.source,
|
||||
self.module_env.getLineStarts(),
|
||||
);
|
||||
try report.document.addLineBreak();
|
||||
|
||||
try report.document.addReflowingText("The type is:");
|
||||
try report.document.addLineBreak();
|
||||
try report.document.addText(" ");
|
||||
try report.document.addAnnotated(snapshot_str, .type_variable);
|
||||
try report.document.addLineBreak();
|
||||
try report.document.addLineBreak();
|
||||
|
||||
// Get the content and explain which parts don't support equality
|
||||
const content = self.snapshots.getContent(data.dispatcher_snapshot);
|
||||
if (content == .structure) {
|
||||
switch (content.structure) {
|
||||
.record => |record| {
|
||||
try self.explainRecordEqualityFailure(&report, record);
|
||||
},
|
||||
.tuple => |tuple| {
|
||||
try self.explainTupleEqualityFailure(&report, tuple);
|
||||
},
|
||||
.tag_union => |tag_union| {
|
||||
try self.explainTagUnionEqualityFailure(&report, tag_union);
|
||||
},
|
||||
.fn_pure, .fn_effectful, .fn_unbound => {
|
||||
try report.document.addReflowingText("Functions cannot be compared for equality.");
|
||||
try report.document.addLineBreak();
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
}
|
||||
|
||||
return report;
|
||||
}
|
||||
|
||||
/// Explain which record fields don't support equality
|
||||
fn explainRecordEqualityFailure(
|
||||
self: *Self,
|
||||
report: *Report,
|
||||
record: snapshot.SnapshotRecord,
|
||||
) !void {
|
||||
const fields = self.snapshots.sliceRecordFields(record.fields);
|
||||
var has_problem_fields = false;
|
||||
|
||||
// First pass: check if any fields don't support equality
|
||||
for (fields.items(.content)) |field_content_idx| {
|
||||
if (!self.snapshotSupportsEquality(field_content_idx)) {
|
||||
has_problem_fields = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_problem_fields) {
|
||||
try report.document.addReflowingText("This record does not support equality because these fields have types that don't support ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(":");
|
||||
try report.document.addLineBreak();
|
||||
|
||||
const field_names = fields.items(.name);
|
||||
const field_contents = fields.items(.content);
|
||||
for (field_names, field_contents) |name, field_content_idx| {
|
||||
if (!self.snapshotSupportsEquality(field_content_idx)) {
|
||||
const field_name = self.can_ir.getIdentText(name);
|
||||
|
||||
self.snapshot_writer.resetContext();
|
||||
try self.snapshot_writer.write(field_content_idx);
|
||||
const field_type_str = try report.addOwnedString(self.snapshot_writer.get());
|
||||
|
||||
try report.document.addText(" ");
|
||||
try report.document.addAnnotated(field_name, .emphasized);
|
||||
try report.document.addText(": ");
|
||||
try report.document.addAnnotated(field_type_str, .type_variable);
|
||||
try report.document.addLineBreak();
|
||||
}
|
||||
}
|
||||
try report.document.addLineBreak();
|
||||
try report.document.addAnnotated("Hint: ", .emphasized);
|
||||
try report.document.addReflowingText("Anonymous records only have an ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(" method if all of their fields have ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(" methods.");
|
||||
try report.document.addLineBreak();
|
||||
}
|
||||
}
|
||||
|
||||
/// Explain which tuple elements don't support equality
|
||||
fn explainTupleEqualityFailure(
|
||||
self: *Self,
|
||||
report: *Report,
|
||||
tuple: snapshot.SnapshotTuple,
|
||||
) !void {
|
||||
const elems = self.snapshots.sliceVars(tuple.elems);
|
||||
var has_problem_elems = false;
|
||||
|
||||
// First pass: check if any elements don't support equality
|
||||
for (elems) |elem_content_idx| {
|
||||
if (!self.snapshotSupportsEquality(elem_content_idx)) {
|
||||
has_problem_elems = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (has_problem_elems) {
|
||||
try report.document.addReflowingText("This tuple does not support equality because these elements have types that don't support ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(":");
|
||||
try report.document.addLineBreak();
|
||||
|
||||
for (elems, 0..) |elem_content_idx, i| {
|
||||
if (!self.snapshotSupportsEquality(elem_content_idx)) {
|
||||
self.snapshot_writer.resetContext();
|
||||
try self.snapshot_writer.write(elem_content_idx);
|
||||
const elem_type_str = try report.addOwnedString(self.snapshot_writer.get());
|
||||
|
||||
try report.document.addText(" element ");
|
||||
var buf: [20]u8 = undefined;
|
||||
const index_str = std.fmt.bufPrint(&buf, "{}", .{i}) catch "?";
|
||||
try report.document.addAnnotated(index_str, .emphasized);
|
||||
try report.document.addText(": ");
|
||||
try report.document.addAnnotated(elem_type_str, .type_variable);
|
||||
try report.document.addLineBreak();
|
||||
}
|
||||
}
|
||||
try report.document.addLineBreak();
|
||||
try report.document.addAnnotated("Hint: ", .emphasized);
|
||||
try report.document.addReflowingText("Tuples only have an ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(" method if all of their elements have ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(" methods.");
|
||||
try report.document.addLineBreak();
|
||||
}
|
||||
}
|
||||
|
||||
/// Explain which tag union payloads don't support equality
|
||||
fn explainTagUnionEqualityFailure(
|
||||
self: *Self,
|
||||
report: *Report,
|
||||
tag_union: snapshot.SnapshotTagUnion,
|
||||
) !void {
|
||||
const tags = self.snapshots.sliceTags(tag_union.tags);
|
||||
var has_problem_tags = false;
|
||||
|
||||
// First pass: check if any tag payloads don't support equality
|
||||
for (tags.items(.args)) |tag_args| {
|
||||
const args = self.snapshots.sliceVars(tag_args);
|
||||
for (args) |arg_content_idx| {
|
||||
if (!self.snapshotSupportsEquality(arg_content_idx)) {
|
||||
has_problem_tags = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (has_problem_tags) break;
|
||||
}
|
||||
|
||||
if (has_problem_tags) {
|
||||
try report.document.addReflowingText("This tag union does not support equality because these tags have payload types that don't support ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(":");
|
||||
try report.document.addLineBreak();
|
||||
|
||||
const tag_names = tags.items(.name);
|
||||
const tag_args_list = tags.items(.args);
|
||||
for (tag_names, tag_args_list) |name, tag_args| {
|
||||
const args = self.snapshots.sliceVars(tag_args);
|
||||
var tag_has_problem = false;
|
||||
for (args) |arg_content_idx| {
|
||||
if (!self.snapshotSupportsEquality(arg_content_idx)) {
|
||||
tag_has_problem = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (tag_has_problem) {
|
||||
const tag_name = self.can_ir.getIdentText(name);
|
||||
try report.document.addText(" ");
|
||||
try report.document.addAnnotated(tag_name, .emphasized);
|
||||
|
||||
// Show the problematic payload types
|
||||
if (args.len > 0) {
|
||||
try report.document.addText(" (");
|
||||
var first = true;
|
||||
for (args) |arg_content_idx| {
|
||||
if (!first) try report.document.addText(", ");
|
||||
first = false;
|
||||
|
||||
self.snapshot_writer.resetContext();
|
||||
try self.snapshot_writer.write(arg_content_idx);
|
||||
const arg_type_str = try report.addOwnedString(self.snapshot_writer.get());
|
||||
try report.document.addAnnotated(arg_type_str, .type_variable);
|
||||
}
|
||||
try report.document.addText(")");
|
||||
}
|
||||
try report.document.addLineBreak();
|
||||
}
|
||||
}
|
||||
try report.document.addLineBreak();
|
||||
try report.document.addAnnotated("Hint: ", .emphasized);
|
||||
try report.document.addReflowingText("Tag unions only have an ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(" method if all of their payload types have ");
|
||||
try report.document.addAnnotated("is_eq", .emphasized);
|
||||
try report.document.addReflowingText(" methods.");
|
||||
try report.document.addLineBreak();
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a snapshotted type supports equality
|
||||
fn snapshotSupportsEquality(self: *Self, content_idx: snapshot.SnapshotContentIdx) bool {
|
||||
const content = self.snapshots.getContent(content_idx);
|
||||
return switch (content) {
|
||||
.structure => |s| switch (s) {
|
||||
// Functions never support equality
|
||||
.fn_pure, .fn_effectful, .fn_unbound => false,
|
||||
// Empty types trivially support equality
|
||||
.empty_record, .empty_tag_union => true,
|
||||
// Records: all fields must support equality
|
||||
.record => |record| {
|
||||
const fields = self.snapshots.sliceRecordFields(record.fields);
|
||||
for (fields.items(.content)) |field_content| {
|
||||
if (!self.snapshotSupportsEquality(field_content)) return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
// Tuples: all elements must support equality
|
||||
.tuple => |tuple| {
|
||||
const elems = self.snapshots.sliceVars(tuple.elems);
|
||||
for (elems) |elem_content| {
|
||||
if (!self.snapshotSupportsEquality(elem_content)) return false;
|
||||
}
|
||||
return true;
|
||||
},
|
||||
// Tag unions: all payloads must support equality
|
||||
.tag_union => |tag_union| {
|
||||
const tags_slice = self.snapshots.sliceTags(tag_union.tags);
|
||||
for (tags_slice.items(.args)) |tag_args| {
|
||||
const args = self.snapshots.sliceVars(tag_args);
|
||||
for (args) |arg_content| {
|
||||
if (!self.snapshotSupportsEquality(arg_content)) return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
},
|
||||
// Other types (nominal, box, etc.) assumed to support equality
|
||||
else => true,
|
||||
},
|
||||
// Aliases: check the underlying type
|
||||
.alias => |alias| self.snapshotSupportsEquality(alias.backing),
|
||||
// Recursion vars: assume they support equality
|
||||
.recursion_var => true,
|
||||
// Other types (flex, rigid, recursive, err) assumed to support equality
|
||||
else => true,
|
||||
};
|
||||
}
|
||||
|
||||
// number problems //
|
||||
|
||||
/// Build a report for "number does not fit in type" diagnostic
|
||||
|
|
|
|||
|
|
@ -1217,6 +1217,10 @@ pub const Store = struct {
|
|||
return self.static_dispatch_constraints.sliceRange(range);
|
||||
}
|
||||
|
||||
pub fn sliceTags(self: *const Self, range: SnapshotTagSafeList.Range) SnapshotTagSafeList.Slice {
|
||||
return self.tags.sliceRange(range);
|
||||
}
|
||||
|
||||
pub fn getContent(self: *const Self, idx: SnapshotContentIdx) SnapshotContent {
|
||||
return self.contents.get(idx).*;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -196,6 +196,85 @@ test "check type - record" {
|
|||
try checkTypesExpr(source, .pass, "{ hello: Str, world: a } where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)])]");
|
||||
}
|
||||
|
||||
// anonymous type equality (is_eq) //
|
||||
|
||||
test "check type - record equality - same records are equal" {
|
||||
const source =
|
||||
\\{ x: 1, y: 2 } == { x: 1, y: 2 }
|
||||
;
|
||||
try checkTypesExpr(source, .pass, "Bool");
|
||||
}
|
||||
|
||||
test "check type - tuple equality - same tuples are equal" {
|
||||
const source =
|
||||
\\(1, 2) == (1, 2)
|
||||
;
|
||||
try checkTypesExpr(source, .pass, "Bool");
|
||||
}
|
||||
|
||||
test "check type - empty record equality" {
|
||||
const source =
|
||||
\\{} == {}
|
||||
;
|
||||
try checkTypesExpr(source, .pass, "Bool");
|
||||
}
|
||||
|
||||
test "check type - record with function field - no is_eq" {
|
||||
// Records containing functions should not have is_eq because functions don't have is_eq
|
||||
const source =
|
||||
\\{ x: 1, f: |a| a + 1 } == { x: 1, f: |a| a + 1 }
|
||||
;
|
||||
try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY");
|
||||
}
|
||||
|
||||
test "check type - tuple with function element - no is_eq" {
|
||||
// Tuples containing functions should not have is_eq because functions don't have is_eq
|
||||
const source =
|
||||
\\(1, |a| a) == (1, |a| a)
|
||||
;
|
||||
try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY");
|
||||
}
|
||||
|
||||
test "check type - nested record equality" {
|
||||
// Nested records should type-check as Bool
|
||||
const source =
|
||||
\\{ a: { x: 1 }, b: 2 } == { a: { x: 1 }, b: 2 }
|
||||
;
|
||||
try checkTypesExpr(source, .pass, "Bool");
|
||||
}
|
||||
|
||||
test "check type - nested tuple equality" {
|
||||
// Nested tuples should type-check as Bool
|
||||
const source =
|
||||
\\((1, 2), 3) == ((1, 2), 3)
|
||||
;
|
||||
try checkTypesExpr(source, .pass, "Bool");
|
||||
}
|
||||
|
||||
test "check type - nested record with function - no is_eq" {
|
||||
// Nested records containing functions should not have is_eq
|
||||
const source =
|
||||
\\{ a: { f: |x| x } } == { a: { f: |x| x } }
|
||||
;
|
||||
try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY");
|
||||
}
|
||||
|
||||
test "check type - tag union equality" {
|
||||
// Tag unions should type-check for equality
|
||||
const source =
|
||||
\\Ok(1) == Ok(1)
|
||||
;
|
||||
try checkTypesExpr(source, .pass, "Bool");
|
||||
}
|
||||
|
||||
test "check type - tag union with function payload - no is_eq" {
|
||||
// Tag unions with function payloads should not have is_eq
|
||||
const source =
|
||||
\\Fn(|x| x) == Fn(|x| x)
|
||||
;
|
||||
try checkTypesExpr(source, .fail, "TYPE DOES NOT SUPPORT EQUALITY");
|
||||
}
|
||||
|
||||
// tags //
|
||||
|
||||
test "check type - tag" {
|
||||
|
|
|
|||
|
|
@ -2038,12 +2038,18 @@ const Unifier = struct {
|
|||
const trace = tracy.trace(@src());
|
||||
defer trace.end();
|
||||
|
||||
const range_start: u32 = self.types_store.record_fields.len();
|
||||
|
||||
// Here, iterate over shared fields, sub unifying the field variables.
|
||||
// At this point, the fields are know to be identical, so we arbitrary choose b
|
||||
// First, unify all field types. This may cause nested record unifications
|
||||
// which will append their own fields to the store. We must NOT interleave
|
||||
// our field appends with these nested calls.
|
||||
for (shared_fields) |shared| {
|
||||
try self.unifyGuarded(shared.a.var_, shared.b.var_);
|
||||
}
|
||||
|
||||
// Now that all nested unifications are complete, append OUR fields.
|
||||
// This ensures our fields form a contiguous range.
|
||||
const range_start: u32 = self.types_store.record_fields.len();
|
||||
|
||||
for (shared_fields) |shared| {
|
||||
_ = self.types_store.appendRecordFields(&[_]RecordField{.{
|
||||
.name = shared.b.name,
|
||||
.var_ = shared.b.var_,
|
||||
|
|
@ -2058,7 +2064,7 @@ const Unifier = struct {
|
|||
_ = self.types_store.appendRecordFields(extended_fields) catch return Error.AllocatorError;
|
||||
}
|
||||
|
||||
// Merge vars
|
||||
// Merge vars - now the range correctly contains only THIS record's fields
|
||||
self.merge(vars, Content{ .structure = FlatType{ .record = .{
|
||||
.fields = self.types_store.record_fields.rangeToEnd(range_start),
|
||||
.ext = ext,
|
||||
|
|
|
|||
|
|
@ -4442,13 +4442,13 @@ pub const Interpreter = struct {
|
|||
.tag_union => {
|
||||
return try self.structuralEqualTag(lhs, rhs, lhs_var);
|
||||
},
|
||||
.list => |elem_var| {
|
||||
return try self.structuralEqualList(lhs, rhs, elem_var);
|
||||
},
|
||||
.empty_record => true,
|
||||
.list_unbound, .record_unbound, .fn_pure, .fn_effectful, .fn_unbound, .nominal_type, .empty_tag_union, .box => error.NotImplemented,
|
||||
.str => error.NotImplemented,
|
||||
.num => error.NotImplemented,
|
||||
.empty_tag_union => true,
|
||||
.nominal_type => |nom| {
|
||||
// For nominal types, dispatch to their is_eq method
|
||||
return try self.dispatchNominalIsEq(lhs, rhs, nom, lhs_var);
|
||||
},
|
||||
.record_unbound, .fn_pure, .fn_effectful, .fn_unbound => error.NotImplemented,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -4630,6 +4630,41 @@ pub const Interpreter = struct {
|
|||
return true;
|
||||
}
|
||||
|
||||
/// Dispatch is_eq method call for a nominal type
|
||||
fn dispatchNominalIsEq(
|
||||
self: *Interpreter,
|
||||
lhs: StackValue,
|
||||
rhs: StackValue,
|
||||
nom: types.NominalType,
|
||||
lhs_var: types.Var,
|
||||
) StructuralEqError!bool {
|
||||
// TODO: Properly dispatch to the nominal type's is_eq method
|
||||
// For now, use a simplified approach:
|
||||
// - If the nominal type is a wrapper around a scalar or simple structure, compare directly
|
||||
// - Otherwise, fall back to structural comparison of the backing type
|
||||
|
||||
// Get the backing var of the nominal type
|
||||
const backing_var = self.runtime_types.getNominalBackingVar(nom);
|
||||
const backing_resolved = self.runtime_types.resolveVar(backing_var);
|
||||
|
||||
// If the backing type is a structure, recursively compare
|
||||
if (backing_resolved.desc.content == .structure) {
|
||||
return self.valuesStructurallyEqual(lhs, backing_var, rhs, backing_var);
|
||||
}
|
||||
|
||||
// For other cases, fall back to attempting scalar comparison
|
||||
// This handles cases like Bool which wraps a tag union but is represented as a scalar
|
||||
if (lhs.layout.tag == .scalar and rhs.layout.tag == .scalar) {
|
||||
const order = self.compareNumericScalars(lhs, rhs) catch return error.NotImplemented;
|
||||
return order == .eq;
|
||||
}
|
||||
|
||||
// Can't compare - likely a user-defined nominal type that needs is_eq dispatch
|
||||
// TODO: Implement proper method dispatch by looking up is_eq in the nominal type's module
|
||||
_ = lhs_var;
|
||||
return error.NotImplemented;
|
||||
}
|
||||
|
||||
pub fn getCanonicalBoolRuntimeVar(self: *Interpreter) !types.Var {
|
||||
if (self.canonical_bool_rt_var) |cached| return cached;
|
||||
// Use the dynamic bool_stmt index (from the Bool module)
|
||||
|
|
@ -5437,22 +5472,40 @@ pub const Interpreter = struct {
|
|||
var rhs = try self.evalExprMinimal(rhs_expr, roc_ops, rhs_rt_var);
|
||||
defer rhs.decref(&self.runtime_layout_store, roc_ops);
|
||||
|
||||
// Get the nominal type information from lhs
|
||||
const nominal_info = switch (lhs_resolved.desc.content) {
|
||||
// Get the nominal type information from lhs, or handle anonymous structural types
|
||||
const nominal_info: ?struct { origin: base_pkg.Ident.Idx, ident: base_pkg.Ident.Idx } = switch (lhs_resolved.desc.content) {
|
||||
.structure => |s| switch (s) {
|
||||
.nominal_type => |nom| .{
|
||||
.origin = nom.origin_module,
|
||||
.ident = nom.ident.ident_idx,
|
||||
},
|
||||
else => return error.InvalidMethodReceiver,
|
||||
.record, .tuple, .tag_union, .empty_record, .empty_tag_union => blk: {
|
||||
// Anonymous structural types have implicit is_eq
|
||||
if (method_ident == self.env.is_eq_ident) {
|
||||
const result = self.valuesStructurallyEqual(lhs, lhs_rt_var, rhs, rhs_rt_var) catch |err| {
|
||||
// If structural equality is not implemented for this type, return false
|
||||
if (err == error.NotImplemented) {
|
||||
return try self.makeBoolValue(false);
|
||||
}
|
||||
return err;
|
||||
};
|
||||
return try self.makeBoolValue(result);
|
||||
}
|
||||
break :blk null;
|
||||
},
|
||||
else => null,
|
||||
},
|
||||
else => return error.InvalidMethodReceiver,
|
||||
else => null,
|
||||
};
|
||||
|
||||
if (nominal_info == null) {
|
||||
return error.InvalidMethodReceiver;
|
||||
}
|
||||
|
||||
// Resolve the method function
|
||||
const method_func = self.resolveMethodFunction(
|
||||
nominal_info.origin,
|
||||
nominal_info.ident,
|
||||
nominal_info.?.origin,
|
||||
nominal_info.?.ident,
|
||||
method_ident,
|
||||
roc_ops,
|
||||
) catch |err| return err;
|
||||
|
|
|
|||
|
|
@ -848,3 +848,50 @@ test "ModuleEnv serialization and interpreter evaluation" {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for anonymous type equality (is_eq on records, tuples, and tag unions)
|
||||
|
||||
test "anonymous record equality" {
|
||||
// Same records should be equal
|
||||
try runExpectBool("{ x: 1, y: 2 } == { x: 1, y: 2 }", true, .no_trace);
|
||||
// Different values should not be equal
|
||||
try runExpectBool("{ x: 1, y: 2 } == { x: 1, y: 3 }", false, .no_trace);
|
||||
// Field order shouldn't matter
|
||||
try runExpectBool("{ x: 1, y: 2 } == { y: 2, x: 1 }", true, .no_trace);
|
||||
}
|
||||
|
||||
test "anonymous tuple equality" {
|
||||
// Same tuples should be equal
|
||||
try runExpectBool("(1, 2) == (1, 2)", true, .no_trace);
|
||||
// Different values should not be equal
|
||||
try runExpectBool("(1, 2) == (1, 3)", false, .no_trace);
|
||||
}
|
||||
|
||||
test "empty record equality" {
|
||||
try runExpectBool("{} == {}", true, .no_trace);
|
||||
}
|
||||
|
||||
test "string field equality" {
|
||||
try runExpectBool("{ name: \"hello\" } == { name: \"hello\" }", true, .no_trace);
|
||||
try runExpectBool("{ name: \"hello\" } == { name: \"world\" }", false, .no_trace);
|
||||
}
|
||||
|
||||
test "nested record equality" {
|
||||
try runExpectBool("{ a: { x: 1 }, b: 2 } == { a: { x: 1 }, b: 2 }", true, .no_trace);
|
||||
try runExpectBool("{ a: { x: 1 }, b: 2 } == { a: { x: 2 }, b: 2 }", false, .no_trace);
|
||||
try runExpectBool("{ outer: { inner: { deep: 42 } } } == { outer: { inner: { deep: 42 } } }", true, .no_trace);
|
||||
try runExpectBool("{ outer: { inner: { deep: 42 } } } == { outer: { inner: { deep: 99 } } }", false, .no_trace);
|
||||
}
|
||||
|
||||
test "bool field equality" {
|
||||
// Use comparison expressions to produce boolean values for record fields
|
||||
try runExpectBool("{ flag: (1 == 1) } == { flag: (1 == 1) }", true, .no_trace);
|
||||
try runExpectBool("{ flag: (1 == 1) } == { flag: (1 != 1) }", false, .no_trace);
|
||||
}
|
||||
|
||||
test "nested tuple equality" {
|
||||
try runExpectBool("((1, 2), 3) == ((1, 2), 3)", true, .no_trace);
|
||||
try runExpectBool("((1, 2), 3) == ((1, 9), 3)", false, .no_trace);
|
||||
try runExpectBool("(1, (2, 3)) == (1, (2, 3))", true, .no_trace);
|
||||
try runExpectBool("(1, (2, 3)) == (1, (2, 9))", false, .no_trace);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ NO CHANGE
|
|||
~~~clojure
|
||||
(inferred-types
|
||||
(defs
|
||||
(patt (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)])]")))
|
||||
(patt (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)]), a.is_eq : a, a -> Bool]")))
|
||||
(expressions
|
||||
(expr (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)])]"))))
|
||||
(expr (type "a -> Str where [a.from_numeral : Numeral -> Try(a, [InvalidNumeral(Str)]), a.is_eq : a, a -> Bool]"))))
|
||||
~~~
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue