use crate::builtins; use roc_can::constraint::Constraint; use roc_can::expected::{Expected, PExpected}; use roc_can::pattern::Pattern::{self, *}; use roc_can::pattern::RecordDestruct; use roc_collections::all::{Index, SendMap}; use roc_module::ident::Lowercase; use roc_module::symbol::Symbol; use roc_region::all::{Located, Region}; use roc_types::subs::Variable; use roc_types::types::{Category, PReason, PatternCategory, Type}; pub struct PatternState { pub headers: SendMap>, pub vars: Vec, pub constraints: Vec, } /// If there is a type annotation, the pattern state headers can be optimized by putting the /// annotation in the headers. Normally /// /// x = 4 /// /// Would add `x => <42>` to the headers (i.e., symbol points to a type variable). If the /// definition has an annotation, we instead now add `x => Int`. pub fn headers_from_annotation( pattern: &Pattern, annotation: &Located, ) -> Option>> { let mut headers = SendMap::default(); // Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int` // in such incorrect cases we don't put the full annotation in headers, just a variable, and let // inference generate a proper error. let is_structurally_valid = headers_from_annotation_help(pattern, annotation, &mut headers); if is_structurally_valid { Some(headers) } else { None } } fn headers_from_annotation_help( pattern: &Pattern, annotation: &Located, headers: &mut SendMap>, ) -> bool { match pattern { Identifier(symbol) => { headers.insert(*symbol, annotation.clone()); true } Underscore | Shadowed(_, _) | MalformedPattern(_, _) | UnsupportedPattern(_) | NumLiteral(_, _) | IntLiteral(_) | FloatLiteral(_) | StrLiteral(_) => true, RecordDestructure { destructs, .. } => match annotation.value.shallow_dealias() { Type::Record(fields, _) => { for destruct in destructs { // NOTE ignores the .guard field. if let Some(field_type) = fields.get(&destruct.value.label) { headers.insert( destruct.value.symbol, Located::at(annotation.region, field_type.clone()), ); } else { return false; } } true } Type::EmptyRec => destructs.is_empty(), _ => false, }, AppliedTag { tag_name, arguments, .. } => match annotation.value.shallow_dealias() { Type::TagUnion(tags, _) => { if let Some((_, arg_types)) = tags.iter().find(|(name, _)| name == tag_name) { if !arguments.len() == arg_types.len() { return false; } arguments .iter() .zip(arg_types.iter()) .all(|(arg_pattern, arg_type)| { headers_from_annotation_help( &arg_pattern.1.value, &Located::at(annotation.region, arg_type.clone()), headers, ) }) } else { false } } _ => false, }, } } /// This accepts PatternState (rather than returning it) so that the caller can /// intiialize the Vecs in PatternState using with_capacity /// based on its knowledge of their lengths. pub fn constrain_pattern( pattern: &Pattern, region: Region, expected: PExpected, state: &mut PatternState, ) { match pattern { Underscore | UnsupportedPattern(_) | MalformedPattern(_, _) | Shadowed(_, _) => { // Neither the _ pattern nor erroneous ones add any constraints. } Identifier(symbol) => { state.headers.insert( *symbol, Located { region, value: expected.get_type(), }, ); } NumLiteral(var, _) => { state.vars.push(*var); state.constraints.push(Constraint::Pattern( region, PatternCategory::Num, builtins::builtin_type(Symbol::NUM_NUM, vec![Type::Variable(*var)]), expected, )); } IntLiteral(_) => { state.constraints.push(Constraint::Pattern( region, PatternCategory::Float, builtins::builtin_type(Symbol::INT_INT, vec![]), expected, )); } FloatLiteral(_) => { state.constraints.push(Constraint::Pattern( region, PatternCategory::Float, builtins::builtin_type(Symbol::FLOAT_FLOAT, vec![]), expected, )); } StrLiteral(_) => { state.constraints.push(Constraint::Pattern( region, PatternCategory::Str, builtins::str_type(), expected, )); } RecordDestructure { whole_var, ext_var, destructs, } => { state.vars.push(*whole_var); state.vars.push(*ext_var); let ext_type = Type::Variable(*ext_var); let mut field_types: SendMap = SendMap::default(); for Located { value: RecordDestruct { var, label, symbol, guard, }, .. } in destructs { let pat_type = Type::Variable(*var); let expected = PExpected::NoExpectation(pat_type.clone()); if !state.headers.contains_key(&symbol) { state .headers .insert(*symbol, Located::at(region, pat_type.clone())); } field_types.insert(label.clone(), pat_type.clone()); if let Some((guard_var, loc_guard)) = guard { state.constraints.push(Constraint::Pattern( region, PatternCategory::PatternGuard, Type::Variable(*guard_var), PExpected::ForReason( PReason::PatternGuard, pat_type.clone(), loc_guard.region, ), )); state.vars.push(*guard_var); constrain_pattern(&loc_guard.value, loc_guard.region, expected, state); } state.vars.push(*var); } let record_type = Type::Record(field_types, Box::new(ext_type)); let whole_con = Constraint::Eq( Type::Variable(*whole_var), Expected::NoExpectation(record_type), Category::Storage, region, ); let record_con = Constraint::Pattern( region, PatternCategory::Record, Type::Variable(*whole_var), expected, ); state.constraints.push(whole_con); state.constraints.push(record_con); } AppliedTag { whole_var, ext_var, tag_name, arguments, } => { let mut argument_types = Vec::with_capacity(arguments.len()); for (index, (pattern_var, loc_pattern)) in arguments.iter().enumerate() { state.vars.push(*pattern_var); let pattern_type = Type::Variable(*pattern_var); argument_types.push(pattern_type.clone()); let expected = PExpected::ForReason( PReason::TagArg { tag_name: tag_name.clone(), index: Index::zero_based(index), }, pattern_type, region, ); constrain_pattern(&loc_pattern.value, loc_pattern.region, expected, state); } let whole_con = Constraint::Eq( Type::Variable(*whole_var), Expected::NoExpectation(Type::TagUnion( vec![(tag_name.clone(), argument_types)], Box::new(Type::Variable(*ext_var)), )), Category::Storage, region, ); let tag_con = Constraint::Pattern( region, PatternCategory::Ctor(tag_name.clone()), Type::Variable(*whole_var), expected, ); state.vars.push(*whole_var); state.vars.push(*ext_var); state.constraints.push(whole_con); state.constraints.push(tag_con); } } }