mirror of
https://github.com/roc-lang/roc.git
synced 2025-09-27 13:59:08 +00:00
be smarter
This commit is contained in:
parent
cfccb92bf9
commit
db06c10b5f
4 changed files with 209 additions and 40 deletions
|
@ -11,7 +11,7 @@ use roc_can::expected::PExpected;
|
||||||
use roc_can::expr::Expr::{self, *};
|
use roc_can::expr::Expr::{self, *};
|
||||||
use roc_can::expr::{ClosureData, Field, WhenBranch};
|
use roc_can::expr::{ClosureData, Field, WhenBranch};
|
||||||
use roc_can::pattern::Pattern;
|
use roc_can::pattern::Pattern;
|
||||||
use roc_collections::all::{ImMap, Index, MutSet, SendMap};
|
use roc_collections::all::{ImMap, Index, MutMap, SendMap};
|
||||||
use roc_module::ident::{Lowercase, TagName};
|
use roc_module::ident::{Lowercase, TagName};
|
||||||
use roc_module::symbol::{ModuleId, Symbol};
|
use roc_module::symbol::{ModuleId, Symbol};
|
||||||
use roc_region::all::{Loc, Region};
|
use roc_region::all::{Loc, Region};
|
||||||
|
@ -52,7 +52,7 @@ pub struct Env {
|
||||||
/// Whenever we encounter a user-defined type variable (a "rigid" var for short),
|
/// Whenever we encounter a user-defined type variable (a "rigid" var for short),
|
||||||
/// for example `a` in the annotation `identity : a -> a`, we add it to this
|
/// for example `a` in the annotation `identity : a -> a`, we add it to this
|
||||||
/// map so that expressions within that annotation can share these vars.
|
/// map so that expressions within that annotation can share these vars.
|
||||||
pub rigids: ImMap<Lowercase, Variable>,
|
pub rigids: MutMap<Lowercase, Variable>,
|
||||||
pub home: ModuleId,
|
pub home: ModuleId,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -693,7 +693,7 @@ pub fn constrain_expr(
|
||||||
let constraint = constrain_expr(
|
let constraint = constrain_expr(
|
||||||
&Env {
|
&Env {
|
||||||
home: env.home,
|
home: env.home,
|
||||||
rigids: ImMap::default(),
|
rigids: MutMap::default(),
|
||||||
},
|
},
|
||||||
region,
|
region,
|
||||||
&loc_expr.value,
|
&loc_expr.value,
|
||||||
|
@ -1170,7 +1170,7 @@ pub fn constrain_decls(home: ModuleId, decls: &[Declaration]) -> Constraint {
|
||||||
|
|
||||||
let mut env = Env {
|
let mut env = Env {
|
||||||
home,
|
home,
|
||||||
rigids: ImMap::default(),
|
rigids: MutMap::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
for decl in decls.iter().rev() {
|
for decl in decls.iter().rev() {
|
||||||
|
@ -1227,13 +1227,29 @@ fn constrain_def(env: &Env, def: &Def, body_con: Constraint) -> Constraint {
|
||||||
def_pattern_state.vars.push(expr_var);
|
def_pattern_state.vars.push(expr_var);
|
||||||
|
|
||||||
let mut new_rigids = Vec::new();
|
let mut new_rigids = Vec::new();
|
||||||
|
|
||||||
let expr_con = match &def.annotation {
|
let expr_con = match &def.annotation {
|
||||||
Some(annotation) => {
|
Some(annotation) => {
|
||||||
let arity = annotation.signature.arity();
|
let arity = annotation.signature.arity();
|
||||||
let rigids = &env.rigids;
|
let rigids = &env.rigids;
|
||||||
let mut ftv = rigids.clone();
|
let mut ftv = rigids.clone();
|
||||||
|
|
||||||
|
/*
|
||||||
|
let mut new_rigids = Vec::new();
|
||||||
|
|
||||||
|
// pub wildcards: Vec<Variable>,
|
||||||
|
// pub var_by_name: SendMap<Lowercase, Variable>,
|
||||||
|
'outer: for (outer_name, outer_var) in rigids.iter() {
|
||||||
|
for (inner_name, inner_var) in annotation.introduced_variables.var_by_name.iter() {
|
||||||
|
if outer_name == inner_name {
|
||||||
|
debug_assert_eq!(inner_var, outer_var);
|
||||||
|
continue 'outer;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// the inner name is not in the outer scope; it's introduced here
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
let signature = instantiate_rigids(
|
let signature = instantiate_rigids(
|
||||||
&annotation.signature,
|
&annotation.signature,
|
||||||
&annotation.introduced_variables,
|
&annotation.introduced_variables,
|
||||||
|
@ -1514,46 +1530,47 @@ fn instantiate_rigids(
|
||||||
annotation: &Type,
|
annotation: &Type,
|
||||||
introduced_vars: &IntroducedVariables,
|
introduced_vars: &IntroducedVariables,
|
||||||
new_rigids: &mut Vec<Variable>,
|
new_rigids: &mut Vec<Variable>,
|
||||||
ftv: &mut ImMap<Lowercase, Variable>, // rigids defined before the current annotation
|
ftv: &mut MutMap<Lowercase, Variable>, // rigids defined before the current annotation
|
||||||
loc_pattern: &Loc<Pattern>,
|
loc_pattern: &Loc<Pattern>,
|
||||||
headers: &mut SendMap<Symbol, Loc<Type>>,
|
headers: &mut SendMap<Symbol, Loc<Type>>,
|
||||||
) -> Type {
|
) -> Type {
|
||||||
let mut annotation = annotation.clone();
|
// find out if rigid type variables first occur in this annotation,
|
||||||
|
// or if they are already introduced in an outer annotation
|
||||||
let mut rigid_substitution: ImMap<Variable, Type> = ImMap::default();
|
let mut rigid_substitution: ImMap<Variable, Type> = ImMap::default();
|
||||||
|
|
||||||
let outside_rigids: MutSet<Variable> = ftv.values().copied().collect();
|
|
||||||
|
|
||||||
for (name, var) in introduced_vars.var_by_name.iter() {
|
for (name, var) in introduced_vars.var_by_name.iter() {
|
||||||
if let Some(existing_rigid) = ftv.get(name) {
|
use std::collections::hash_map::Entry::*;
|
||||||
rigid_substitution.insert(*var, Type::Variable(*existing_rigid));
|
|
||||||
} else {
|
match ftv.entry(name.clone()) {
|
||||||
// It's possible to use this rigid in nested defs
|
Occupied(occupied) => {
|
||||||
ftv.insert(name.clone(), *var);
|
let existing_rigid = occupied.get();
|
||||||
|
rigid_substitution.insert(*var, Type::Variable(*existing_rigid));
|
||||||
|
}
|
||||||
|
Vacant(vacant) => {
|
||||||
|
// It's possible to use this rigid in nested defs
|
||||||
|
vacant.insert(*var);
|
||||||
|
new_rigids.push(*var);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate rigid variables
|
// wildcards are always freshly introduced in this annotation
|
||||||
|
for (i, wildcard) in introduced_vars.wildcards.iter().enumerate() {
|
||||||
|
ftv.insert(format!("*{}", i).into(), *wildcard);
|
||||||
|
new_rigids.push(*wildcard);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut annotation = annotation.clone();
|
||||||
if !rigid_substitution.is_empty() {
|
if !rigid_substitution.is_empty() {
|
||||||
annotation.substitute(&rigid_substitution);
|
annotation.substitute(&rigid_substitution);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(new_headers) = crate::pattern::headers_from_annotation(
|
let loc_annotation_ref = Loc::at(loc_pattern.region, &annotation);
|
||||||
&loc_pattern.value,
|
if let Pattern::Identifier(symbol) = loc_pattern.value {
|
||||||
&Loc::at(loc_pattern.region, annotation.clone()),
|
headers.insert(symbol, Loc::at(loc_pattern.region, annotation.clone()));
|
||||||
) {
|
} else if let Some(new_headers) =
|
||||||
for (symbol, loc_type) in new_headers {
|
crate::pattern::headers_from_annotation(&loc_pattern.value, &loc_annotation_ref)
|
||||||
for var in loc_type.value.variables() {
|
{
|
||||||
// a rigid is only new if this annotation is the first occurrence of this rigid
|
headers.extend(new_headers)
|
||||||
if !outside_rigids.contains(&var) {
|
|
||||||
new_rigids.push(var);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
headers.insert(symbol, loc_type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i, wildcard) in introduced_vars.wildcards.iter().enumerate() {
|
|
||||||
ftv.insert(format!("*{}", i).into(), *wildcard);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
annotation
|
annotation
|
||||||
|
@ -1584,7 +1601,6 @@ pub fn rec_defs_help(
|
||||||
|
|
||||||
def_pattern_state.vars.push(expr_var);
|
def_pattern_state.vars.push(expr_var);
|
||||||
|
|
||||||
let mut new_rigids = Vec::new();
|
|
||||||
match &def.annotation {
|
match &def.annotation {
|
||||||
None => {
|
None => {
|
||||||
let expr_con = constrain_expr(
|
let expr_con = constrain_expr(
|
||||||
|
@ -1611,6 +1627,7 @@ pub fn rec_defs_help(
|
||||||
Some(annotation) => {
|
Some(annotation) => {
|
||||||
let arity = annotation.signature.arity();
|
let arity = annotation.signature.arity();
|
||||||
let mut ftv = env.rigids.clone();
|
let mut ftv = env.rigids.clone();
|
||||||
|
let mut new_rigids = Vec::new();
|
||||||
|
|
||||||
let signature = instantiate_rigids(
|
let signature = instantiate_rigids(
|
||||||
&annotation.signature,
|
&annotation.signature,
|
||||||
|
|
|
@ -27,7 +27,7 @@ pub struct PatternState {
|
||||||
/// definition has an annotation, we instead now add `x => Int`.
|
/// definition has an annotation, we instead now add `x => Int`.
|
||||||
pub fn headers_from_annotation(
|
pub fn headers_from_annotation(
|
||||||
pattern: &Pattern,
|
pattern: &Pattern,
|
||||||
annotation: &Loc<Type>,
|
annotation: &Loc<&Type>,
|
||||||
) -> Option<SendMap<Symbol, Loc<Type>>> {
|
) -> Option<SendMap<Symbol, Loc<Type>>> {
|
||||||
let mut headers = SendMap::default();
|
let mut headers = SendMap::default();
|
||||||
// Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int`
|
// Check that the annotation structurally agrees with the pattern, preventing e.g. `{ x, y } : Int`
|
||||||
|
@ -44,12 +44,13 @@ pub fn headers_from_annotation(
|
||||||
|
|
||||||
fn headers_from_annotation_help(
|
fn headers_from_annotation_help(
|
||||||
pattern: &Pattern,
|
pattern: &Pattern,
|
||||||
annotation: &Loc<Type>,
|
annotation: &Loc<&Type>,
|
||||||
headers: &mut SendMap<Symbol, Loc<Type>>,
|
headers: &mut SendMap<Symbol, Loc<Type>>,
|
||||||
) -> bool {
|
) -> bool {
|
||||||
match pattern {
|
match pattern {
|
||||||
Identifier(symbol) | Shadowed(_, _, symbol) => {
|
Identifier(symbol) | Shadowed(_, _, symbol) => {
|
||||||
headers.insert(*symbol, annotation.clone());
|
let typ = Loc::at(annotation.region, annotation.value.clone());
|
||||||
|
headers.insert(*symbol, typ);
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
Underscore
|
Underscore
|
||||||
|
@ -106,7 +107,7 @@ fn headers_from_annotation_help(
|
||||||
.all(|(arg_pattern, arg_type)| {
|
.all(|(arg_pattern, arg_type)| {
|
||||||
headers_from_annotation_help(
|
headers_from_annotation_help(
|
||||||
&arg_pattern.1.value,
|
&arg_pattern.1.value,
|
||||||
&Loc::at(annotation.region, arg_type.clone()),
|
&Loc::at(annotation.region, arg_type),
|
||||||
headers,
|
headers,
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -135,12 +136,13 @@ fn headers_from_annotation_help(
|
||||||
&& type_arguments.len() == pat_type_arguments.len()
|
&& type_arguments.len() == pat_type_arguments.len()
|
||||||
&& lambda_set_variables.len() == pat_lambda_set_variables.len() =>
|
&& lambda_set_variables.len() == pat_lambda_set_variables.len() =>
|
||||||
{
|
{
|
||||||
headers.insert(*opaque, annotation.clone());
|
let typ = Loc::at(annotation.region, annotation.value.clone());
|
||||||
|
headers.insert(*opaque, typ);
|
||||||
|
|
||||||
let (_, argument_pat) = &**argument;
|
let (_, argument_pat) = &**argument;
|
||||||
headers_from_annotation_help(
|
headers_from_annotation_help(
|
||||||
&argument_pat.value,
|
&argument_pat.value,
|
||||||
&Loc::at(annotation.region, (**actual).clone()),
|
&Loc::at(annotation.region, actual),
|
||||||
headers,
|
headers,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5509,4 +5509,42 @@ mod solve_expr {
|
||||||
r#"Id [ A, B, C { a : Str }e ] -> Str"#,
|
r#"Id [ A, B, C { a : Str }e ] -> Str"#,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inner_annotation_rigid() {
|
||||||
|
infer_eq_without_problem(
|
||||||
|
indoc!(
|
||||||
|
r#"
|
||||||
|
f : a -> a
|
||||||
|
f =
|
||||||
|
g : b -> b
|
||||||
|
g = \x -> x
|
||||||
|
|
||||||
|
g
|
||||||
|
|
||||||
|
f
|
||||||
|
"#
|
||||||
|
),
|
||||||
|
r#"a -> a"#,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn inner_annotation_rigid_2() {
|
||||||
|
infer_eq_without_problem(
|
||||||
|
indoc!(
|
||||||
|
r#"
|
||||||
|
f : {} -> List a
|
||||||
|
f =
|
||||||
|
g : List a
|
||||||
|
g = []
|
||||||
|
|
||||||
|
\{} -> g
|
||||||
|
|
||||||
|
f
|
||||||
|
"#
|
||||||
|
),
|
||||||
|
r#"{} -> List a"#,
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -991,6 +991,118 @@ fn symbols_help(tipe: &Type, accum: &mut ImSet<Symbol>) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct VariablesIter<'a> {
|
||||||
|
stack: Vec<&'a Type>,
|
||||||
|
recursion_variables: Vec<Variable>,
|
||||||
|
variables: Vec<Variable>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> VariablesIter<'a> {
|
||||||
|
pub fn new(typ: &'a Type) -> Self {
|
||||||
|
Self {
|
||||||
|
stack: vec![typ],
|
||||||
|
recursion_variables: vec![],
|
||||||
|
variables: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for VariablesIter<'a> {
|
||||||
|
type Item = Variable;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
use Type::*;
|
||||||
|
|
||||||
|
if let Some(var) = self.variables.pop() {
|
||||||
|
debug_assert!(!self.recursion_variables.contains(&var));
|
||||||
|
|
||||||
|
return Some(var);
|
||||||
|
}
|
||||||
|
|
||||||
|
while let Some(tipe) = self.stack.pop() {
|
||||||
|
match tipe {
|
||||||
|
EmptyRec | EmptyTagUnion | Erroneous(_) => {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ClosureTag { ext: v, .. } | Variable(v) => {
|
||||||
|
if !self.recursion_variables.contains(v) {
|
||||||
|
return Some(*v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Function(args, closure, ret) => {
|
||||||
|
self.stack.push(ret);
|
||||||
|
self.stack.push(closure);
|
||||||
|
self.stack.extend(args.iter().rev());
|
||||||
|
}
|
||||||
|
Record(fields, ext) => {
|
||||||
|
use RecordField::*;
|
||||||
|
|
||||||
|
self.stack.push(ext);
|
||||||
|
|
||||||
|
for (_, field) in fields {
|
||||||
|
match field {
|
||||||
|
Optional(x) => self.stack.push(x),
|
||||||
|
Required(x) => self.stack.push(x),
|
||||||
|
Demanded(x) => self.stack.push(x),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TagUnion(tags, ext) => {
|
||||||
|
self.stack.push(ext);
|
||||||
|
|
||||||
|
for (_, args) in tags {
|
||||||
|
self.stack.extend(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
FunctionOrTagUnion(_, _, ext) => {
|
||||||
|
self.stack.push(ext);
|
||||||
|
}
|
||||||
|
RecursiveTagUnion(rec, tags, ext) => {
|
||||||
|
self.recursion_variables.push(*rec);
|
||||||
|
self.stack.push(ext);
|
||||||
|
|
||||||
|
for (_, args) in tags.iter().rev() {
|
||||||
|
self.stack.extend(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Alias {
|
||||||
|
type_arguments,
|
||||||
|
actual,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.stack.push(actual);
|
||||||
|
|
||||||
|
for (_, args) in type_arguments.iter().rev() {
|
||||||
|
self.stack.push(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
HostExposedAlias {
|
||||||
|
type_arguments: arguments,
|
||||||
|
actual,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
self.stack.push(actual);
|
||||||
|
|
||||||
|
for (_, args) in arguments.iter().rev() {
|
||||||
|
self.stack.push(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RangedNumber(typ, vars) => {
|
||||||
|
self.stack.push(typ);
|
||||||
|
self.variables.extend(vars.iter().copied());
|
||||||
|
}
|
||||||
|
Apply(_, args, _) => {
|
||||||
|
self.stack.extend(args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn variables_help(tipe: &Type, accum: &mut ImSet<Variable>) {
|
fn variables_help(tipe: &Type, accum: &mut ImSet<Variable>) {
|
||||||
use Type::*;
|
use Type::*;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue