mirror of
https://github.com/roc-lang/roc.git
synced 2025-10-03 16:44:33 +00:00
Trying out ena
This commit is contained in:
parent
338be03bdd
commit
5635561fca
8 changed files with 535 additions and 44 deletions
|
@ -51,7 +51,7 @@ pub fn literal_to_string<'a>(literal: &'a Literal<'a>) -> String {
|
||||||
elem_strings.push(literal_to_string(eval(elem_expr)));
|
elem_strings.push(literal_to_string(eval(elem_expr)));
|
||||||
}
|
}
|
||||||
|
|
||||||
format!("[ {} ]", elem_strings.join(", "))
|
format!("[{}]", elem_strings.join(", "))
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
#![feature(box_patterns)]
|
#![feature(box_patterns)]
|
||||||
|
|
||||||
pub mod unify;
|
// pub mod unify;
|
||||||
pub mod interpret;
|
// pub mod interpret;
|
||||||
pub mod repl;
|
// pub mod repl;
|
||||||
|
|
||||||
|
pub mod solve;
|
||||||
|
|
||||||
|
extern crate ena;
|
||||||
|
|
274
src/solve.rs
274
src/solve.rs
|
@ -1,9 +1,19 @@
|
||||||
use self::Type::*;
|
use std::collections::BTreeSet;
|
||||||
|
use self::VarContent::*;
|
||||||
|
use self::Operator::*;
|
||||||
|
use ena::unify::UnificationTable;
|
||||||
|
use ena::unify::UnifyValue;
|
||||||
|
use ena::unify::InPlace;
|
||||||
|
|
||||||
pub type Name<'a> = &'a str;
|
pub type Name<'a> = &'a str;
|
||||||
|
|
||||||
pub type ModuleName<'a> = &'a str;
|
pub type ModuleName<'a> = &'a str;
|
||||||
|
|
||||||
|
type UTable<'a> = UnificationTable<InPlace<Variable<'a>>>;
|
||||||
|
|
||||||
|
type TypeUnion<'a> = BTreeSet<Type<'a>>;
|
||||||
|
type VarUnion<'a> = BTreeSet<VarContent<'a>>;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub enum Type<'a> {
|
pub enum Type<'a> {
|
||||||
Symbol(&'a str),
|
Symbol(&'a str),
|
||||||
|
@ -11,6 +21,7 @@ pub enum Type<'a> {
|
||||||
Float,
|
Float,
|
||||||
Number,
|
Number,
|
||||||
Function(Box<Type<'a>>, Box<Type<'a>>),
|
Function(Box<Type<'a>>, Box<Type<'a>>),
|
||||||
|
CallOperator(Operator, Box<&'a Type<'a>>, Box<&'a Type<'a>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -34,6 +45,267 @@ pub enum Problem {
|
||||||
Mismatch
|
Mismatch
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq, Clone)]
|
||||||
|
pub struct Variable<'a> {
|
||||||
|
content: VarContent<'a>,
|
||||||
|
rank: u8
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
enum VarContent<'a> {
|
||||||
|
Wildcard,
|
||||||
|
RigidVar(&'a Name<'a>),
|
||||||
|
FlexUnion(TypeUnion<'a>),
|
||||||
|
RigidUnion(TypeUnion<'a>),
|
||||||
|
Structure(FlatType<'a>),
|
||||||
|
Mismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_rigid<'a>(named: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
|
||||||
|
match other {
|
||||||
|
Wildcard => named,
|
||||||
|
RigidVar(_) => Mismatch,
|
||||||
|
FlexUnion(_) => Mismatch,
|
||||||
|
RigidUnion(_) => Mismatch,
|
||||||
|
Mismatch => other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_rigid_union<'a>(rigid_union: &'a VarUnion<'a>, var: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
|
||||||
|
match other {
|
||||||
|
Wildcard => var,
|
||||||
|
RigidVar(_) => Mismatch,
|
||||||
|
FlexUnion(flex_union) => {
|
||||||
|
// a flex union can conform to a rigid one, as long as
|
||||||
|
// as the rigid union contains all the flex union's options
|
||||||
|
if rigid_union.is_subset(flex_union) {
|
||||||
|
var
|
||||||
|
} else {
|
||||||
|
Mismatch
|
||||||
|
}
|
||||||
|
},
|
||||||
|
RigidUnion(_) => Mismatch,
|
||||||
|
Mismatch => other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_flex_union<'a>(flex_union: &'a VarUnion<'a>, var: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
|
||||||
|
match other {
|
||||||
|
Wildcard => var,
|
||||||
|
RigidVar(_) => Mismatch,
|
||||||
|
RigidUnion(rigid_union) => {
|
||||||
|
// a flex union can conform to a rigid one, as long as
|
||||||
|
// as the rigid union contains all the flex union's options
|
||||||
|
if rigid_union.is_subset(flex_union) {
|
||||||
|
other
|
||||||
|
} else {
|
||||||
|
Mismatch
|
||||||
|
}
|
||||||
|
},
|
||||||
|
FlexUnion(other_union) => unify_flex_unions(flex_union, var, other_union, other),
|
||||||
|
Structure(flat_type) => unify_flex_union_with_flat_type(flex_union, flat_type),
|
||||||
|
Mismatch => other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_flex_unions<'a>(my_union: &'a VarUnion<'a>, my_var: &'a VarContent<'a>, other_union: &'a VarUnion<'a>, other_var: &'a VarContent<'a>) -> &'a VarContent<'a> {
|
||||||
|
// Prioritize not allocating a new BTreeSet if possible.
|
||||||
|
if my_union == other_union {
|
||||||
|
return my_var;
|
||||||
|
}
|
||||||
|
|
||||||
|
let types_in_common = my_union.intersection(other_union);
|
||||||
|
|
||||||
|
if types_in_common.is_empty() {
|
||||||
|
Mismatch
|
||||||
|
} else {
|
||||||
|
let unified_union: VarUnion<'a> = types_in_common.into_iter().collect();
|
||||||
|
|
||||||
|
FlexUnion(unified_union)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn actually_unify<'a>(first: &'a VarContent<'a>, second: &'a VarContent<'a>) -> &'a VarContent<'a> {
|
||||||
|
match first {
|
||||||
|
// wildcard types defer to whatever the other type happens to be.
|
||||||
|
Wildcard => second,
|
||||||
|
FlexUnion(union) => unify_flex_union(union, first, second),
|
||||||
|
RigidVar(Name) => unify_rigid(first, second),
|
||||||
|
RigidUnion(union) => unify_rigid_union(union, first, second),
|
||||||
|
Structure(flat_type) => unify_structure(flat_type, first, second),
|
||||||
|
// Mismatches propagate.
|
||||||
|
Mismatch => first
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type CanonicalModuleName = String;
|
||||||
|
|
||||||
|
enum FlatType<'a> {
|
||||||
|
Function(Variable<'a>, Variable<'a>),
|
||||||
|
// Apply a higher-kinded type constructor by name
|
||||||
|
// e.g. apply `Array` to the variable `Int` to form `Array Int`
|
||||||
|
// ApplyTypeConstructor(CanonicalModuleName, Name, &'a Variable<'a>)
|
||||||
|
Tuple2(Variable<'a>, Variable<'a>),
|
||||||
|
// Tuple3(Variable<'a>, Variable<'a>, Variable<'a>),
|
||||||
|
// TupleN(Vec<Variable<'a>>), // Last resort - allocates
|
||||||
|
// Record1 (Map.Map N.Name Variable) Variable,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_args<'a>(arg1: &'a Variable<'a>, arg2: Variable) -> Result<Vec<Variable<'a>>, Vec<Variable<'a>>> {
|
||||||
|
guarded_unify(arg1, arg2)
|
||||||
|
// case subUnify arg1 arg2 of
|
||||||
|
// Unify k ->
|
||||||
|
// k vars
|
||||||
|
// (\vs () -> unifyArgs vs context others1 others2 ok err)
|
||||||
|
// (\vs () -> unifyArgs vs context others1 others2 err err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn guarded_unify<'a>(utable: UTable<'a>, left: Variable<'a>, right: Variable<'a>) -> Result<(), ()> {
|
||||||
|
if utable.unioned(left, right) {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
let left_descriptor = utable.probe_key(left);
|
||||||
|
let right_descriptor = utable.probe_key(right);
|
||||||
|
|
||||||
|
actually_unify(left, left_descriptor, right, right_descriptor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn unify_structure<'a>(utable: &'a mut UTable<'a>, flat_type: &'a FlatType<'a>, var: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
|
||||||
|
match other {
|
||||||
|
Wildcard => var,
|
||||||
|
RigidVar(_) => Mismatch,
|
||||||
|
FlexUnion(union) => unify_flex_union_with_flat_type(flex_union, flat_type),
|
||||||
|
RigidUnion(_) => Mismatch,
|
||||||
|
Structure(other_flat_type) =>
|
||||||
|
match (flat_type, other) {
|
||||||
|
(FlatType::Function(my_arg, my_return),
|
||||||
|
FlatType::Function(other_arg, other_return)) => {
|
||||||
|
guarded_unify(utable, my_arg, other_arg);
|
||||||
|
guarded_unify(utable, my_returned, other_returned);
|
||||||
|
},
|
||||||
|
(FlatType::Tuple2(my_first, my_second),
|
||||||
|
FlatType::Tuple2(other_first, other_second)) => {
|
||||||
|
guarded_unify(utable, my_first, other_first);
|
||||||
|
guarded_unify(utable, my_second, other_second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Mismatch =>
|
||||||
|
other
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unify_flex_union_with_flat_type<'a>(utable: &'a mut UTable<'a>, flex_union: &'a VarUnion<'a>, flat_type: &'a FlatType<'a>) -> &'a VarContent<'a> {
|
||||||
|
if var_union_contains(flex_union, flat_type) {
|
||||||
|
// This will use the UnifyValue trait to unify the values.
|
||||||
|
utable.union(var1, var2);
|
||||||
|
} else {
|
||||||
|
Mismatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
type ExpectedType<'a> = Type<'a>;
|
||||||
|
|
||||||
|
pub enum Constraint<'a> {
|
||||||
|
True,
|
||||||
|
Equal(Type<'a>, ExpectedType<'a>),
|
||||||
|
Batch(Vec<Constraint<'a>>),
|
||||||
|
}
|
||||||
|
|
||||||
pub fn infer_type<'a>(expr: Expr<'a>) -> Result<Type<'a>, Problem> {
|
pub fn infer_type<'a>(expr: Expr<'a>) -> Result<Type<'a>, Problem> {
|
||||||
Err(Problem::Mismatch)
|
Err(Problem::Mismatch)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct State {
|
||||||
|
errors: Vec<String>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl<'a> UnifyValue for Variable<'a> {
|
||||||
|
// We return our own Mismatch variant to track errors.
|
||||||
|
type Error = ena::unify::NoError;
|
||||||
|
|
||||||
|
fn unify_values(value1: &'a Variable<'a>, value2: &'a Variable<'a>) -> Result<Variable<'a>, ena::unify::NoError> {
|
||||||
|
// TODO unify 'em
|
||||||
|
|
||||||
|
// TODO problem: Elm's unification mutates and looks things up as it goes.
|
||||||
|
// I can see these possible ways to proceed:
|
||||||
|
// (1) Try to have the table's values contain a mutable reference to the table itself.
|
||||||
|
// This sounds like a mistake.
|
||||||
|
// (2) Implement unification without mutating as we go.
|
||||||
|
// Might be too slow, and might not even work.
|
||||||
|
// Like, what if I need to look something up in the middle?
|
||||||
|
// (3) Make a custom fork of ena that supports Elm's way.
|
||||||
|
// (3a) Change the unify_values function to accept the table itself, so it can be
|
||||||
|
// passed in and used during unification
|
||||||
|
// (3b) Change the unify_values function to accept the table itself, so it can be
|
||||||
|
// passed in and used during unification. I'm not super confident this would work.
|
||||||
|
//
|
||||||
|
// Possibly before doing any of this, I should look at ena's examples/tests
|
||||||
|
|
||||||
|
// TODO also I'm pretty sure in this implementation,
|
||||||
|
// I'm supposed to let them take care of the rank.
|
||||||
|
Ok(Variable {content, rank: min(rank1, rank2)})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn type_to_var(rank: u8, typ: Type) -> Variable {
|
||||||
|
match typ {
|
||||||
|
Type::CallOperator(op, left_type, right_type) => {
|
||||||
|
let left_var = type_to_var(left_type);
|
||||||
|
let right_var = type_to_var(right_type);
|
||||||
|
|
||||||
|
// TODO should we match on op to hardcode the types we expect?
|
||||||
|
let flat_type = FlatType::Function(left_var, right_var);
|
||||||
|
let content = Structure(flat_type);
|
||||||
|
|
||||||
|
utable.new_key(Variable {rank, content})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
pub fn unify(utable: Table, left_var: Variable, right_var: Variable) -> Result<(), ()>{
|
||||||
|
let left_content = utable.probe_value(left_var);
|
||||||
|
let right_content = utable.probe_value(right_var);
|
||||||
|
|
||||||
|
if left_content == right_content {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Ok(actually_unify(left, left_desc, right, right_desc))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn solve(rank: u8, state: State, constraint: Constraint) {
|
||||||
|
match constraint {
|
||||||
|
True =>
|
||||||
|
state
|
||||||
|
|
||||||
|
Equal(actual_type, expectation) => {
|
||||||
|
let actual_var = type_to_var(rank, actual_type)
|
||||||
|
let expected_var = type_to_var(rank, expectation)
|
||||||
|
let answer = unify(actual_var, expected_var)
|
||||||
|
|
||||||
|
match answer {
|
||||||
|
Ok vars ->
|
||||||
|
panic!("TODO abc");
|
||||||
|
// do introduce rank pools vars
|
||||||
|
// return state
|
||||||
|
|
||||||
|
// UF.modify var $ \(Descriptor content _ mark copy) ->
|
||||||
|
// Descriptor content rank mark copy
|
||||||
|
|
||||||
|
// Unify.Err vars actualType expectedType ->
|
||||||
|
|
||||||
|
// panic!("TODO xyz");
|
||||||
|
// do introduce rank pools vars
|
||||||
|
// return $ addError state $
|
||||||
|
// Error.BadExpr region category actualType $
|
||||||
|
// Error.typeReplace expectation expectedType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
282
src/unify.rs
282
src/unify.rs
|
@ -1,8 +1,10 @@
|
||||||
use std::collections::BTreeSet;
|
use std::collections::BTreeSet;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use self::Type::*;
|
||||||
|
|
||||||
pub type Ident<'a> = &'a str;
|
pub type Name<'a> = &'a str;
|
||||||
|
|
||||||
pub type Field<'a> = &'a str;
|
pub type ModuleName<'a> = &'a str;
|
||||||
|
|
||||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
|
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||||
pub enum Type<'a> {
|
pub enum Type<'a> {
|
||||||
|
@ -14,9 +16,9 @@ pub enum Type<'a> {
|
||||||
Number,
|
Number,
|
||||||
Symbol(&'a str),
|
Symbol(&'a str),
|
||||||
Array(Box<Type<'a>>),
|
Array(Box<Type<'a>>),
|
||||||
Record(Vec<(Field<'a>, Type<'a>)>),
|
Function(Box<Type<'a>>, Box<Type<'a>>),
|
||||||
|
Record(BTreeMap<Name<'a>, Type<'a>>),
|
||||||
Tuple(Vec<Type<'a>>),
|
Tuple(Vec<Type<'a>>),
|
||||||
Assignment(Ident<'a>, Box<Type<'a>>),
|
|
||||||
Union(BTreeSet<Type<'a>>),
|
Union(BTreeSet<Type<'a>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,39 +27,73 @@ pub enum Type<'a> {
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum Expr<'a> {
|
pub enum Expr<'a> {
|
||||||
Literal(&'a Literal<'a>),
|
// Variables
|
||||||
Assignment(Ident<'a>, Box<&'a Expr<'a>>),
|
Declaration(&'a Pattern<'a>, Box<&'a Expr<'a>>, Box<Expr<'a>>),
|
||||||
If(Box<&'a Expr<'a>> /* Conditional */, Box<&'a Expr<'a>> /* True branch */, Box<&'a Expr<'a>> /* False branch */),
|
LookupLocal(&'a Name<'a>),
|
||||||
|
LookupGlobal(&'a ModuleName<'a>, &'a Name<'a>),
|
||||||
|
|
||||||
|
// Scalars
|
||||||
|
Symbol(&'a str),
|
||||||
|
String(&'a str),
|
||||||
|
Char(char),
|
||||||
|
HexOctalBinary(i64), // : Int
|
||||||
|
FractionalNumber(f64), // : Float
|
||||||
|
WholeNumber(i64), // : Int | Float
|
||||||
|
|
||||||
|
// Collections
|
||||||
|
Array(Vec<Expr<'a>>),
|
||||||
|
Record(Vec<(&'a Name<'a>, &'a Expr<'a>)>),
|
||||||
|
Tuple(Vec<&'a Expr<'a>>),
|
||||||
|
LookupName(Name<'a>, Box<&'a Expr<'a>>),
|
||||||
// TODO add record update
|
// TODO add record update
|
||||||
// TODO add conditional
|
|
||||||
// TODO add function
|
// Functions
|
||||||
|
Function(&'a Pattern<'a>, &'a Expr<'a>),
|
||||||
|
Call(Box<&'a Expr<'a>>, Box<&'a Expr<'a>>),
|
||||||
|
CallOperator(&'a Operator, Box<&'a Expr<'a>>, Box<&'a Expr<'a>>),
|
||||||
|
|
||||||
|
// Conditionals
|
||||||
|
If(Box<&'a Expr<'a>> /* Conditional */, Box<&'a Expr<'a>> /* True branch */, Box<&'a Expr<'a>> /* False branch */),
|
||||||
|
Case(Box<&'a Expr<'a>>, Vec<(&'a Pattern<'a>, &'a Expr<'a>)>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, PartialEq)]
|
#[derive(Debug, PartialEq)]
|
||||||
pub enum Literal<'a> {
|
pub enum Operator {
|
||||||
|
Plus, Minus, Star, Caret, Percent, FloatDivision, IntDivision,
|
||||||
|
GT, GTE, LT, LTE,
|
||||||
|
EQ, NE, And, Or,
|
||||||
|
QuestionMark, Or
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub enum Pattern<'a> {
|
||||||
|
Name(&'a Name<'a>), // `foo =`
|
||||||
|
As(&'a Name<'a>, &'a Pattern<'a>), // `<pattern> as foo`
|
||||||
|
Type(&'a Type<'a>),
|
||||||
|
Symbol(&'a str),
|
||||||
String(&'a str),
|
String(&'a str),
|
||||||
Char(char),
|
Char(char),
|
||||||
Number(&'a str),
|
WholeNumber(&'a str),
|
||||||
|
FractionalNumber(&'a str),
|
||||||
HexOctalBinary(&'a str),
|
HexOctalBinary(&'a str),
|
||||||
Symbol(&'a str),
|
Tuple(Vec<Pattern<'a>>),
|
||||||
Array(Vec<Expr<'a>>),
|
Record(Vec<(Name<'a>, Option<Pattern<'a>>)>), // { a = 5, b : Int as x, c }
|
||||||
Record(Vec<(Field<'a>, &'a Expr<'a>)>),
|
|
||||||
Tuple(Vec<&'a Expr<'a>>)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
|
pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
|
||||||
match expr {
|
match expr {
|
||||||
Expr::Literal(Literal::String(_)) => Ok(Type::String),
|
Expr::String(_) => Ok(String),
|
||||||
Expr::Literal(Literal::HexOctalBinary(_)) => Ok(Type::Int),
|
Expr::Char(_) => Ok(Char),
|
||||||
Expr::Literal(Literal::Char(_)) => Ok(Type::Char),
|
Expr::HexOctalBinary(_) => Ok(Int),
|
||||||
Expr::Literal(Literal::Number(_)) => Ok(Type::Number),
|
Expr::FractionalNumber(_) => Ok(Float),
|
||||||
Expr::Literal(Literal::Symbol(sym)) => Ok(Type::Symbol(sym)),
|
Expr::WholeNumber(_) => Ok(Number),
|
||||||
Expr::Literal(Literal::Array(elem_exprs)) => {
|
Expr::Symbol(sym) => Ok(Symbol(sym)),
|
||||||
|
Expr::Array(elem_exprs) => {
|
||||||
let elem_type;
|
let elem_type;
|
||||||
|
|
||||||
if elem_exprs.is_empty() {
|
if elem_exprs.is_empty() {
|
||||||
elem_type = Type::Unbound;
|
elem_type = Unbound;
|
||||||
} else {
|
} else {
|
||||||
let mut unified_type = BTreeSet::new();
|
let mut unified_type = BTreeSet::new();
|
||||||
|
|
||||||
|
@ -70,24 +106,24 @@ pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
|
||||||
// No point in storing a union of 1.
|
// No point in storing a union of 1.
|
||||||
elem_type = unified_type.into_iter().next().unwrap()
|
elem_type = unified_type.into_iter().next().unwrap()
|
||||||
} else {
|
} else {
|
||||||
elem_type = Type::Union(unified_type)
|
elem_type = Union(unified_type)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Type::Array(Box::new(elem_type)))
|
Ok(Array(Box::new(elem_type)))
|
||||||
},
|
},
|
||||||
Expr::Literal(Literal::Record(fields)) => {
|
Expr::Record(fields) => {
|
||||||
let mut rec_type: Vec<(&'a str, Type<'a>)> = Vec::new();
|
let mut rec_type: BTreeMap<&'a Name<'a>, Type<'a>> = BTreeMap::new();
|
||||||
|
|
||||||
for (field, subexpr) in fields {
|
for (field, subexpr) in fields {
|
||||||
let field_type = infer(subexpr)?;
|
let field_type = infer(subexpr)?;
|
||||||
|
|
||||||
rec_type.push((&field, field_type));
|
rec_type.insert(&field, field_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Type::Record(rec_type))
|
Ok(Record(rec_type))
|
||||||
},
|
},
|
||||||
Expr::Literal(Literal::Tuple(exprs)) => {
|
Expr::Tuple(exprs) => {
|
||||||
let mut tuple_type: Vec<Type<'a>> = Vec::new();
|
let mut tuple_type: Vec<Type<'a>> = Vec::new();
|
||||||
|
|
||||||
for subexpr in exprs {
|
for subexpr in exprs {
|
||||||
|
@ -96,7 +132,7 @@ pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
|
||||||
tuple_type.push(field_type);
|
tuple_type.push(field_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Type::Tuple(tuple_type))
|
Ok(Tuple(tuple_type))
|
||||||
},
|
},
|
||||||
Expr::If(box cond, expr_if_true, expr_if_false) => {
|
Expr::If(box cond, expr_if_true, expr_if_false) => {
|
||||||
let cond_type = infer(&cond)?;
|
let cond_type = infer(&cond)?;
|
||||||
|
@ -122,29 +158,197 @@ pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
|
||||||
// but we can pull it back out of the set
|
// but we can pull it back out of the set
|
||||||
Ok(unified_type.into_iter().next().unwrap())
|
Ok(unified_type.into_iter().next().unwrap())
|
||||||
} else {
|
} else {
|
||||||
Ok(Type::Union(unified_type))
|
Ok(Union(unified_type))
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Expr::Assignment(ident, subexpr) => {
|
Call(func, arg) => {
|
||||||
Ok(Type::Assignment(ident, Box::new(infer(subexpr)?)))
|
|
||||||
|
},
|
||||||
|
CallOperator(op, left_expr, right_expr) => {
|
||||||
|
let left = &(infer(left_expr)?);
|
||||||
|
let right = &(infer(right_expr)?);
|
||||||
|
|
||||||
|
match op {
|
||||||
|
Operator::EQ | Operator::NE | Operator::And | Operator::Or => {
|
||||||
|
if types_match(left, right) {
|
||||||
|
conform_to_bool(left)
|
||||||
|
} else {
|
||||||
|
Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Operator::Plus | Operator::Minus | Operator::Star
|
||||||
|
| Operator::GT | Operator::LT | Operator::GTE | Operator::LTE
|
||||||
|
| Operator::Caret | Operator::Percent => {
|
||||||
|
if types_match(left, right) {
|
||||||
|
conform_to_number(left)
|
||||||
|
} else {
|
||||||
|
Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Operator::FloatDivision => {
|
||||||
|
if matches_float_type(left) && matches_float_type(right) {
|
||||||
|
Ok(&Float)
|
||||||
|
} else {
|
||||||
|
Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Operator::IntDivision => {
|
||||||
|
if matches_int_type(left) && matches_int_type(right) {
|
||||||
|
Ok(&Int)
|
||||||
|
} else {
|
||||||
|
Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Operator::CombineStrings => {
|
||||||
|
if matches_string_type(left) && matches_string_type(right) {
|
||||||
|
Ok(&String)
|
||||||
|
} else {
|
||||||
|
Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Operator::QuestionMark => {
|
||||||
|
if types_match(left, right) {
|
||||||
|
conform_to_optional(left)
|
||||||
|
} else {
|
||||||
|
Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Expr::Declaration(pattern, let_expr, in_expr) => {
|
||||||
|
// Think of this as a let..in even though syntactically it's not.
|
||||||
|
// We need to type-check the let-binding, but the type of the
|
||||||
|
// *expression* we're expaning is only affected by the in-block.
|
||||||
|
check_pattern(&pattern, &let_expr)?;
|
||||||
|
|
||||||
|
infer(in_expr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn types_match<'a>(first: &'a Type<'a>, second: &'a Type<'a>) -> bool {
|
||||||
|
match (first, second) {
|
||||||
|
(Type::Union(first_types), Type::Union(second_types)) => {
|
||||||
|
// If any type is not directly present in the other union,
|
||||||
|
// it must at least match *some* type in the other union
|
||||||
|
first_types.difference(second_types).into_iter().all(|not_in_second_type| {
|
||||||
|
second_types.iter().any(|second_type| types_match(second_type, not_in_second_type))
|
||||||
|
}) &&
|
||||||
|
second_types.difference(first_types).into_iter().all(|not_in_first_type| {
|
||||||
|
first_types.iter().any(|first_type| types_match(first_type, not_in_first_type))
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
// Happy path: try these first, since we expect them to succeed.
|
||||||
|
// These are sorted based on a vague guess of how often they will be used in practice.
|
||||||
|
(Type::Symbol(sym_one), Type::Symbol(sym_two)) => sym_one == sym_two,
|
||||||
|
(Type::String, Type::String) => true,
|
||||||
|
(Type::Unbound, _) | (_, Type::Unbound)=> true,
|
||||||
|
(Type::Array(box elem_type_one), Type::Array(box elem_type_two)) => {
|
||||||
|
types_match(elem_type_one, elem_type_two)
|
||||||
|
},
|
||||||
|
(Type::Number, Type::Number) => true,
|
||||||
|
(Type::Number, other) => matches_number_type(other),
|
||||||
|
(other, Type::Number) => matches_number_type(other),
|
||||||
|
(Type::Int, Type::Int) => true,
|
||||||
|
(Type::Float, Type::Float) => true,
|
||||||
|
(Type::Tuple(first_elems), Type::Tuple(second_elems)) => {
|
||||||
|
// TODO verify that the elems and their types match up
|
||||||
|
// TODO write some scenarios to understand these better -
|
||||||
|
// like, what happens if you have a function that takes
|
||||||
|
// a lambda whose argument takes an open record,
|
||||||
|
// and you pass a lamba whose argument takes *fewer* fields?
|
||||||
|
// that should work! the function is gonna pass it a lambda that
|
||||||
|
// has more fields than it needs.
|
||||||
|
// I think there's an element of directionality here that I'm
|
||||||
|
// disregarding. Maybe this function shouldn't commute.
|
||||||
|
},
|
||||||
|
(Type::Function(first_arg), Type::Function(second_arg)) => {
|
||||||
|
// TODO verify that the elems and their types match up
|
||||||
|
},
|
||||||
|
(Type::Record(first_fields), Type::Record(second_fields)) => {
|
||||||
|
// TODO verify that the fields and their types match up
|
||||||
|
// TODO what should happen if one is a superset of the other? fail?
|
||||||
|
},
|
||||||
|
(Type::Char, Type::Char) => true,
|
||||||
|
|
||||||
|
// Unhappy path - expect these to fail, so check them last
|
||||||
|
(Type::Union(first_types), _) => {
|
||||||
|
first_types.iter().all(|typ| types_match(typ, second))
|
||||||
|
},
|
||||||
|
(_, Type::Union(second_types)) => {
|
||||||
|
second_types.iter().all(|typ| types_match(first, typ))
|
||||||
|
},
|
||||||
|
(Type::String, _) | (_, Type::String) => false,
|
||||||
|
(Type::Char, _) | (_, Type::Char) => false,
|
||||||
|
(Type::Int, _) | (_, Type::Int) => false,
|
||||||
|
(Type::Float, _) | (_, Type::Float) => false,
|
||||||
|
(Type::Symbol(_), _) | (_, Type::Symbol(_)) => false,
|
||||||
|
(Type::Array(_), _) | (_, Type::Array(_)) => false,
|
||||||
|
(Type::Record(_), _) | (_, Type::Record(_)) => false,
|
||||||
|
(Type::Tuple(_), _) | (_, Type::Tuple(_)) => false,
|
||||||
|
(Type::Function(_, _), _) | (_, Type::Function(_, _)) => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn check_pattern<'a>(pattern: &'a Pattern<'a>, expr: &'a Expr<'a>) -> Result<(), UnificationProblem> {
|
||||||
|
let expr_type = infer(expr)?;
|
||||||
|
|
||||||
|
panic!("TODO check the pattern's type against expr_type, then write some tests for funky record pattern cases - this is our first real unification! Next one will be field access, ooooo - gonna want lots of tests for that")
|
||||||
|
}
|
||||||
|
|
||||||
const TRUE_SYMBOL_STR: &'static str = "True";
|
const TRUE_SYMBOL_STR: &'static str = "True";
|
||||||
const FALSE_SYMBOL_STR: &'static str = "False";
|
const FALSE_SYMBOL_STR: &'static str = "False";
|
||||||
|
|
||||||
|
pub fn matches_string_type<'a>(candidate: &Type<'a>) -> bool {
|
||||||
|
match candidate {
|
||||||
|
Unbound | String => true,
|
||||||
|
Type::Union(types) => {
|
||||||
|
types.iter().all(|typ| matches_string_type(typ))
|
||||||
|
},
|
||||||
|
_ => Err(UnificationProblem::TypeMismatch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn matches_bool_type<'a>(candidate: &Type<'a>) -> bool {
|
pub fn matches_bool_type<'a>(candidate: &Type<'a>) -> bool {
|
||||||
match candidate {
|
match candidate {
|
||||||
Type::Symbol(str) => {
|
Type::Unbound => true,
|
||||||
str == &TRUE_SYMBOL_STR || str == &FALSE_SYMBOL_STR
|
Type::Symbol(str) => str == &TRUE_SYMBOL_STR || str == &FALSE_SYMBOL_STR,
|
||||||
}
|
|
||||||
Type::Union(types) => {
|
Type::Union(types) => {
|
||||||
types.len() <= 2 && types.iter().all(|typ| matches_bool_type(typ))
|
types.iter().all(|typ| matches_bool_type(typ))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => false
|
||||||
false
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn matches_number_type<'a>(candidate: &Type<'a>) -> bool {
|
||||||
|
match candidate {
|
||||||
|
Type::Unbound | Type::Int | Type::Float | Type::Number => true,
|
||||||
|
Type::Union(types) => {
|
||||||
|
types.iter().all(|typ| matches_number_type(typ))
|
||||||
}
|
}
|
||||||
|
_ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn matches_int_type<'a>(candidate: &Type<'a>) -> bool {
|
||||||
|
match candidate {
|
||||||
|
Type::Unbound | Type::Int => true,
|
||||||
|
Type::Union(types) => {
|
||||||
|
types.iter().all(|typ| matches_int_type(typ))
|
||||||
|
}
|
||||||
|
_ => false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn matches_float_type<'a>(candidate: &Type<'a>) -> bool {
|
||||||
|
match candidate {
|
||||||
|
Type::Unbound | Type::Float => true,
|
||||||
|
Type::Union(types) => {
|
||||||
|
types.iter().all(|typ| matches_float_type(typ))
|
||||||
|
}
|
||||||
|
_ => false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,4 +71,16 @@ mod tests {
|
||||||
fn int<'a>() -> Box<&'a Expr<'a>> { Box::new(&HexOctalBinary(0x12)) }
|
fn int<'a>() -> Box<&'a Expr<'a>> { Box::new(&HexOctalBinary(0x12)) }
|
||||||
fn float<'a>() -> Box<&'a Expr<'a>> { Box::new(&FractionalNumber(3.1)) }
|
fn float<'a>() -> Box<&'a Expr<'a>> { Box::new(&FractionalNumber(3.1)) }
|
||||||
fn num<'a>() -> Box<&'a Expr<'a>> { Box::new(&WholeNumber(5)) }
|
fn num<'a>() -> Box<&'a Expr<'a>> { Box::new(&WholeNumber(5)) }
|
||||||
|
|
||||||
|
// TODO test unions that ought to be equivalent, but only after
|
||||||
|
// a reduction of some sort, e.g.
|
||||||
|
//
|
||||||
|
// ((a|b)|c) vs (a|(b|c))
|
||||||
|
//
|
||||||
|
// ((a|z)|(b|z)) vs (a|b|z)
|
||||||
|
//
|
||||||
|
// ideally, we fix these when constructing unions
|
||||||
|
// e.g. if a user puts this in as an annotation, reduce it immediately
|
||||||
|
// and when we're inferring unions, always infer them flat.
|
||||||
|
// This way we can avoid checking recursively.
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue