From 05e8e6de6fdf8716d492804a302becd87fb4a93c Mon Sep 17 00:00:00 2001 From: Ayaz Hafiz Date: Tue, 18 Oct 2022 15:50:20 -0500 Subject: [PATCH] Disallow typing optional fields when required fields are annotated Closes #4313 --- crates/ast/src/lang/core/types.rs | 19 +----- crates/ast/src/solve_type.rs | 9 +++ crates/compiler/can/src/annotation.rs | 4 +- crates/compiler/mono/src/layout.rs | 6 +- crates/compiler/mono/src/layout_soa.rs | 4 +- crates/compiler/solve/src/solve.rs | 20 ++++-- crates/compiler/types/src/pretty_print.rs | 6 +- crates/compiler/types/src/subs.rs | 2 + crates/compiler/types/src/types.rs | 77 +++++++++++++---------- crates/compiler/unify/src/unify.rs | 16 ++++- crates/glue/src/types.rs | 4 +- crates/reporting/src/error/type.rs | 24 +++---- crates/reporting/tests/test_reporting.rs | 27 ++++++++ 13 files changed, 133 insertions(+), 85 deletions(-) diff --git a/crates/ast/src/lang/core/types.rs b/crates/ast/src/lang/core/types.rs index 5216d43bc2..25ce0cc0bf 100644 --- a/crates/ast/src/lang/core/types.rs +++ b/crates/ast/src/lang/core/types.rs @@ -414,24 +414,7 @@ pub fn to_type2<'a>( for (node_id, (label, field)) in field_types.iter_node_ids().zip(field_types_map) { let poolstr = PoolStr::new(label.as_str(), env.pool); - let rec_field = match field { - RecordField::Optional(_) => { - let field_id = env.pool.add(field.into_inner()); - RecordField::Optional(field_id) - } - RecordField::RigidOptional(_) => { - let field_id = env.pool.add(field.into_inner()); - RecordField::RigidOptional(field_id) - } - RecordField::Demanded(_) => { - let field_id = env.pool.add(field.into_inner()); - RecordField::Demanded(field_id) - } - RecordField::Required(_) => { - let field_id = env.pool.add(field.into_inner()); - RecordField::Required(field_id) - } - }; + let rec_field = field.map_owned(|field| env.pool.add(field)); env.pool[node_id] = (poolstr, rec_field); } diff --git a/crates/ast/src/solve_type.rs b/crates/ast/src/solve_type.rs index fabd83e825..b84ead1cb9 100644 --- a/crates/ast/src/solve_type.rs +++ b/crates/ast/src/solve_type.rs @@ -834,6 +834,15 @@ fn type_to_variable<'a>( cached, mempool.get(*type_id), )), + RigidRequired(type_id) => RigidRequired(type_to_variable( + arena, + mempool, + subs, + rank, + pools, + cached, + mempool.get(*type_id), + )), Optional(type_id) => Optional(type_to_variable( arena, mempool, diff --git a/crates/compiler/can/src/annotation.rs b/crates/compiler/can/src/annotation.rs index 7c5e789f4c..eda2d8d13e 100644 --- a/crates/compiler/can/src/annotation.rs +++ b/crates/compiler/can/src/annotation.rs @@ -1209,7 +1209,7 @@ fn can_assigned_fields<'a>( ); let label = Lowercase::from(field_name.value); - field_types.insert(label.clone(), Required(field_type)); + field_types.insert(label.clone(), RigidRequired(field_type)); break 'inner label; } @@ -1246,7 +1246,7 @@ fn can_assigned_fields<'a>( } }; - field_types.insert(field_name.clone(), Required(field_type)); + field_types.insert(field_name.clone(), RigidRequired(field_type)); break 'inner field_name; } diff --git a/crates/compiler/mono/src/layout.rs b/crates/compiler/mono/src/layout.rs index a184594e01..3b22bd379c 100644 --- a/crates/compiler/mono/src/layout.rs +++ b/crates/compiler/mono/src/layout.rs @@ -3117,7 +3117,9 @@ fn layout_from_flat_type<'a>( for (label, field) in it { match field { - RecordField::Required(field_var) | RecordField::Demanded(field_var) => { + RecordField::Required(field_var) + | RecordField::Demanded(field_var) + | RecordField::RigidRequired(field_var) => { sortables .push((label, cached!(Layout::from_var(env, field_var), criteria))); } @@ -3220,7 +3222,7 @@ fn sort_record_fields_help<'a>( for (label, field) in fields_map { match field { - RecordField::Demanded(v) | RecordField::Required(v) => { + RecordField::Demanded(v) | RecordField::Required(v) | RecordField::RigidRequired(v) => { let Cacheable(layout, _) = Layout::from_var(env, v); sorted_fields.push((label, v, Ok(layout?))); } diff --git a/crates/compiler/mono/src/layout_soa.rs b/crates/compiler/mono/src/layout_soa.rs index 0d2262273e..63c8079078 100644 --- a/crates/compiler/mono/src/layout_soa.rs +++ b/crates/compiler/mono/src/layout_soa.rs @@ -807,7 +807,9 @@ impl Layout { RecordField::Optional(_) | RecordField::RigidOptional(_) => { // do nothing } - RecordField::Required(_) | RecordField::Demanded(_) => { + RecordField::Required(_) + | RecordField::Demanded(_) + | RecordField::RigidRequired(_) => { let var = subs.variables[var_index.index as usize]; let layout = Layout::from_var_help(layouts, subs, var)?; diff --git a/crates/compiler/solve/src/solve.rs b/crates/compiler/solve/src/solve.rs index bdfcb24265..2676d6f25c 100644 --- a/crates/compiler/solve/src/solve.rs +++ b/crates/compiler/solve/src/solve.rs @@ -2427,6 +2427,7 @@ fn type_to_variable<'a>( Optional(t) => Optional(helper!(t)), Required(t) => Required(helper!(t)), Demanded(t) => Demanded(helper!(t)), + RigidRequired(t) => RigidRequired(helper!(t)), RigidOptional(t) => RigidOptional(helper!(t)), } }; @@ -3413,11 +3414,17 @@ fn adjust_rank_content( let var = subs[var_index]; rank = rank.max(adjust_rank(subs, young_mark, visit_mark, group_rank, var)); - // When generalizing annotations with rigid optionals, we want to promote - // them to non-rigid, so that usages at specialized sites don't have to - // exactly include the optional field. - if let RecordField::RigidOptional(()) = subs[field_index] { - subs[field_index] = RecordField::Optional(()); + // When generalizing annotations with rigid optional/required fields, + // we want to promote them to non-rigid, so that usages at + // specialized sites don't have to exactly include the optional/required field. + match subs[field_index] { + RecordField::RigidOptional(()) => { + subs[field_index] = RecordField::Optional(()); + } + RecordField::RigidRequired(()) => { + subs[field_index] = RecordField::Required(()); + } + _ => {} } } @@ -3791,7 +3798,8 @@ fn deep_copy_var_help( let slice = SubsSlice::extend_new( &mut subs.record_fields, field_types.into_iter().map(|f| match f { - RecordField::RigidOptional(()) => internal_error!("RigidOptionals should be generalized to non-rigid by this point"), + RecordField::RigidOptional(()) + | RecordField::RigidRequired(()) => internal_error!("Rigid optional/required should be generalized to non-rigid by this point"), RecordField::Demanded(_) | RecordField::Required(_) diff --git a/crates/compiler/types/src/pretty_print.rs b/crates/compiler/types/src/pretty_print.rs index 0c930cf25f..ffdf9ea618 100644 --- a/crates/compiler/types/src/pretty_print.rs +++ b/crates/compiler/types/src/pretty_print.rs @@ -1104,10 +1104,8 @@ fn write_flat_type<'a>( buf.push_str(label.as_str()); match record_field { - Optional(_) => buf.push_str(" ? "), - Required(_) => buf.push_str(" : "), - Demanded(_) => buf.push_str(" : "), - RigidOptional(_) => buf.push_str(" ? "), + Optional(_) | RigidOptional(_) => buf.push_str(" ? "), + Required(_) | Demanded(_) | RigidRequired(_) => buf.push_str(" : "), }; write_content( diff --git a/crates/compiler/types/src/subs.rs b/crates/compiler/types/src/subs.rs index b433675a2e..a72811eb22 100644 --- a/crates/compiler/types/src/subs.rs +++ b/crates/compiler/types/src/subs.rs @@ -920,6 +920,7 @@ fn subs_fmt_flat_type(this: &FlatType, subs: &Subs, f: &mut fmt::Formatter) -> f RecordField::RigidOptional(_) => "r?", RecordField::Required(_) => ":", RecordField::Demanded(_) => ":", + RecordField::RigidRequired(_) => "r:", }; write!( f, @@ -3831,6 +3832,7 @@ fn flat_type_to_err_type( Required(_) => Required(error_type), Demanded(_) => Demanded(error_type), RigidOptional(_) => RigidOptional(error_type), + RigidRequired(_) => RigidRequired(error_type), }; err_fields.insert(label, err_record_field); diff --git a/crates/compiler/types/src/types.rs b/crates/compiler/types/src/types.rs index 7fba003170..07fb21507d 100644 --- a/crates/compiler/types/src/types.rs +++ b/crates/compiler/types/src/types.rs @@ -27,10 +27,12 @@ const GREEK_LETTERS: &[char] = &[ /// /// - Demanded: only introduced by pattern matches, e.g. { x } -> /// Cannot unify with an Optional field, but can unify with a Required field -/// - Required: introduced by record literals and type annotations. +/// - Required: introduced by record literals /// Can unify with Optional and Demanded /// - Optional: introduced by pattern matches, e.g. { x ? "" } -> /// Can unify with Required, but not with Demanded +/// - RigidRequired: introduced by annotations, e.g. { x : Str} +/// Can only unify with Required and Demanded, to prevent an optional field being typed as Required /// - RigidOptional: introduced by annotations, e.g. { x ? Str} /// Can only unify with Optional, to prevent a required field being typed as Optional #[derive(PartialEq, Eq, Clone, Hash)] @@ -38,6 +40,7 @@ pub enum RecordField { Demanded(T), Required(T), Optional(T), + RigidRequired(T), RigidOptional(T), } @@ -51,6 +54,7 @@ impl fmt::Debug for RecordField { Optional(typ) => write!(f, "Optional({:?})", typ), Required(typ) => write!(f, "Required({:?})", typ), Demanded(typ) => write!(f, "Demanded({:?})", typ), + RigidRequired(typ) => write!(f, "RigidRequired({:?})", typ), RigidOptional(typ) => write!(f, "RigidOptional({:?})", typ), } } @@ -64,6 +68,7 @@ impl RecordField { Optional(t) => t, Required(t) => t, Demanded(t) => t, + RigidRequired(t) => t, RigidOptional(t) => t, } } @@ -75,6 +80,7 @@ impl RecordField { Optional(t) => t, Required(t) => t, Demanded(t) => t, + RigidRequired(t) => t, RigidOptional(t) => t, } } @@ -86,23 +92,43 @@ impl RecordField { Optional(t) => t, Required(t) => t, Demanded(t) => t, + RigidRequired(t) => t, RigidOptional(t) => t, } } - pub fn map(&self, mut f: F) -> RecordField + pub fn map(&self, f: F) -> RecordField where - F: FnMut(&T) -> U, + F: FnOnce(&T) -> U, + { + self.replace(f(self.as_inner())) + } + + pub fn map_owned(self, f: F) -> RecordField + where + F: FnOnce(T) -> U, { use RecordField::*; match self { Optional(t) => Optional(f(t)), Required(t) => Required(f(t)), Demanded(t) => Demanded(f(t)), + RigidRequired(t) => RigidRequired(f(t)), RigidOptional(t) => RigidOptional(f(t)), } } + pub fn replace(&self, u: U) -> RecordField { + use RecordField::*; + match self { + Optional(_) => Optional(u), + Required(_) => Required(u), + Demanded(_) => Demanded(u), + RigidRequired(_) => RigidRequired(u), + RigidOptional(_) => RigidOptional(u), + } + } + pub fn is_optional(&self) -> bool { matches!( self, @@ -119,6 +145,7 @@ impl RecordField { Optional(typ) => typ.substitute(substitutions), Required(typ) => typ.substitute(substitutions), Demanded(typ) => typ.substitute(substitutions), + RigidRequired(typ) => typ.substitute(substitutions), RigidOptional(typ) => typ.substitute(substitutions), } } @@ -135,6 +162,7 @@ impl RecordField { Optional(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), Required(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), Demanded(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), + RigidRequired(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), RigidOptional(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), } } @@ -154,6 +182,7 @@ impl RecordField { Optional(typ) => typ.instantiate_aliases(region, aliases, var_store, introduced), Required(typ) => typ.instantiate_aliases(region, aliases, var_store, introduced), Demanded(typ) => typ.instantiate_aliases(region, aliases, var_store, introduced), + RigidRequired(typ) => typ.instantiate_aliases(region, aliases, var_store, introduced), RigidOptional(typ) => typ.instantiate_aliases(region, aliases, var_store, introduced), } } @@ -165,6 +194,7 @@ impl RecordField { Optional(typ) => typ.contains_symbol(rep_symbol), Required(typ) => typ.contains_symbol(rep_symbol), Demanded(typ) => typ.contains_symbol(rep_symbol), + RigidRequired(typ) => typ.contains_symbol(rep_symbol), RigidOptional(typ) => typ.contains_symbol(rep_symbol), } } @@ -175,6 +205,7 @@ impl RecordField { Optional(typ) => typ.contains_variable(rep_variable), Required(typ) => typ.contains_variable(rep_variable), Demanded(typ) => typ.contains_variable(rep_variable), + RigidRequired(typ) => typ.contains_variable(rep_variable), RigidOptional(typ) => typ.contains_variable(rep_variable), } } @@ -581,12 +612,14 @@ impl fmt::Debug for Type { for (label, field_type) in fields { match field_type { - RecordField::Optional(_) => write!(f, "{:?} ? {:?}", label, field_type)?, - RecordField::Required(_) => write!(f, "{:?} : {:?}", label, field_type)?, - RecordField::Demanded(_) => write!(f, "{:?} : {:?}", label, field_type)?, - RecordField::RigidOptional(_) => { + RecordField::Optional(_) | RecordField::RigidOptional(_) => { write!(f, "{:?} ? {:?}", label, field_type)? } + RecordField::Required(_) + | RecordField::Demanded(_) + | RecordField::RigidRequired(_) => { + write!(f, "{:?} : {:?}", label, field_type)? + } } if any_written_yet { @@ -1636,15 +1669,8 @@ fn variables_help(tipe: &Type, accum: &mut ImSet) { variables_help(ret, accum); } Record(fields, ext) => { - use RecordField::*; - for (_, field) in fields { - match field { - Optional(x) => variables_help(x, accum), - Required(x) => variables_help(x, accum), - Demanded(x) => variables_help(x, accum), - RigidOptional(x) => variables_help(x, accum), - }; + variables_help(field.as_inner(), accum); } if let TypeExtension::Open(ext) = ext { @@ -1778,15 +1804,8 @@ fn variables_help_detailed(tipe: &Type, accum: &mut VariableDetail) { variables_help_detailed(ret, accum); } Record(fields, ext) => { - use RecordField::*; - for (_, field) in fields { - match field { - Optional(x) => variables_help_detailed(x, accum), - Required(x) => variables_help_detailed(x, accum), - Demanded(x) => variables_help_detailed(x, accum), - RigidOptional(x) => variables_help_detailed(x, accum), - }; + variables_help_detailed(field.as_inner(), accum); } if let TypeExtension::Open(ext) = ext { @@ -2366,11 +2385,7 @@ fn write_error_type_help( buf.push_str(" ? "); content } - Required(content) => { - buf.push_str(" : "); - content - } - Demanded(content) => { + Required(content) | Demanded(content) | RigidRequired(content) => { buf.push_str(" : "); content } @@ -2520,11 +2535,7 @@ fn write_debug_error_type_help(error_type: ErrorType, buf: &mut String, parens: buf.push_str(" ? "); content } - Required(content) => { - buf.push_str(" : "); - content - } - Demanded(content) => { + Required(content) | Demanded(content) | RigidRequired(content) => { buf.push_str(" : "); content } diff --git a/crates/compiler/unify/src/unify.rs b/crates/compiler/unify/src/unify.rs index 08bcf62e39..19a4d21d0f 100644 --- a/crates/compiler/unify/src/unify.rs +++ b/crates/compiler/unify/src/unify.rs @@ -1845,6 +1845,7 @@ fn unify_shared_fields( // // Demanded does not unify with Optional // RigidOptional does not unify with Required or Demanded + // RigidRequired does not unify with Optional // Unifying Required with Demanded => Demanded // Unifying Optional with Required => Required // Unifying Optional with RigidOptional => RigidOptional @@ -1861,15 +1862,26 @@ fn unify_shared_fields( (Required(val), Optional(_)) => Required(val), (Optional(val), Required(_)) => Required(val), (Optional(val), Optional(_)) => Optional(val), + + // rigid optional (RigidOptional(val), Optional(_)) | (Optional(_), RigidOptional(val)) => { RigidOptional(val) } - (RigidOptional(_), Demanded(_) | Required(_)) - | (Demanded(_) | Required(_), RigidOptional(_)) => { + (RigidOptional(_), Demanded(_) | Required(_) | RigidRequired(_)) + | (Demanded(_) | Required(_) | RigidRequired(_), RigidOptional(_)) => { // this is an error, but we continue to give better error messages continue; } (RigidOptional(val), RigidOptional(_)) => RigidOptional(val), + + // rigid required + (RigidRequired(_), Optional(_)) | (Optional(_), RigidRequired(_)) => { + // this is an error, but we continue to give better error messages + continue; + } + (RigidRequired(val), Demanded(_) | Required(_)) + | (Demanded(_) | Required(_), RigidRequired(val)) => RigidRequired(val), + (RigidRequired(val), RigidRequired(_)) => RigidRequired(val), }; matching_fields.push((name, actual)); diff --git a/crates/glue/src/types.rs b/crates/glue/src/types.rs index 311fa71857..317461d148 100644 --- a/crates/glue/src/types.rs +++ b/crates/glue/src/types.rs @@ -768,7 +768,9 @@ fn add_type_help<'a>( .expect("something weird in content") .flat_map(|(label, field)| { match field { - RecordField::Required(field_var) | RecordField::Demanded(field_var) => { + RecordField::Required(field_var) + | RecordField::Demanded(field_var) + | RecordField::RigidRequired(field_var) => { Some((label.to_string(), field_var)) } RecordField::Optional(_) | RecordField::RigidOptional(_) => { diff --git a/crates/reporting/src/error/type.rs b/crates/reporting/src/error/type.rs index 3ed8c75082..52cb8bd025 100644 --- a/crates/reporting/src/error/type.rs +++ b/crates/reporting/src/error/type.rs @@ -2317,6 +2317,9 @@ fn to_doc_help<'b>( Parens::Unnecessary, v, )), + RecordField::RigidRequired(v) => RecordField::RigidRequired( + to_doc_help(ctx, alloc, Parens::Unnecessary, v), + ), RecordField::Demanded(v) => RecordField::Demanded(to_doc_help( ctx, alloc, @@ -2794,22 +2797,12 @@ fn diff_record<'b>( left: ( field.clone(), alloc.string(field.as_str().to_string()), - match t1 { - RecordField::Optional(_) => RecordField::Optional(diff.left), - RecordField::RigidOptional(_) => RecordField::RigidOptional(diff.left), - RecordField::Required(_) => RecordField::Required(diff.left), - RecordField::Demanded(_) => RecordField::Demanded(diff.left), - }, + t1.replace(diff.left), ), right: ( field.clone(), alloc.string(field.as_str().to_string()), - match t2 { - RecordField::Optional(_) => RecordField::Optional(diff.right), - RecordField::RigidOptional(_) => RecordField::RigidOptional(diff.right), - RecordField::Required(_) => RecordField::Required(diff.right), - RecordField::Demanded(_) => RecordField::Demanded(diff.right), - }, + t2.replace(diff.right), ), status: { match (&t1, &t2) { @@ -3252,10 +3245,9 @@ mod report_text { let entry_to_doc = |(field_name, field_type): (RocDocBuilder<'b>, RecordField>)| { match field_type { - RecordField::Demanded(field) => { - field_name.append(alloc.text(" : ")).append(field) - } - RecordField::Required(field) => { + RecordField::Demanded(field) + | RecordField::Required(field) + | RecordField::RigidRequired(field) => { field_name.append(alloc.text(" : ")).append(field) } RecordField::Optional(field) | RecordField::RigidOptional(field) => { diff --git a/crates/reporting/tests/test_reporting.rs b/crates/reporting/tests/test_reporting.rs index f26dab8ea9..401b567112 100644 --- a/crates/reporting/tests/test_reporting.rs +++ b/crates/reporting/tests/test_reporting.rs @@ -11457,4 +11457,31 @@ All branches in an `if` must have the same type! Note: `Hash` cannot be generated for functions. "### ); + + test_report!( + demanded_vs_optional_record_field, + indoc!( + r#" + foo : { a : Str } -> Str + foo = \{ a ? "" } -> a + foo + "# + ), + @r###" + ── TYPE MISMATCH ───────────────────────────────────────── /code/proj/Main.roc ─ + + The 1st argument to `foo` is weird: + + 5│ foo = \{ a ? "" } -> a + ^^^^^^^^^^ + + The argument is a pattern that matches record values of type: + + { a ? Str } + + But the annotation on `foo` says the 1st argument should be: + + { a : Str } + "### + ); }