mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-02 16:21:11 +00:00
Switch obligation checking to use a visitor
My hope is this will make obligation checking for other abilities easier to add when we do so
This commit is contained in:
parent
fa14146054
commit
a7bc8cf4f2
1 changed files with 255 additions and 86 deletions
|
@ -5,6 +5,7 @@ use roc_error_macros::internal_error;
|
||||||
use roc_module::symbol::Symbol;
|
use roc_module::symbol::Symbol;
|
||||||
use roc_region::all::{Loc, Region};
|
use roc_region::all::{Loc, Region};
|
||||||
use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled};
|
use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled};
|
||||||
|
use roc_types::num::NumericRange;
|
||||||
use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable};
|
use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable};
|
||||||
use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory};
|
use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory};
|
||||||
use roc_unify::unify::{Env, MustImplementConstraints};
|
use roc_unify::unify::{Env, MustImplementConstraints};
|
||||||
|
@ -253,7 +254,12 @@ impl ObligationCache {
|
||||||
// independent queries.
|
// independent queries.
|
||||||
|
|
||||||
let opt_can_derive_builtin = match ability {
|
let opt_can_derive_builtin = match ability {
|
||||||
Symbol::ENCODE_ENCODING => Some(self.can_derive_encoding(subs, abilities_store, var)),
|
Symbol::ENCODE_ENCODING => Some(DeriveEncoding::is_derivable(
|
||||||
|
self,
|
||||||
|
abilities_store,
|
||||||
|
subs,
|
||||||
|
var,
|
||||||
|
)),
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -262,7 +268,7 @@ impl ObligationCache {
|
||||||
// can derive!
|
// can derive!
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
Some(Err(failure_var)) => Some(if failure_var == var {
|
Some(Err(DerivableError::NotDerivable(failure_var))) => Some(if failure_var == var {
|
||||||
UnderivableReason::SurfaceNotDerivable
|
UnderivableReason::SurfaceNotDerivable
|
||||||
} else {
|
} else {
|
||||||
let (error_type, _skeletons) = subs.var_to_error_type(failure_var);
|
let (error_type, _skeletons) = subs.var_to_error_type(failure_var);
|
||||||
|
@ -391,16 +397,128 @@ impl ObligationCache {
|
||||||
let check_has_fake = self.derive_cache.insert(derive_key, root_result);
|
let check_has_fake = self.derive_cache.insert(derive_key, root_result);
|
||||||
debug_assert_eq!(check_has_fake, Some(fake_fulfilled));
|
debug_assert_eq!(check_has_fake, Some(fake_fulfilled));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// If we have a lot of these, consider using a visitor.
|
#[inline(always)]
|
||||||
// It will be very similar for most types (can't derive functions, can't derive unbound type
|
#[rustfmt::skip]
|
||||||
// variables, can only derive opaques if they have an impl, etc).
|
fn is_builtin_number_alias(symbol: Symbol) -> bool {
|
||||||
fn can_derive_encoding(
|
matches!(symbol,
|
||||||
&mut self,
|
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
|
||||||
subs: &mut Subs,
|
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
|
||||||
|
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
|
||||||
|
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
|
||||||
|
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
|
||||||
|
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
|
||||||
|
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
|
||||||
|
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
|
||||||
|
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
|
||||||
|
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
|
||||||
|
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
|
||||||
|
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
|
||||||
|
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
|
||||||
|
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
enum DerivableError {
|
||||||
|
NotDerivable(Variable),
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Descend(bool);
|
||||||
|
|
||||||
|
trait DerivableVisitor {
|
||||||
|
const ABILITY: Symbol;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_flex(var: Variable) -> Result<(), DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_rigid(var: Variable) -> Result<(), DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
|
||||||
|
if ability != Self::ABILITY {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
|
||||||
|
if ability != Self::ABILITY {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_recursion(var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_apply(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_func(var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_record(var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_tag_union(var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_recursive_tag_union(var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_function_or_tag_union(var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_empty_record(var: Variable) -> Result<(), DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_empty_tag_union(var: Variable) -> Result<(), DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_alias(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn is_derivable(
|
||||||
|
obligation_cache: &mut ObligationCache,
|
||||||
abilities_store: &AbilitiesStore,
|
abilities_store: &AbilitiesStore,
|
||||||
|
subs: &Subs,
|
||||||
var: Variable,
|
var: Variable,
|
||||||
) -> Result<(), Variable> {
|
) -> Result<(), DerivableError> {
|
||||||
let mut stack = vec![var];
|
let mut stack = vec![var];
|
||||||
let mut seen_recursion_vars = vec![];
|
let mut seen_recursion_vars = vec![];
|
||||||
|
|
||||||
|
@ -418,102 +536,93 @@ impl ObligationCache {
|
||||||
let content = subs.get_content_without_compacting(var);
|
let content = subs.get_content_without_compacting(var);
|
||||||
|
|
||||||
use Content::*;
|
use Content::*;
|
||||||
|
use DerivableError::*;
|
||||||
use FlatType::*;
|
use FlatType::*;
|
||||||
match content {
|
match *content {
|
||||||
FlexVar(_) | RigidVar(_) => return Err(var),
|
FlexVar(_) => Self::visit_flex(var)?,
|
||||||
FlexAbleVar(_, ability) | RigidAbleVar(_, ability) => {
|
RigidVar(_) => Self::visit_rigid(var)?,
|
||||||
if *ability != Symbol::ENCODE_ENCODING {
|
FlexAbleVar(_, ability) => Self::visit_flex_able(var, ability)?,
|
||||||
return Err(var);
|
RigidAbleVar(_, ability) => Self::visit_rigid_able(var, ability)?,
|
||||||
}
|
|
||||||
// Any concrete type this variables is instantiated with will also gain a "does
|
|
||||||
// implement" check so this is okay.
|
|
||||||
}
|
|
||||||
RecursionVar {
|
RecursionVar {
|
||||||
structure,
|
structure,
|
||||||
opt_name: _,
|
opt_name: _,
|
||||||
} => {
|
} => {
|
||||||
seen_recursion_vars.push(var);
|
let descend = Self::visit_recursion(var)?;
|
||||||
stack.push(*structure);
|
if descend.0 {
|
||||||
|
seen_recursion_vars.push(var);
|
||||||
|
stack.push(structure);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Structure(flat_type) => match flat_type {
|
Structure(flat_type) => match flat_type {
|
||||||
Apply(
|
Apply(symbol, vars) => {
|
||||||
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
|
let descend = Self::visit_apply(var, symbol)?;
|
||||||
vars,
|
if descend.0 {
|
||||||
) => push_var_slice!(*vars),
|
push_var_slice!(vars)
|
||||||
Apply(..) => return Err(var),
|
|
||||||
Func(..) => {
|
|
||||||
return Err(var);
|
|
||||||
}
|
|
||||||
Record(fields, var) => {
|
|
||||||
push_var_slice!(fields.variables());
|
|
||||||
stack.push(*var);
|
|
||||||
}
|
|
||||||
TagUnion(tags, ext_var) => {
|
|
||||||
for i in tags.variables() {
|
|
||||||
push_var_slice!(subs[i]);
|
|
||||||
}
|
}
|
||||||
stack.push(*ext_var);
|
|
||||||
}
|
}
|
||||||
FunctionOrTagUnion(_, _, var) => stack.push(*var),
|
Func(args, _clos, ret) => {
|
||||||
RecursiveTagUnion(rec_var, tags, ext_var) => {
|
let descend = Self::visit_func(var)?;
|
||||||
seen_recursion_vars.push(*rec_var);
|
if descend.0 {
|
||||||
for i in tags.variables() {
|
push_var_slice!(args);
|
||||||
push_var_slice!(subs[i]);
|
stack.push(ret);
|
||||||
}
|
}
|
||||||
stack.push(*ext_var);
|
|
||||||
}
|
}
|
||||||
EmptyRecord | EmptyTagUnion => {
|
Record(fields, ext) => {
|
||||||
// yes
|
let descend = Self::visit_record(var)?;
|
||||||
|
if descend.0 {
|
||||||
|
push_var_slice!(fields.variables());
|
||||||
|
stack.push(ext);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Erroneous(_) => return Err(var),
|
TagUnion(tags, ext) => {
|
||||||
|
let descend = Self::visit_tag_union(var)?;
|
||||||
|
if descend.0 {
|
||||||
|
for i in tags.variables() {
|
||||||
|
push_var_slice!(subs[i]);
|
||||||
|
}
|
||||||
|
stack.push(ext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FunctionOrTagUnion(_tag_name, _fn_name, ext) => {
|
||||||
|
let descend = Self::visit_function_or_tag_union(var)?;
|
||||||
|
if descend.0 {
|
||||||
|
stack.push(ext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RecursiveTagUnion(rec, tags, ext) => {
|
||||||
|
let descend = Self::visit_recursive_tag_union(var)?;
|
||||||
|
if descend.0 {
|
||||||
|
seen_recursion_vars.push(rec);
|
||||||
|
for i in tags.variables() {
|
||||||
|
push_var_slice!(subs[i]);
|
||||||
|
}
|
||||||
|
stack.push(ext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
EmptyRecord => Self::visit_empty_record(var)?,
|
||||||
|
EmptyTagUnion => Self::visit_empty_tag_union(var)?,
|
||||||
|
|
||||||
|
Erroneous(_) => return Err(NotDerivable(var)),
|
||||||
},
|
},
|
||||||
#[rustfmt::skip]
|
Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => {
|
||||||
Alias(
|
if obligation_cache
|
||||||
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
|
.check_opaque_and_read(abilities_store, opaque, Self::ABILITY)
|
||||||
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
|
|
||||||
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
|
|
||||||
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
|
|
||||||
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
|
|
||||||
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
|
|
||||||
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
|
|
||||||
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
|
|
||||||
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
|
|
||||||
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
|
|
||||||
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
|
|
||||||
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
|
|
||||||
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
|
|
||||||
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
) => {
|
|
||||||
// yes
|
|
||||||
}
|
|
||||||
Alias(
|
|
||||||
Symbol::NUM_NUM | Symbol::NUM_INTEGER | Symbol::NUM_FLOATINGPOINT,
|
|
||||||
_,
|
|
||||||
real_var,
|
|
||||||
_,
|
|
||||||
) => stack.push(*real_var),
|
|
||||||
Alias(name, _, _, AliasKind::Opaque) => {
|
|
||||||
let opaque = *name;
|
|
||||||
if self
|
|
||||||
.check_opaque_and_read(abilities_store, opaque, Symbol::ENCODE_ENCODING)
|
|
||||||
.is_err()
|
.is_err()
|
||||||
{
|
{
|
||||||
return Err(var);
|
return Err(NotDerivable(var));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Alias(_, arguments, real_type_var, _) => {
|
Alias(symbol, _alias_variables, real_var, AliasKind::Structural) => {
|
||||||
push_var_slice!(arguments.all_variables());
|
let descend = Self::visit_alias(var, symbol)?;
|
||||||
stack.push(*real_type_var);
|
if descend.0 {
|
||||||
|
stack.push(real_var);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
RangedNumber(..) => {
|
RangedNumber(range) => Self::visit_ranged_number(var, range)?,
|
||||||
// yes, all numbers can
|
|
||||||
}
|
LambdaSet(..) => return Err(NotDerivable(var)),
|
||||||
LambdaSet(..) => return Err(var),
|
|
||||||
Error => {
|
Error => {
|
||||||
return Err(var);
|
return Err(NotDerivable(var));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -522,6 +631,66 @@ impl ObligationCache {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DeriveEncoding;
|
||||||
|
impl DerivableVisitor for DeriveEncoding {
|
||||||
|
const ABILITY: Symbol = Symbol::ENCODE_ENCODING;
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Ok(Descend(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||||
|
if matches!(
|
||||||
|
symbol,
|
||||||
|
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
|
||||||
|
) {
|
||||||
|
Ok(Descend(true))
|
||||||
|
} else {
|
||||||
|
Err(DerivableError::NotDerivable(var))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Ok(Descend(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Ok(Descend(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Ok(Descend(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
|
||||||
|
Ok(Descend(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||||
|
if is_builtin_number_alias(symbol) {
|
||||||
|
Ok(Descend(false))
|
||||||
|
} else {
|
||||||
|
Ok(Descend(true))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Determines what type implements an ability member of a specialized signature, given the
|
/// Determines what type implements an ability member of a specialized signature, given the
|
||||||
/// [MustImplementAbility] constraints of the signature.
|
/// [MustImplementAbility] constraints of the signature.
|
||||||
pub fn type_implementing_specialization(
|
pub fn type_implementing_specialization(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue