diff --git a/compiler/constrain/src/expr.rs b/compiler/constrain/src/expr.rs index da99c6c50e..2d3d131019 100644 --- a/compiler/constrain/src/expr.rs +++ b/compiler/constrain/src/expr.rs @@ -11,7 +11,7 @@ use roc_can::expected::PExpected; use roc_can::expr::Expr::{self, *}; use roc_can::expr::{ClosureData, Field, WhenBranch}; use roc_can::pattern::Pattern; -use roc_collections::all::{ImMap, Index, MutSet, SendMap}; +use roc_collections::all::{ImMap, Index, MutMap, SendMap}; use roc_module::ident::{Lowercase, TagName}; use roc_module::symbol::{ModuleId, Symbol}; use roc_region::all::{Loc, Region}; @@ -52,7 +52,7 @@ pub struct Env { /// Whenever we encounter a user-defined type variable (a "rigid" var for short), /// for example `a` in the annotation `identity : a -> a`, we add it to this /// map so that expressions within that annotation can share these vars. - pub rigids: ImMap, + pub rigids: MutMap, pub home: ModuleId, } @@ -693,7 +693,7 @@ pub fn constrain_expr( let constraint = constrain_expr( &Env { home: env.home, - rigids: ImMap::default(), + rigids: MutMap::default(), }, region, &loc_expr.value, @@ -1170,7 +1170,7 @@ pub fn constrain_decls(home: ModuleId, decls: &[Declaration]) -> Constraint { let mut env = Env { home, - rigids: ImMap::default(), + rigids: MutMap::default(), }; for decl in decls.iter().rev() { @@ -1227,13 +1227,29 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint { def_pattern_state.vars.push(expr_var); let mut new_rigids = Vec::new(); - let expr_con = match &def.annotation { Some(annotation) => { let arity = annotation.signature.arity(); let rigids = &env.rigids; let mut ftv = rigids.clone(); + /* + let mut new_rigids = Vec::new(); + + // pub wildcards: Vec, + // pub var_by_name: SendMap, + 'outer: for (outer_name, outer_var) in rigids.iter() { + for (inner_name, inner_var) in annotation.introduced_variables.var_by_name.iter() { + if outer_name == inner_name { + debug_assert_eq!(inner_var, outer_var); + continue 'outer; + } + } + + // the inner name is not in the outer scope; it's introduced here + } + */ + let signature = instantiate_rigids( &annotation.signature, &annotation.introduced_variables, @@ -1514,46 +1530,47 @@ fn instantiate_rigids( annotation: &Type, introduced_vars: &IntroducedVariables, new_rigids: &mut Vec, - ftv: &mut ImMap, // rigids defined before the current annotation + ftv: &mut MutMap, // rigids defined before the current annotation loc_pattern: &Loc, headers: &mut SendMap>, ) -> Type { - let mut annotation = annotation.clone(); + // find out if rigid type variables first occur in this annotation, + // or if they are already introduced in an outer annotation let mut rigid_substitution: ImMap = ImMap::default(); - - let outside_rigids: MutSet = ftv.values().copied().collect(); - for (name, var) in introduced_vars.var_by_name.iter() { - if let Some(existing_rigid) = ftv.get(name) { - rigid_substitution.insert(*var, Type::Variable(*existing_rigid)); - } else { - // It's possible to use this rigid in nested defs - ftv.insert(name.clone(), *var); + use std::collections::hash_map::Entry::*; + + match ftv.entry(name.clone()) { + Occupied(occupied) => { + let existing_rigid = occupied.get(); + rigid_substitution.insert(*var, Type::Variable(*existing_rigid)); + } + Vacant(vacant) => { + // It's possible to use this rigid in nested defs + vacant.insert(*var); + new_rigids.push(*var); + } } } - // Instantiate rigid variables + // wildcards are always freshly introduced in this annotation + for (i, wildcard) in introduced_vars.wildcards.iter().enumerate() { + ftv.insert(format!("*{}", i).into(), *wildcard); + new_rigids.push(*wildcard); + } + + let mut annotation = annotation.clone(); if !rigid_substitution.is_empty() { annotation.substitute(&rigid_substitution); } - if let Some(new_headers) = crate::pattern::headers_from_annotation( - &loc_pattern.value, - &Loc::at(loc_pattern.region, annotation.clone()), - ) { - for (symbol, loc_type) in new_headers { - for var in loc_type.value.variables() { - // a rigid is only new if this annotation is the first occurrence of this rigid - if !outside_rigids.contains(&var) { - new_rigids.push(var); - } - } - headers.insert(symbol, loc_type); - } - } - - for (i, wildcard) in introduced_vars.wildcards.iter().enumerate() { - ftv.insert(format!("*{}", i).into(), *wildcard); + let loc_annotation_ref = Loc::at(loc_pattern.region, &annotation); + if let Pattern::Identifier(symbol) = loc_pattern.value { + headers.insert(symbol, Loc::at(loc_pattern.region, annotation.clone())); + } else if let Some(new_headers) = + crate::pattern::headers_from_annotation(&loc_pattern.value, &loc_annotation_ref) + { + headers.extend(new_headers) } annotation @@ -1584,7 +1601,6 @@ pub fn rec_defs_help( def_pattern_state.vars.push(expr_var); - let mut new_rigids = Vec::new(); match &def.annotation { None => { let expr_con = constrain_expr( @@ -1611,6 +1627,7 @@ pub fn rec_defs_help( Some(annotation) => { let arity = annotation.signature.arity(); let mut ftv = env.rigids.clone(); + let mut new_rigids = Vec::new(); let signature = instantiate_rigids( &annotation.signature, diff --git a/compiler/constrain/src/pattern.rs b/compiler/constrain/src/pattern.rs index 200f38dc74..c0aee74614 100644 --- a/compiler/constrain/src/pattern.rs +++ b/compiler/constrain/src/pattern.rs @@ -27,7 +27,7 @@ pub struct PatternState { /// definition has an annotation, we instead now add `x => Int`. pub fn headers_from_annotation( pattern: &Pattern, - annotation: &Loc, + annotation: &Loc<&Type>, ) -> Option>> { let mut headers = SendMap::default(); // Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int` @@ -44,12 +44,13 @@ pub fn headers_from_annotation( fn headers_from_annotation_help( pattern: &Pattern, - annotation: &Loc, + annotation: &Loc<&Type>, headers: &mut SendMap>, ) -> bool { match pattern { Identifier(symbol) | Shadowed(_, _, symbol) => { - headers.insert(*symbol, annotation.clone()); + let typ = Loc::at(annotation.region, annotation.value.clone()); + headers.insert(*symbol, typ); true } Underscore @@ -106,7 +107,7 @@ fn headers_from_annotation_help( .all(|(arg_pattern, arg_type)| { headers_from_annotation_help( &arg_pattern.1.value, - &Loc::at(annotation.region, arg_type.clone()), + &Loc::at(annotation.region, arg_type), headers, ) }) @@ -135,12 +136,13 @@ fn headers_from_annotation_help( && type_arguments.len() == pat_type_arguments.len() && lambda_set_variables.len() == pat_lambda_set_variables.len() => { - headers.insert(*opaque, annotation.clone()); + let typ = Loc::at(annotation.region, annotation.value.clone()); + headers.insert(*opaque, typ); let (_, argument_pat) = &**argument; headers_from_annotation_help( &argument_pat.value, - &Loc::at(annotation.region, (**actual).clone()), + &Loc::at(annotation.region, actual), headers, ) } diff --git a/compiler/solve/tests/solve_expr.rs b/compiler/solve/tests/solve_expr.rs index 6987ac4b4c..28566441d7 100644 --- a/compiler/solve/tests/solve_expr.rs +++ b/compiler/solve/tests/solve_expr.rs @@ -5509,4 +5509,42 @@ mod solve_expr { r#"Id [ A, B, C { a : Str }e ] -> Str"#, ) } + + #[test] + fn inner_annotation_rigid() { + infer_eq_without_problem( + indoc!( + r#" + f : a -> a + f = + g : b -> b + g = \x -> x + + g + + f + "# + ), + r#"a -> a"#, + ) + } + + #[test] + fn inner_annotation_rigid_2() { + infer_eq_without_problem( + indoc!( + r#" + f : {} -> List a + f = + g : List a + g = [] + + \{} -> g + + f + "# + ), + r#"{} -> List a"#, + ) + } } diff --git a/compiler/types/src/types.rs b/compiler/types/src/types.rs index 298fd32be8..c37e0f38ed 100644 --- a/compiler/types/src/types.rs +++ b/compiler/types/src/types.rs @@ -991,6 +991,118 @@ fn symbols_help(tipe: &Type, accum: &mut ImSet) { } } +pub struct VariablesIter<'a> { + stack: Vec<&'a Type>, + recursion_variables: Vec, + variables: Vec, +} + +impl<'a> VariablesIter<'a> { + pub fn new(typ: &'a Type) -> Self { + Self { + stack: vec![typ], + recursion_variables: vec![], + variables: vec![], + } + } +} + +impl<'a> Iterator for VariablesIter<'a> { + type Item = Variable; + + fn next(&mut self) -> Option { + use Type::*; + + if let Some(var) = self.variables.pop() { + debug_assert!(!self.recursion_variables.contains(&var)); + + return Some(var); + } + + while let Some(tipe) = self.stack.pop() { + match tipe { + EmptyRec | EmptyTagUnion | Erroneous(_) => { + continue; + } + + ClosureTag { ext: v, .. } | Variable(v) => { + if !self.recursion_variables.contains(v) { + return Some(*v); + } + } + + Function(args, closure, ret) => { + self.stack.push(ret); + self.stack.push(closure); + self.stack.extend(args.iter().rev()); + } + Record(fields, ext) => { + use RecordField::*; + + self.stack.push(ext); + + for (_, field) in fields { + match field { + Optional(x) => self.stack.push(x), + Required(x) => self.stack.push(x), + Demanded(x) => self.stack.push(x), + }; + } + } + TagUnion(tags, ext) => { + self.stack.push(ext); + + for (_, args) in tags { + self.stack.extend(args); + } + } + FunctionOrTagUnion(_, _, ext) => { + self.stack.push(ext); + } + RecursiveTagUnion(rec, tags, ext) => { + self.recursion_variables.push(*rec); + self.stack.push(ext); + + for (_, args) in tags.iter().rev() { + self.stack.extend(args); + } + } + Alias { + type_arguments, + actual, + .. + } => { + self.stack.push(actual); + + for (_, args) in type_arguments.iter().rev() { + self.stack.push(args); + } + } + HostExposedAlias { + type_arguments: arguments, + actual, + .. + } => { + self.stack.push(actual); + + for (_, args) in arguments.iter().rev() { + self.stack.push(args); + } + } + RangedNumber(typ, vars) => { + self.stack.push(typ); + self.variables.extend(vars.iter().copied()); + } + Apply(_, args, _) => { + self.stack.extend(args); + } + } + } + + None + } +} + fn variables_help(tipe: &Type, accum: &mut ImSet) { use Type::*;