Merge pull request #3734 from roc-lang/decoding-optional-record-fields-illegal

Report errors for attempting to derive decoding of records with optional field types
This commit is contained in:
Richard Feldman 2022-08-27 21:12:44 -04:00 committed by GitHub
commit adb89bbf82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 324 additions and 118 deletions

View file

@ -4,9 +4,13 @@ use roc_collections::{VecMap, VecSet};
use roc_error_macros::{internal_error, todo_abilities};
use roc_module::symbol::Symbol;
use roc_region::all::{Loc, Region};
use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled};
use roc_solve_problem::{
NotDerivableContext, NotDerivableDecode, TypeError, UnderivableReason, Unfulfilled,
};
use roc_types::num::NumericRange;
use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable};
use roc_types::subs::{
instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, RecordFields, Subs, Variable,
};
use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory};
use roc_unify::unify::{Env, MustImplementConstraints};
use roc_unify::unify::{MustImplementAbility, Obligated};
@ -276,11 +280,14 @@ impl ObligationCache {
// can derive!
None
}
Some(Err(DerivableError::NotDerivable(failure_var))) => Some(if failure_var == var {
UnderivableReason::SurfaceNotDerivable
Some(Err(NotDerivable {
var: failure_var,
context,
})) => Some(if failure_var == var {
UnderivableReason::SurfaceNotDerivable(context)
} else {
let (error_type, _skeletons) = subs.var_to_error_type(failure_var);
UnderivableReason::NestedNotDerivable(error_type)
UnderivableReason::NestedNotDerivable(error_type, context)
}),
None => Some(UnderivableReason::NotABuiltin),
};
@ -428,8 +435,9 @@ fn is_builtin_number_alias(symbol: Symbol) -> bool {
)
}
enum DerivableError {
NotDerivable(Variable),
struct NotDerivable {
var: Variable,
context: NotDerivableContext,
}
struct Descend(bool);
@ -443,76 +451,119 @@ trait DerivableVisitor {
}
#[inline(always)]
fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), NotDerivable> {
if ability != Self::ABILITY {
Err(DerivableError::NotDerivable(var))
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
} else {
Ok(())
}
}
#[inline(always)]
fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), NotDerivable> {
if ability != Self::ABILITY {
Err(DerivableError::NotDerivable(var))
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
} else {
Ok(())
}
}
#[inline(always)]
fn visit_recursion(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_recursion(var: Variable) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_apply(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_apply(var: Variable, _symbol: Symbol) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_func(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_func(var: Variable) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::Function,
})
}
#[inline(always)]
fn visit_record(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_record(
_subs: &Subs,
var: Variable,
_fields: RecordFields,
) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_tag_union(var: Variable) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_recursive_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_recursive_tag_union(var: Variable) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_function_or_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_function_or_tag_union(var: Variable) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_empty_record(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_empty_record(var: Variable) -> Result<(), NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_empty_tag_union(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_empty_tag_union(var: Variable) -> Result<(), NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_alias(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_alias(var: Variable, _symbol: Symbol) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
@ -521,7 +572,7 @@ trait DerivableVisitor {
abilities_store: &AbilitiesStore,
subs: &mut Subs,
var: Variable,
) -> Result<(), DerivableError> {
) -> Result<(), NotDerivable> {
let mut stack = vec![var];
let mut seen_recursion_vars = vec![];
@ -539,14 +590,18 @@ trait DerivableVisitor {
let content = subs.get_content_without_compacting(var);
use Content::*;
use DerivableError::*;
use FlatType::*;
match *content {
FlexVar(opt_name) => {
// Promote the flex var to be bound to the ability.
subs.set_content(var, Content::FlexAbleVar(opt_name, Self::ABILITY));
}
RigidVar(_) => return Err(NotDerivable(var)),
RigidVar(_) => {
return Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
FlexAbleVar(_, ability) => Self::visit_flex_able(var, ability)?,
RigidAbleVar(_, ability) => Self::visit_rigid_able(var, ability)?,
RecursionVar {
@ -574,7 +629,7 @@ trait DerivableVisitor {
}
}
Record(fields, ext) => {
let descend = Self::visit_record(var)?;
let descend = Self::visit_record(subs, var, fields)?;
if descend.0 {
push_var_slice!(fields.variables());
if !matches!(
@ -616,7 +671,12 @@ trait DerivableVisitor {
EmptyRecord => Self::visit_empty_record(var)?,
EmptyTagUnion => Self::visit_empty_tag_union(var)?,
Erroneous(_) => return Err(NotDerivable(var)),
Erroneous(_) => {
return Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
},
Alias(
Symbol::NUM_NUM | Symbol::NUM_INTEGER | Symbol::NUM_FLOATINGPOINT,
@ -633,7 +693,10 @@ trait DerivableVisitor {
.is_err()
&& !Self::is_derivable_builtin_opaque(opaque)
{
return Err(NotDerivable(var));
return Err(NotDerivable {
var,
context: NotDerivableContext::Opaque(opaque),
});
}
}
Alias(symbol, _alias_variables, real_var, AliasKind::Structural) => {
@ -644,9 +707,17 @@ trait DerivableVisitor {
}
RangedNumber(range) => Self::visit_ranged_number(var, range)?,
LambdaSet(..) => return Err(NotDerivable(var)),
LambdaSet(..) => {
return Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
Error => {
return Err(NotDerivable(var));
return Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
});
}
}
}
@ -665,54 +736,61 @@ impl DerivableVisitor for DeriveEncoding {
}
#[inline(always)]
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_recursion(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
if matches!(
symbol,
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
) {
Ok(Descend(true))
} else {
Err(DerivableError::NotDerivable(var))
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
}
#[inline(always)]
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_record(
_subs: &Subs,
_var: Variable,
_fields: RecordFields,
) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
fn visit_empty_record(_var: Variable) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
fn visit_empty_tag_union(_var: Variable) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
if is_builtin_number_alias(symbol) {
Ok(Descend(false))
} else {
@ -721,7 +799,7 @@ impl DerivableVisitor for DeriveEncoding {
}
#[inline(always)]
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(())
}
}
@ -736,54 +814,72 @@ impl DerivableVisitor for DeriveDecoding {
}
#[inline(always)]
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_recursion(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
if matches!(
symbol,
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
) {
Ok(Descend(true))
} else {
Err(DerivableError::NotDerivable(var))
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
}
#[inline(always)]
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_record(
subs: &Subs,
var: Variable,
fields: RecordFields,
) -> Result<Descend, NotDerivable> {
for (field_name, _, field) in fields.iter_all() {
if subs[field].is_optional() {
return Err(NotDerivable {
var,
context: NotDerivableContext::Decode(NotDerivableDecode::OptionalRecordField(
subs[field_name].clone(),
)),
});
}
}
Ok(Descend(true))
}
#[inline(always)]
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, NotDerivable> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
fn visit_empty_record(_var: Variable) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
fn visit_empty_tag_union(_var: Variable) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, NotDerivable> {
if is_builtin_number_alias(symbol) {
Ok(Descend(false))
} else {
@ -792,7 +888,7 @@ impl DerivableVisitor for DeriveDecoding {
}
#[inline(always)]
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(())
}
}

View file

@ -1,5 +1,5 @@
use roc_can::expected::{Expected, PExpected};
use roc_module::symbol::Symbol;
use roc_module::{ident::Lowercase, symbol::Symbol};
use roc_problem::can::CycleEntry;
use roc_region::all::Region;
@ -55,7 +55,21 @@ pub enum Unfulfilled {
pub enum UnderivableReason {
NotABuiltin,
/// The surface type is not derivable
SurfaceNotDerivable,
SurfaceNotDerivable(NotDerivableContext),
/// A nested type is not derivable
NestedNotDerivable(ErrorType),
NestedNotDerivable(ErrorType, NotDerivableContext),
}
#[derive(PartialEq, Debug, Clone)]
pub enum NotDerivableContext {
NoContext,
Function,
UnboundVar,
Opaque(Symbol),
Decode(NotDerivableDecode),
}
#[derive(PartialEq, Debug, Clone)]
pub enum NotDerivableDecode {
OptionalRecordField(Lowercase),
}

View file

@ -104,7 +104,10 @@ impl<T> RecordField<T> {
}
pub fn is_optional(&self) -> bool {
matches!(self, RecordField::Optional(..))
matches!(
self,
RecordField::Optional(..) | RecordField::RigidOptional(..)
)
}
}