diff --git a/crates/compiler/solve/src/ability.rs b/crates/compiler/solve/src/ability.rs index 392f129fa9..c9808842db 100644 --- a/crates/compiler/solve/src/ability.rs +++ b/crates/compiler/solve/src/ability.rs @@ -5,6 +5,7 @@ use roc_error_macros::internal_error; use roc_module::symbol::Symbol; use roc_region::all::{Loc, Region}; use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled}; +use roc_types::num::NumericRange; use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable}; use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory}; use roc_unify::unify::{Env, MustImplementConstraints}; @@ -253,7 +254,12 @@ impl ObligationCache { // independent queries. let opt_can_derive_builtin = match ability { - Symbol::ENCODE_ENCODING => Some(self.can_derive_encoding(subs, abilities_store, var)), + Symbol::ENCODE_ENCODING => Some(DeriveEncoding::is_derivable( + self, + abilities_store, + subs, + var, + )), _ => None, }; @@ -262,7 +268,7 @@ impl ObligationCache { // can derive! None } - Some(Err(failure_var)) => Some(if failure_var == var { + Some(Err(DerivableError::NotDerivable(failure_var))) => Some(if failure_var == var { UnderivableReason::SurfaceNotDerivable } else { let (error_type, _skeletons) = subs.var_to_error_type(failure_var); @@ -391,16 +397,128 @@ impl ObligationCache { let check_has_fake = self.derive_cache.insert(derive_key, root_result); debug_assert_eq!(check_has_fake, Some(fake_fulfilled)); } +} - // If we have a lot of these, consider using a visitor. - // It will be very similar for most types (can't derive functions, can't derive unbound type - // variables, can only derive opaques if they have an impl, etc). - fn can_derive_encoding( - &mut self, - subs: &mut Subs, +#[inline(always)] +#[rustfmt::skip] +fn is_builtin_number_alias(symbol: Symbol) -> bool { + matches!(symbol, + Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8 + | Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16 + | Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32 + | Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64 + | Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128 + | Symbol::NUM_I8 | Symbol::NUM_SIGNED8 + | Symbol::NUM_I16 | Symbol::NUM_SIGNED16 + | Symbol::NUM_I32 | Symbol::NUM_SIGNED32 + | Symbol::NUM_I64 | Symbol::NUM_SIGNED64 + | Symbol::NUM_I128 | Symbol::NUM_SIGNED128 + | Symbol::NUM_NAT | Symbol::NUM_NATURAL + | Symbol::NUM_F32 | Symbol::NUM_BINARY32 + | Symbol::NUM_F64 | Symbol::NUM_BINARY64 + | Symbol::NUM_DEC | Symbol::NUM_DECIMAL, + ) +} + +enum DerivableError { + NotDerivable(Variable), +} + +struct Descend(bool); + +trait DerivableVisitor { + const ABILITY: Symbol; + + #[inline(always)] + fn visit_flex(var: Variable) -> Result<(), DerivableError> { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_rigid(var: Variable) -> Result<(), DerivableError> { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> { + if ability != Self::ABILITY { + Err(DerivableError::NotDerivable(var)) + } else { + Ok(()) + } + } + + #[inline(always)] + fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> { + if ability != Self::ABILITY { + Err(DerivableError::NotDerivable(var)) + } else { + Ok(()) + } + } + + #[inline(always)] + fn visit_recursion(var: Variable) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_apply(var: Variable, _symbol: Symbol) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_func(var: Variable) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_record(var: Variable) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_tag_union(var: Variable) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_recursive_tag_union(var: Variable) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_function_or_tag_union(var: Variable) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_empty_record(var: Variable) -> Result<(), DerivableError> { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_empty_tag_union(var: Variable) -> Result<(), DerivableError> { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_alias(var: Variable, _symbol: Symbol) -> Result { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), DerivableError> { + Err(DerivableError::NotDerivable(var)) + } + + #[inline(always)] + fn is_derivable( + obligation_cache: &mut ObligationCache, abilities_store: &AbilitiesStore, + subs: &Subs, var: Variable, - ) -> Result<(), Variable> { + ) -> Result<(), DerivableError> { let mut stack = vec![var]; let mut seen_recursion_vars = vec![]; @@ -418,102 +536,93 @@ impl ObligationCache { let content = subs.get_content_without_compacting(var); use Content::*; + use DerivableError::*; use FlatType::*; - match content { - FlexVar(_) | RigidVar(_) => return Err(var), - FlexAbleVar(_, ability) | RigidAbleVar(_, ability) => { - if *ability != Symbol::ENCODE_ENCODING { - return Err(var); - } - // Any concrete type this variables is instantiated with will also gain a "does - // implement" check so this is okay. - } + match *content { + FlexVar(_) => Self::visit_flex(var)?, + RigidVar(_) => Self::visit_rigid(var)?, + FlexAbleVar(_, ability) => Self::visit_flex_able(var, ability)?, + RigidAbleVar(_, ability) => Self::visit_rigid_able(var, ability)?, RecursionVar { structure, opt_name: _, } => { - seen_recursion_vars.push(var); - stack.push(*structure); + let descend = Self::visit_recursion(var)?; + if descend.0 { + seen_recursion_vars.push(var); + stack.push(structure); + } } Structure(flat_type) => match flat_type { - Apply( - Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR, - vars, - ) => push_var_slice!(*vars), - Apply(..) => return Err(var), - Func(..) => { - return Err(var); - } - Record(fields, var) => { - push_var_slice!(fields.variables()); - stack.push(*var); - } - TagUnion(tags, ext_var) => { - for i in tags.variables() { - push_var_slice!(subs[i]); + Apply(symbol, vars) => { + let descend = Self::visit_apply(var, symbol)?; + if descend.0 { + push_var_slice!(vars) } - stack.push(*ext_var); } - FunctionOrTagUnion(_, _, var) => stack.push(*var), - RecursiveTagUnion(rec_var, tags, ext_var) => { - seen_recursion_vars.push(*rec_var); - for i in tags.variables() { - push_var_slice!(subs[i]); + Func(args, _clos, ret) => { + let descend = Self::visit_func(var)?; + if descend.0 { + push_var_slice!(args); + stack.push(ret); } - stack.push(*ext_var); } - EmptyRecord | EmptyTagUnion => { - // yes + Record(fields, ext) => { + let descend = Self::visit_record(var)?; + if descend.0 { + push_var_slice!(fields.variables()); + stack.push(ext); + } } - Erroneous(_) => return Err(var), + TagUnion(tags, ext) => { + let descend = Self::visit_tag_union(var)?; + if descend.0 { + for i in tags.variables() { + push_var_slice!(subs[i]); + } + stack.push(ext); + } + } + FunctionOrTagUnion(_tag_name, _fn_name, ext) => { + let descend = Self::visit_function_or_tag_union(var)?; + if descend.0 { + stack.push(ext); + } + } + RecursiveTagUnion(rec, tags, ext) => { + let descend = Self::visit_recursive_tag_union(var)?; + if descend.0 { + seen_recursion_vars.push(rec); + for i in tags.variables() { + push_var_slice!(subs[i]); + } + stack.push(ext); + } + } + EmptyRecord => Self::visit_empty_record(var)?, + EmptyTagUnion => Self::visit_empty_tag_union(var)?, + + Erroneous(_) => return Err(NotDerivable(var)), }, - #[rustfmt::skip] - Alias( - Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8 - | Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16 - | Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32 - | Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64 - | Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128 - | Symbol::NUM_I8 | Symbol::NUM_SIGNED8 - | Symbol::NUM_I16 | Symbol::NUM_SIGNED16 - | Symbol::NUM_I32 | Symbol::NUM_SIGNED32 - | Symbol::NUM_I64 | Symbol::NUM_SIGNED64 - | Symbol::NUM_I128 | Symbol::NUM_SIGNED128 - | Symbol::NUM_NAT | Symbol::NUM_NATURAL - | Symbol::NUM_F32 | Symbol::NUM_BINARY32 - | Symbol::NUM_F64 | Symbol::NUM_BINARY64 - | Symbol::NUM_DEC | Symbol::NUM_DECIMAL, - _, - _, - _, - ) => { - // yes - } - Alias( - Symbol::NUM_NUM | Symbol::NUM_INTEGER | Symbol::NUM_FLOATINGPOINT, - _, - real_var, - _, - ) => stack.push(*real_var), - Alias(name, _, _, AliasKind::Opaque) => { - let opaque = *name; - if self - .check_opaque_and_read(abilities_store, opaque, Symbol::ENCODE_ENCODING) + Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => { + if obligation_cache + .check_opaque_and_read(abilities_store, opaque, Self::ABILITY) .is_err() { - return Err(var); + return Err(NotDerivable(var)); } } - Alias(_, arguments, real_type_var, _) => { - push_var_slice!(arguments.all_variables()); - stack.push(*real_type_var); + Alias(symbol, _alias_variables, real_var, AliasKind::Structural) => { + let descend = Self::visit_alias(var, symbol)?; + if descend.0 { + stack.push(real_var); + } } - RangedNumber(..) => { - // yes, all numbers can - } - LambdaSet(..) => return Err(var), + RangedNumber(range) => Self::visit_ranged_number(var, range)?, + + LambdaSet(..) => return Err(NotDerivable(var)), Error => { - return Err(var); + return Err(NotDerivable(var)); } } } @@ -522,6 +631,66 @@ impl ObligationCache { } } +struct DeriveEncoding; +impl DerivableVisitor for DeriveEncoding { + const ABILITY: Symbol = Symbol::ENCODE_ENCODING; + + #[inline(always)] + fn visit_recursion(_var: Variable) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_apply(var: Variable, symbol: Symbol) -> Result { + if matches!( + symbol, + Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR, + ) { + Ok(Descend(true)) + } else { + Err(DerivableError::NotDerivable(var)) + } + } + + fn visit_record(_var: Variable) -> Result { + Ok(Descend(true)) + } + + fn visit_tag_union(_var: Variable) -> Result { + Ok(Descend(true)) + } + + fn visit_recursive_tag_union(_var: Variable) -> Result { + Ok(Descend(true)) + } + + fn visit_function_or_tag_union(_var: Variable) -> Result { + Ok(Descend(true)) + } + + #[inline(always)] + fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> { + Ok(()) + } + + #[inline(always)] + fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> { + Ok(()) + } + + fn visit_alias(_var: Variable, symbol: Symbol) -> Result { + if is_builtin_number_alias(symbol) { + Ok(Descend(false)) + } else { + Ok(Descend(true)) + } + } + + fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> { + Ok(()) + } +} + /// Determines what type implements an ability member of a specialized signature, given the /// [MustImplementAbility] constraints of the signature. pub fn type_implementing_specialization(