Implement obligation checking for the Eq ability

Every type can have `Eq.isEq` derived for it, as long as

- it does not transitively contain a function
- it does not transitively contain a floating point value
- it does not transitively contain an opaque type that does not support
  `Eq`
This commit is contained in:
Ayaz Hafiz 2022-10-05 15:07:27 -05:00
parent 5931dd5fc2
commit b587bcf0c2
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
5 changed files with 437 additions and 81 deletions

View file

@ -5,7 +5,8 @@ use roc_error_macros::{internal_error, todo_abilities};
use roc_module::symbol::Symbol;
use roc_region::all::{Loc, Region};
use roc_solve_problem::{
NotDerivableContext, NotDerivableDecode, TypeError, UnderivableReason, Unfulfilled,
NotDerivableContext, NotDerivableDecode, NotDerivableEq, TypeError, UnderivableReason,
Unfulfilled,
};
use roc_types::num::NumericRange;
use roc_types::subs::{
@ -276,6 +277,8 @@ impl ObligationCache {
Some(DeriveHash::is_derivable(self, abilities_store, subs, var))
}
Symbol::EQ_EQ => Some(DeriveEq::is_derivable(self, abilities_store, subs, var)),
_ => None,
};
@ -420,7 +423,7 @@ impl ObligationCache {
#[inline(always)]
#[rustfmt::skip]
fn is_builtin_number_alias(symbol: Symbol) -> bool {
fn is_builtin_int_alias(symbol: Symbol) -> bool {
matches!(symbol,
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
@ -433,12 +436,32 @@ fn is_builtin_number_alias(symbol: Symbol) -> bool {
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
)
}
#[inline(always)]
#[rustfmt::skip]
fn is_builtin_float_alias(symbol: Symbol) -> bool {
matches!(symbol,
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
)
}
#[inline(always)]
#[rustfmt::skip]
fn is_builtin_dec_alias(symbol: Symbol) -> bool {
matches!(symbol,
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
)
}
#[inline(always)]
#[rustfmt::skip]
fn is_builtin_number_alias(symbol: Symbol) -> bool {
is_builtin_int_alias(symbol) || is_builtin_float_alias(symbol) || is_builtin_dec_alias(symbol)
}
struct NotDerivable {
var: Variable,
context: NotDerivableContext,
@ -986,6 +1009,102 @@ impl DerivableVisitor for DeriveHash {
}
}
struct DeriveEq;
impl DerivableVisitor for DeriveEq {
const ABILITY: Symbol = Symbol::EQ_EQ;
#[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_int_alias(symbol) || is_builtin_dec_alias(symbol)
}
#[inline(always)]
fn visit_recursion(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
if matches!(
symbol,
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
) {
Ok(Descend(true))
} else {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
}
#[inline(always)]
fn visit_record(
subs: &Subs,
var: Variable,
fields: RecordFields,
) -> Result<Descend, NotDerivable> {
for (field_name, _, field) in fields.iter_all() {
if subs[field].is_optional() {
return Err(NotDerivable {
var,
context: NotDerivableContext::Decode(NotDerivableDecode::OptionalRecordField(
subs[field_name].clone(),
)),
});
}
}
Ok(Descend(true))
}
#[inline(always)]
fn visit_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_empty_record(_var: Variable) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_empty_tag_union(_var: Variable) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_alias(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
if is_builtin_float_alias(symbol) {
Err(NotDerivable {
var,
context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint),
})
} else if is_builtin_number_alias(symbol) {
Ok(Descend(false))
} else {
Ok(Descend(true))
}
}
#[inline(always)]
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
// Ranged numbers are allowed, because they are always possibly ints - floats can not have
// `isEq` derived, but if something were to be a float, we'd see it exactly as a float.
Ok(())
}
}
/// Determines what type implements an ability member of a specialized signature, given the
/// [MustImplementAbility] constraints of the signature.
pub fn type_implementing_specialization(