mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-26 21:39:07 +00:00
Implement obligation checking for the Eq
ability
Every type can have `Eq.isEq` derived for it, as long as - it does not transitively contain a function - it does not transitively contain a floating point value - it does not transitively contain an opaque type that does not support `Eq`
This commit is contained in:
parent
5931dd5fc2
commit
b587bcf0c2
5 changed files with 437 additions and 81 deletions
|
@ -5,7 +5,8 @@ use roc_error_macros::{internal_error, todo_abilities};
|
|||
use roc_module::symbol::Symbol;
|
||||
use roc_region::all::{Loc, Region};
|
||||
use roc_solve_problem::{
|
||||
NotDerivableContext, NotDerivableDecode, TypeError, UnderivableReason, Unfulfilled,
|
||||
NotDerivableContext, NotDerivableDecode, NotDerivableEq, TypeError, UnderivableReason,
|
||||
Unfulfilled,
|
||||
};
|
||||
use roc_types::num::NumericRange;
|
||||
use roc_types::subs::{
|
||||
|
@ -276,6 +277,8 @@ impl ObligationCache {
|
|||
Some(DeriveHash::is_derivable(self, abilities_store, subs, var))
|
||||
}
|
||||
|
||||
Symbol::EQ_EQ => Some(DeriveEq::is_derivable(self, abilities_store, subs, var)),
|
||||
|
||||
_ => None,
|
||||
};
|
||||
|
||||
|
@ -420,7 +423,7 @@ impl ObligationCache {
|
|||
|
||||
#[inline(always)]
|
||||
#[rustfmt::skip]
|
||||
fn is_builtin_number_alias(symbol: Symbol) -> bool {
|
||||
fn is_builtin_int_alias(symbol: Symbol) -> bool {
|
||||
matches!(symbol,
|
||||
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
|
||||
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
|
||||
|
@ -433,12 +436,32 @@ fn is_builtin_number_alias(symbol: Symbol) -> bool {
|
|||
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
|
||||
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
|
||||
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[rustfmt::skip]
|
||||
fn is_builtin_float_alias(symbol: Symbol) -> bool {
|
||||
matches!(symbol,
|
||||
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
|
||||
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[rustfmt::skip]
|
||||
fn is_builtin_dec_alias(symbol: Symbol) -> bool {
|
||||
matches!(symbol,
|
||||
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[rustfmt::skip]
|
||||
fn is_builtin_number_alias(symbol: Symbol) -> bool {
|
||||
is_builtin_int_alias(symbol) || is_builtin_float_alias(symbol) || is_builtin_dec_alias(symbol)
|
||||
}
|
||||
|
||||
struct NotDerivable {
|
||||
var: Variable,
|
||||
context: NotDerivableContext,
|
||||
|
@ -986,6 +1009,102 @@ impl DerivableVisitor for DeriveHash {
|
|||
}
|
||||
}
|
||||
|
||||
struct DeriveEq;
|
||||
impl DerivableVisitor for DeriveEq {
|
||||
const ABILITY: Symbol = Symbol::EQ_EQ;
|
||||
|
||||
#[inline(always)]
|
||||
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
|
||||
is_builtin_int_alias(symbol) || is_builtin_dec_alias(symbol)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_recursion(_var: Variable) -> Result<Descend, NotDerivable> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
|
||||
if matches!(
|
||||
symbol,
|
||||
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
|
||||
) {
|
||||
Ok(Descend(true))
|
||||
} else {
|
||||
Err(NotDerivable {
|
||||
var,
|
||||
context: NotDerivableContext::NoContext,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_record(
|
||||
subs: &Subs,
|
||||
var: Variable,
|
||||
fields: RecordFields,
|
||||
) -> Result<Descend, NotDerivable> {
|
||||
for (field_name, _, field) in fields.iter_all() {
|
||||
if subs[field].is_optional() {
|
||||
return Err(NotDerivable {
|
||||
var,
|
||||
context: NotDerivableContext::Decode(NotDerivableDecode::OptionalRecordField(
|
||||
subs[field_name].clone(),
|
||||
)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_empty_record(_var: Variable) -> Result<(), NotDerivable> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_empty_tag_union(_var: Variable) -> Result<(), NotDerivable> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_alias(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
|
||||
if is_builtin_float_alias(symbol) {
|
||||
Err(NotDerivable {
|
||||
var,
|
||||
context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint),
|
||||
})
|
||||
} else if is_builtin_number_alias(symbol) {
|
||||
Ok(Descend(false))
|
||||
} else {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
|
||||
// Ranged numbers are allowed, because they are always possibly ints - floats can not have
|
||||
// `isEq` derived, but if something were to be a float, we'd see it exactly as a float.
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines what type implements an ability member of a specialized signature, given the
|
||||
/// [MustImplementAbility] constraints of the signature.
|
||||
pub fn type_implementing_specialization(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue