be smarter

This commit is contained in:
Folkert 2022-03-04 23:02:10 +01:00
parent cfccb92bf9
commit db06c10b5f
No known key found for this signature in database
GPG key ID: 1F17F6FFD112B97C
4 changed files with 209 additions and 40 deletions

View file

@ -11,7 +11,7 @@ use roc_can::expected::PExpected;
use roc_can::expr::Expr::{self, *}; use roc_can::expr::Expr::{self, *};
use roc_can::expr::{ClosureData, Field, WhenBranch}; use roc_can::expr::{ClosureData, Field, WhenBranch};
use roc_can::pattern::Pattern; use roc_can::pattern::Pattern;
use roc_collections::all::{ImMap, Index, MutSet, SendMap}; use roc_collections::all::{ImMap, Index, MutMap, SendMap};
use roc_module::ident::{Lowercase, TagName}; use roc_module::ident::{Lowercase, TagName};
use roc_module::symbol::{ModuleId, Symbol}; use roc_module::symbol::{ModuleId, Symbol};
use roc_region::all::{Loc, Region}; use roc_region::all::{Loc, Region};
@ -52,7 +52,7 @@ pub struct Env {
/// Whenever we encounter a user-defined type variable (a "rigid" var for short), /// Whenever we encounter a user-defined type variable (a "rigid" var for short),
/// for example `a` in the annotation `identity : a -> a`, we add it to this /// for example `a` in the annotation `identity : a -> a`, we add it to this
/// map so that expressions within that annotation can share these vars. /// map so that expressions within that annotation can share these vars.
pub rigids: ImMap<Lowercase, Variable>, pub rigids: MutMap<Lowercase, Variable>,
pub home: ModuleId, pub home: ModuleId,
} }
@ -693,7 +693,7 @@ pub fn constrain_expr(
let constraint = constrain_expr( let constraint = constrain_expr(
&Env { &Env {
home: env.home, home: env.home,
rigids: ImMap::default(), rigids: MutMap::default(),
}, },
region, region,
&loc_expr.value, &loc_expr.value,
@ -1170,7 +1170,7 @@ pub fn constrain_decls(home: ModuleId, decls: &[Declaration]) -> Constraint {
let mut env = Env { let mut env = Env {
home, home,
rigids: ImMap::default(), rigids: MutMap::default(),
}; };
for decl in decls.iter().rev() { for decl in decls.iter().rev() {
@ -1227,13 +1227,29 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint {
def_pattern_state.vars.push(expr_var); def_pattern_state.vars.push(expr_var);
let mut new_rigids = Vec::new(); let mut new_rigids = Vec::new();
let expr_con = match &def.annotation { let expr_con = match &def.annotation {
Some(annotation) => { Some(annotation) => {
let arity = annotation.signature.arity(); let arity = annotation.signature.arity();
let rigids = &env.rigids; let rigids = &env.rigids;
let mut ftv = rigids.clone(); let mut ftv = rigids.clone();
/*
let mut new_rigids = Vec::new();
// pub wildcards: Vec<Variable>,
// pub var_by_name: SendMap<Lowercase, Variable>,
'outer: for (outer_name, outer_var) in rigids.iter() {
for (inner_name, inner_var) in annotation.introduced_variables.var_by_name.iter() {
if outer_name == inner_name {
debug_assert_eq!(inner_var, outer_var);
continue 'outer;
}
}
// the inner name is not in the outer scope; it's introduced here
}
*/
let signature = instantiate_rigids( let signature = instantiate_rigids(
&annotation.signature, &annotation.signature,
&annotation.introduced_variables, &annotation.introduced_variables,
@ -1514,46 +1530,47 @@ fn instantiate_rigids(
annotation: &Type, annotation: &Type,
introduced_vars: &IntroducedVariables, introduced_vars: &IntroducedVariables,
new_rigids: &mut Vec<Variable>, new_rigids: &mut Vec<Variable>,
ftv: &mut ImMap<Lowercase, Variable>, // rigids defined before the current annotation ftv: &mut MutMap<Lowercase, Variable>, // rigids defined before the current annotation
loc_pattern: &Loc<Pattern>, loc_pattern: &Loc<Pattern>,
headers: &mut SendMap<Symbol, Loc<Type>>, headers: &mut SendMap<Symbol, Loc<Type>>,
) -> Type { ) -> Type {
let mut annotation = annotation.clone(); // find out if rigid type variables first occur in this annotation,
// or if they are already introduced in an outer annotation
let mut rigid_substitution: ImMap<Variable, Type> = ImMap::default(); let mut rigid_substitution: ImMap<Variable, Type> = ImMap::default();
let outside_rigids: MutSet<Variable> = ftv.values().copied().collect();
for (name, var) in introduced_vars.var_by_name.iter() { for (name, var) in introduced_vars.var_by_name.iter() {
if let Some(existing_rigid) = ftv.get(name) { use std::collections::hash_map::Entry::*;
rigid_substitution.insert(*var, Type::Variable(*existing_rigid));
} else { match ftv.entry(name.clone()) {
// It's possible to use this rigid in nested defs Occupied(occupied) => {
ftv.insert(name.clone(), *var); let existing_rigid = occupied.get();
rigid_substitution.insert(*var, Type::Variable(*existing_rigid));
}
Vacant(vacant) => {
// It's possible to use this rigid in nested defs
vacant.insert(*var);
new_rigids.push(*var);
}
} }
} }
// Instantiate rigid variables // wildcards are always freshly introduced in this annotation
for (i, wildcard) in introduced_vars.wildcards.iter().enumerate() {
ftv.insert(format!("*{}", i).into(), *wildcard);
new_rigids.push(*wildcard);
}
let mut annotation = annotation.clone();
if !rigid_substitution.is_empty() { if !rigid_substitution.is_empty() {
annotation.substitute(&rigid_substitution); annotation.substitute(&rigid_substitution);
} }
if let Some(new_headers) = crate::pattern::headers_from_annotation( let loc_annotation_ref = Loc::at(loc_pattern.region, &annotation);
&loc_pattern.value, if let Pattern::Identifier(symbol) = loc_pattern.value {
&Loc::at(loc_pattern.region, annotation.clone()), headers.insert(symbol, Loc::at(loc_pattern.region, annotation.clone()));
) { } else if let Some(new_headers) =
for (symbol, loc_type) in new_headers { crate::pattern::headers_from_annotation(&loc_pattern.value, &loc_annotation_ref)
for var in loc_type.value.variables() { {
// a rigid is only new if this annotation is the first occurrence of this rigid headers.extend(new_headers)
if !outside_rigids.contains(&var) {
new_rigids.push(var);
}
}
headers.insert(symbol, loc_type);
}
}
for (i, wildcard) in introduced_vars.wildcards.iter().enumerate() {
ftv.insert(format!("*{}", i).into(), *wildcard);
} }
annotation annotation
@ -1584,7 +1601,6 @@ pub fn rec_defs_help(
def_pattern_state.vars.push(expr_var); def_pattern_state.vars.push(expr_var);
let mut new_rigids = Vec::new();
match &def.annotation { match &def.annotation {
None => { None => {
let expr_con = constrain_expr( let expr_con = constrain_expr(
@ -1611,6 +1627,7 @@ pub fn rec_defs_help(
Some(annotation) => { Some(annotation) => {
let arity = annotation.signature.arity(); let arity = annotation.signature.arity();
let mut ftv = env.rigids.clone(); let mut ftv = env.rigids.clone();
let mut new_rigids = Vec::new();
let signature = instantiate_rigids( let signature = instantiate_rigids(
&annotation.signature, &annotation.signature,

View file

@ -27,7 +27,7 @@ pub struct PatternState {
/// definition has an annotation, we instead now add `x => Int`. /// definition has an annotation, we instead now add `x => Int`.
pub fn headers_from_annotation( pub fn headers_from_annotation(
pattern: &Pattern, pattern: &Pattern,
annotation: &Loc<Type>, annotation: &Loc<&Type>,
) -> Option<SendMap<Symbol, Loc<Type>>> { ) -> Option<SendMap<Symbol, Loc<Type>>> {
let mut headers = SendMap::default(); let mut headers = SendMap::default();
// Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int` // Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int`
@ -44,12 +44,13 @@ pub fn headers_from_annotation(
fn headers_from_annotation_help( fn headers_from_annotation_help(
pattern: &Pattern, pattern: &Pattern,
annotation: &Loc<Type>, annotation: &Loc<&Type>,
headers: &mut SendMap<Symbol, Loc<Type>>, headers: &mut SendMap<Symbol, Loc<Type>>,
) -> bool { ) -> bool {
match pattern { match pattern {
Identifier(symbol) | Shadowed(_, _, symbol) => { Identifier(symbol) | Shadowed(_, _, symbol) => {
headers.insert(*symbol, annotation.clone()); let typ = Loc::at(annotation.region, annotation.value.clone());
headers.insert(*symbol, typ);
true true
} }
Underscore Underscore
@ -106,7 +107,7 @@ fn headers_from_annotation_help(
.all(|(arg_pattern, arg_type)| { .all(|(arg_pattern, arg_type)| {
headers_from_annotation_help( headers_from_annotation_help(
&arg_pattern.1.value, &arg_pattern.1.value,
&Loc::at(annotation.region, arg_type.clone()), &Loc::at(annotation.region, arg_type),
headers, headers,
) )
}) })
@ -135,12 +136,13 @@ fn headers_from_annotation_help(
&& type_arguments.len() == pat_type_arguments.len() && type_arguments.len() == pat_type_arguments.len()
&& lambda_set_variables.len() == pat_lambda_set_variables.len() => && lambda_set_variables.len() == pat_lambda_set_variables.len() =>
{ {
headers.insert(*opaque, annotation.clone()); let typ = Loc::at(annotation.region, annotation.value.clone());
headers.insert(*opaque, typ);
let (_, argument_pat) = &**argument; let (_, argument_pat) = &**argument;
headers_from_annotation_help( headers_from_annotation_help(
&argument_pat.value, &argument_pat.value,
&Loc::at(annotation.region, (**actual).clone()), &Loc::at(annotation.region, actual),
headers, headers,
) )
} }

View file

@ -5509,4 +5509,42 @@ mod solve_expr {
r#"Id [ A, B, C { a : Str }e ] -> Str"#, r#"Id [ A, B, C { a : Str }e ] -> Str"#,
) )
} }
#[test]
fn inner_annotation_rigid() {
infer_eq_without_problem(
indoc!(
r#"
f : a -> a
f =
g : b -> b
g = \x -> x
g
f
"#
),
r#"a -> a"#,
)
}
#[test]
fn inner_annotation_rigid_2() {
infer_eq_without_problem(
indoc!(
r#"
f : {} -> List a
f =
g : List a
g = []
\{} -> g
f
"#
),
r#"{} -> List a"#,
)
}
} }

View file

@ -991,6 +991,118 @@ fn symbols_help(tipe: &Type, accum: &mut ImSet<Symbol>) {
} }
} }
pub struct VariablesIter<'a> {
stack: Vec<&'a Type>,
recursion_variables: Vec<Variable>,
variables: Vec<Variable>,
}
impl<'a> VariablesIter<'a> {
pub fn new(typ: &'a Type) -> Self {
Self {
stack: vec![typ],
recursion_variables: vec![],
variables: vec![],
}
}
}
impl<'a> Iterator for VariablesIter<'a> {
type Item = Variable;
fn next(&mut self) -> Option<Self::Item> {
use Type::*;
if let Some(var) = self.variables.pop() {
debug_assert!(!self.recursion_variables.contains(&var));
return Some(var);
}
while let Some(tipe) = self.stack.pop() {
match tipe {
EmptyRec | EmptyTagUnion | Erroneous(_) => {
continue;
}
ClosureTag { ext: v, .. } | Variable(v) => {
if !self.recursion_variables.contains(v) {
return Some(*v);
}
}
Function(args, closure, ret) => {
self.stack.push(ret);
self.stack.push(closure);
self.stack.extend(args.iter().rev());
}
Record(fields, ext) => {
use RecordField::*;
self.stack.push(ext);
for (_, field) in fields {
match field {
Optional(x) => self.stack.push(x),
Required(x) => self.stack.push(x),
Demanded(x) => self.stack.push(x),
};
}
}
TagUnion(tags, ext) => {
self.stack.push(ext);
for (_, args) in tags {
self.stack.extend(args);
}
}
FunctionOrTagUnion(_, _, ext) => {
self.stack.push(ext);
}
RecursiveTagUnion(rec, tags, ext) => {
self.recursion_variables.push(*rec);
self.stack.push(ext);
for (_, args) in tags.iter().rev() {
self.stack.extend(args);
}
}
Alias {
type_arguments,
actual,
..
} => {
self.stack.push(actual);
for (_, args) in type_arguments.iter().rev() {
self.stack.push(args);
}
}
HostExposedAlias {
type_arguments: arguments,
actual,
..
} => {
self.stack.push(actual);
for (_, args) in arguments.iter().rev() {
self.stack.push(args);
}
}
RangedNumber(typ, vars) => {
self.stack.push(typ);
self.variables.extend(vars.iter().copied());
}
Apply(_, args, _) => {
self.stack.extend(args);
}
}
}
None
}
}
fn variables_help(tipe: &Type, accum: &mut ImSet<Variable>) { fn variables_help(tipe: &Type, accum: &mut ImSet<Variable>) {
use Type::*; use Type::*;