Add derivability obligation checking for Decode

This commit is contained in:
Ayaz Hafiz 2022-08-01 13:04:52 -05:00
parent a7bc8cf4f2
commit 4bbc6b74fc
No known key found for this signature in database
GPG key ID: 0E2A37416A25EF58
3 changed files with 147 additions and 2 deletions

View file

@ -47,8 +47,10 @@ const SYMBOL_HAS_NICHE: () =
#[cfg(debug_assertions)]
const PRETTY_PRINT_DEBUG_SYMBOLS: bool = true;
pub const DERIVABLE_ABILITIES: &[(Symbol, &[Symbol])] =
&[(Symbol::ENCODE_ENCODING, &[Symbol::ENCODE_TO_ENCODER])];
pub const DERIVABLE_ABILITIES: &[(Symbol, &[Symbol])] = &[
(Symbol::ENCODE_ENCODING, &[Symbol::ENCODE_TO_ENCODER]),
(Symbol::DECODE_DECODING, &[Symbol::DECODE_DECODER]),
];
/// In Debug builds only, Symbol has a name() method that lets
/// you look up its name in a global intern table. This table is

View file

@ -260,6 +260,14 @@ impl ObligationCache {
subs,
var,
)),
Symbol::DECODE_DECODING => Some(DeriveDecoding::is_derivable(
self,
abilities_store,
subs,
var,
)),
_ => None,
};
@ -691,6 +699,66 @@ impl DerivableVisitor for DeriveEncoding {
}
}
struct DeriveDecoding;
impl DerivableVisitor for DeriveDecoding {
const ABILITY: Symbol = Symbol::DECODE_DECODING;
#[inline(always)]
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
if matches!(
symbol,
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
) {
Ok(Descend(true))
} else {
Err(DerivableError::NotDerivable(var))
}
}
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
Ok(())
}
#[inline(always)]
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
Ok(())
}
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
if is_builtin_number_alias(symbol) {
Ok(Descend(false))
} else {
Ok(Descend(true))
}
}
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
Ok(())
}
}
/// Determines what type implements an ability member of a specialized signature, given the
/// [MustImplementAbility] constraints of the signature.
pub fn type_implementing_specialization(