Switch obligation checking to use a visitor

My hope is this will make obligation checking for other abilities easier
to add when we do so
This commit is contained in:
Ayaz Hafiz 2022-08-01 12:52:32 -05:00
parent fa14146054
commit a7bc8cf4f2
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58

View file

@ -5,6 +5,7 @@ use roc_error_macros::internal_error;
use roc_module::symbol::Symbol; use roc_module::symbol::Symbol;
use roc_region::all::{Loc, Region}; use roc_region::all::{Loc, Region};
use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled}; use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled};
use roc_types::num::NumericRange;
use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable}; use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable};
use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory}; use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory};
use roc_unify::unify::{Env, MustImplementConstraints}; use roc_unify::unify::{Env, MustImplementConstraints};
@ -253,7 +254,12 @@ impl ObligationCache {
// independent queries. // independent queries.
let opt_can_derive_builtin = match ability { let opt_can_derive_builtin = match ability {
Symbol::ENCODE_ENCODING => Some(self.can_derive_encoding(subs, abilities_store, var)), Symbol::ENCODE_ENCODING => Some(DeriveEncoding::is_derivable(
self,
abilities_store,
subs,
var,
)),
_ => None, _ => None,
}; };
@ -262,7 +268,7 @@ impl ObligationCache {
// can derive! // can derive!
None None
} }
Some(Err(failure_var)) => Some(if failure_var == var { Some(Err(DerivableError::NotDerivable(failure_var))) => Some(if failure_var == var {
UnderivableReason::SurfaceNotDerivable UnderivableReason::SurfaceNotDerivable
} else { } else {
let (error_type, _skeletons) = subs.var_to_error_type(failure_var); let (error_type, _skeletons) = subs.var_to_error_type(failure_var);
@ -391,16 +397,128 @@ impl ObligationCache {
let check_has_fake = self.derive_cache.insert(derive_key, root_result); let check_has_fake = self.derive_cache.insert(derive_key, root_result);
debug_assert_eq!(check_has_fake, Some(fake_fulfilled)); debug_assert_eq!(check_has_fake, Some(fake_fulfilled));
} }
}
// If we have a lot of these, consider using a visitor. #[inline(always)]
// It will be very similar for most types (can't derive functions, can't derive unbound type #[rustfmt::skip]
// variables, can only derive opaques if they have an impl, etc). fn is_builtin_number_alias(symbol: Symbol) -> bool {
fn can_derive_encoding( matches!(symbol,
&mut self, Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
subs: &mut Subs, | Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
)
}
enum DerivableError {
NotDerivable(Variable),
}
struct Descend(bool);
trait DerivableVisitor {
const ABILITY: Symbol;
#[inline(always)]
fn visit_flex(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_rigid(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
if ability != Self::ABILITY {
Err(DerivableError::NotDerivable(var))
} else {
Ok(())
}
}
#[inline(always)]
fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
if ability != Self::ABILITY {
Err(DerivableError::NotDerivable(var))
} else {
Ok(())
}
}
#[inline(always)]
fn visit_recursion(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_apply(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_func(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_record(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_recursive_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_function_or_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_empty_record(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_empty_tag_union(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_alias(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn is_derivable(
obligation_cache: &mut ObligationCache,
abilities_store: &AbilitiesStore, abilities_store: &AbilitiesStore,
subs: &Subs,
var: Variable, var: Variable,
) -> Result<(), Variable> { ) -> Result<(), DerivableError> {
let mut stack = vec![var]; let mut stack = vec![var];
let mut seen_recursion_vars = vec![]; let mut seen_recursion_vars = vec![];
@ -418,102 +536,93 @@ impl ObligationCache {
let content = subs.get_content_without_compacting(var); let content = subs.get_content_without_compacting(var);
use Content::*; use Content::*;
use DerivableError::*;
use FlatType::*; use FlatType::*;
match content { match *content {
FlexVar(_) | RigidVar(_) => return Err(var), FlexVar(_) => Self::visit_flex(var)?,
FlexAbleVar(_, ability) | RigidAbleVar(_, ability) => { RigidVar(_) => Self::visit_rigid(var)?,
if *ability != Symbol::ENCODE_ENCODING { FlexAbleVar(_, ability) => Self::visit_flex_able(var, ability)?,
return Err(var); RigidAbleVar(_, ability) => Self::visit_rigid_able(var, ability)?,
}
// Any concrete type this variables is instantiated with will also gain a "does
// implement" check so this is okay.
}
RecursionVar { RecursionVar {
structure, structure,
opt_name: _, opt_name: _,
} => { } => {
seen_recursion_vars.push(var); let descend = Self::visit_recursion(var)?;
stack.push(*structure); if descend.0 {
seen_recursion_vars.push(var);
stack.push(structure);
}
} }
Structure(flat_type) => match flat_type { Structure(flat_type) => match flat_type {
Apply( Apply(symbol, vars) => {
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR, let descend = Self::visit_apply(var, symbol)?;
vars, if descend.0 {
) => push_var_slice!(*vars), push_var_slice!(vars)
Apply(..) => return Err(var),
Func(..) => {
return Err(var);
}
Record(fields, var) => {
push_var_slice!(fields.variables());
stack.push(*var);
}
TagUnion(tags, ext_var) => {
for i in tags.variables() {
push_var_slice!(subs[i]);
} }
stack.push(*ext_var);
} }
FunctionOrTagUnion(_, _, var) => stack.push(*var), Func(args, _clos, ret) => {
RecursiveTagUnion(rec_var, tags, ext_var) => { let descend = Self::visit_func(var)?;
seen_recursion_vars.push(*rec_var); if descend.0 {
for i in tags.variables() { push_var_slice!(args);
push_var_slice!(subs[i]); stack.push(ret);
} }
stack.push(*ext_var);
} }
EmptyRecord | EmptyTagUnion => { Record(fields, ext) => {
// yes let descend = Self::visit_record(var)?;
if descend.0 {
push_var_slice!(fields.variables());
stack.push(ext);
}
} }
Erroneous(_) => return Err(var), TagUnion(tags, ext) => {
let descend = Self::visit_tag_union(var)?;
if descend.0 {
for i in tags.variables() {
push_var_slice!(subs[i]);
}
stack.push(ext);
}
}
FunctionOrTagUnion(_tag_name, _fn_name, ext) => {
let descend = Self::visit_function_or_tag_union(var)?;
if descend.0 {
stack.push(ext);
}
}
RecursiveTagUnion(rec, tags, ext) => {
let descend = Self::visit_recursive_tag_union(var)?;
if descend.0 {
seen_recursion_vars.push(rec);
for i in tags.variables() {
push_var_slice!(subs[i]);
}
stack.push(ext);
}
}
EmptyRecord => Self::visit_empty_record(var)?,
EmptyTagUnion => Self::visit_empty_tag_union(var)?,
Erroneous(_) => return Err(NotDerivable(var)),
}, },
#[rustfmt::skip] Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => {
Alias( if obligation_cache
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8 .check_opaque_and_read(abilities_store, opaque, Self::ABILITY)
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
_,
_,
_,
) => {
// yes
}
Alias(
Symbol::NUM_NUM | Symbol::NUM_INTEGER | Symbol::NUM_FLOATINGPOINT,
_,
real_var,
_,
) => stack.push(*real_var),
Alias(name, _, _, AliasKind::Opaque) => {
let opaque = *name;
if self
.check_opaque_and_read(abilities_store, opaque, Symbol::ENCODE_ENCODING)
.is_err() .is_err()
{ {
return Err(var); return Err(NotDerivable(var));
} }
} }
Alias(_, arguments, real_type_var, _) => { Alias(symbol, _alias_variables, real_var, AliasKind::Structural) => {
push_var_slice!(arguments.all_variables()); let descend = Self::visit_alias(var, symbol)?;
stack.push(*real_type_var); if descend.0 {
stack.push(real_var);
}
} }
RangedNumber(..) => { RangedNumber(range) => Self::visit_ranged_number(var, range)?,
// yes, all numbers can
} LambdaSet(..) => return Err(NotDerivable(var)),
LambdaSet(..) => return Err(var),
Error => { Error => {
return Err(var); return Err(NotDerivable(var));
} }
} }
} }
@ -522,6 +631,66 @@ impl ObligationCache {
} }
} }
struct DeriveEncoding;
impl DerivableVisitor for DeriveEncoding {
const ABILITY: Symbol = Symbol::ENCODE_ENCODING;
#[inline(always)]
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
if matches!(
symbol,
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
) {
Ok(Descend(true))
} else {
Err(DerivableError::NotDerivable(var))
}
}
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
Ok(())
}
#[inline(always)]
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
Ok(())
}
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
if is_builtin_number_alias(symbol) {
Ok(Descend(false))
} else {
Ok(Descend(true))
}
}
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
Ok(())
}
}
/// Determines what type implements an ability member of a specialized signature, given the /// Determines what type implements an ability member of a specialized signature, given the
/// [MustImplementAbility] constraints of the signature. /// [MustImplementAbility] constraints of the signature.
pub fn type_implementing_specialization( pub fn type_implementing_specialization(