Trying out ena

This commit is contained in:
Richard Feldman 2019-02-01 21:07:52 -05:00
parent 338be03bdd
commit 5635561fca
8 changed files with 535 additions and 44 deletions

View file

@ -51,7 +51,7 @@ pub fn literal_to_string<'a>(literal: &'a Literal<'a>) -> String {
elem_strings.push(literal_to_string(eval(elem_expr)));
}
format!("[ {} ]", elem_strings.join(", "))
format!("[{}]", elem_strings.join(", "))
},
}
}

View file

@ -1,6 +1,9 @@
#![feature(box_patterns)]
pub mod unify;
pub mod interpret;
pub mod repl;
// pub mod unify;
// pub mod interpret;
// pub mod repl;
pub mod solve;
extern crate ena;

View file

@ -1,9 +1,19 @@
use self::Type::*;
use std::collections::BTreeSet;
use self::VarContent::*;
use self::Operator::*;
use ena::unify::UnificationTable;
use ena::unify::UnifyValue;
use ena::unify::InPlace;
pub type Name<'a> = &'a str;
pub type ModuleName<'a> = &'a str;
type UTable<'a> = UnificationTable<InPlace<Variable<'a>>>;
type TypeUnion<'a> = BTreeSet<Type<'a>>;
type VarUnion<'a> = BTreeSet<VarContent<'a>>;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum Type<'a> {
Symbol(&'a str),
@ -11,6 +21,7 @@ pub enum Type<'a> {
Float,
Number,
Function(Box<Type<'a>>, Box<Type<'a>>),
CallOperator(Operator, Box<&'a Type<'a>>, Box<&'a Type<'a>>),
}
@ -34,6 +45,267 @@ pub enum Problem {
Mismatch
}
#[derive(Debug, PartialEq, Clone)]
pub struct Variable<'a> {
content: VarContent<'a>,
rank: u8
}
#[derive(Debug, PartialEq)]
enum VarContent<'a> {
Wildcard,
RigidVar(&'a Name<'a>),
FlexUnion(TypeUnion<'a>),
RigidUnion(TypeUnion<'a>),
Structure(FlatType<'a>),
Mismatch
}
fn unify_rigid<'a>(named: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
match other {
Wildcard => named,
RigidVar(_) => Mismatch,
FlexUnion(_) => Mismatch,
RigidUnion(_) => Mismatch,
Mismatch => other
}
}
fn unify_rigid_union<'a>(rigid_union: &'a VarUnion<'a>, var: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
match other {
Wildcard => var,
RigidVar(_) => Mismatch,
FlexUnion(flex_union) => {
// a flex union can conform to a rigid one, as long as
// as the rigid union contains all the flex union's options
if rigid_union.is_subset(flex_union) {
var
} else {
Mismatch
}
},
RigidUnion(_) => Mismatch,
Mismatch => other
}
}
fn unify_flex_union<'a>(flex_union: &'a VarUnion<'a>, var: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
match other {
Wildcard => var,
RigidVar(_) => Mismatch,
RigidUnion(rigid_union) => {
// a flex union can conform to a rigid one, as long as
// as the rigid union contains all the flex union's options
if rigid_union.is_subset(flex_union) {
other
} else {
Mismatch
}
},
FlexUnion(other_union) => unify_flex_unions(flex_union, var, other_union, other),
Structure(flat_type) => unify_flex_union_with_flat_type(flex_union, flat_type),
Mismatch => other
}
}
fn unify_flex_unions<'a>(my_union: &'a VarUnion<'a>, my_var: &'a VarContent<'a>, other_union: &'a VarUnion<'a>, other_var: &'a VarContent<'a>) -> &'a VarContent<'a> {
// Prioritize not allocating a new BTreeSet if possible.
if my_union == other_union {
return my_var;
}
let types_in_common = my_union.intersection(other_union);
if types_in_common.is_empty() {
Mismatch
} else {
let unified_union: VarUnion<'a> = types_in_common.into_iter().collect();
FlexUnion(unified_union)
}
}
fn actually_unify<'a>(first: &'a VarContent<'a>, second: &'a VarContent<'a>) -> &'a VarContent<'a> {
match first {
// wildcard types defer to whatever the other type happens to be.
Wildcard => second,
FlexUnion(union) => unify_flex_union(union, first, second),
RigidVar(Name) => unify_rigid(first, second),
RigidUnion(union) => unify_rigid_union(union, first, second),
Structure(flat_type) => unify_structure(flat_type, first, second),
// Mismatches propagate.
Mismatch => first
}
}
type CanonicalModuleName = String;
enum FlatType<'a> {
Function(Variable<'a>, Variable<'a>),
// Apply a higher-kinded type constructor by name
// e.g. apply `Array` to the variable `Int` to form `Array Int`
// ApplyTypeConstructor(CanonicalModuleName, Name, &'a Variable<'a>)
Tuple2(Variable<'a>, Variable<'a>),
// Tuple3(Variable<'a>, Variable<'a>, Variable<'a>),
// TupleN(Vec<Variable<'a>>), // Last resort - allocates
// Record1 (Map.Map N.Name Variable) Variable,
}
fn unify_args<'a>(arg1: &'a Variable<'a>, arg2: Variable) -> Result<Vec<Variable<'a>>, Vec<Variable<'a>>> {
guarded_unify(arg1, arg2)
// case subUnify arg1 arg2 of
// Unify k ->
// k vars
// (\vs () -> unifyArgs vs context others1 others2 ok err)
// (\vs () -> unifyArgs vs context others1 others2 err err)
}
fn guarded_unify<'a>(utable: UTable<'a>, left: Variable<'a>, right: Variable<'a>) -> Result<(), ()> {
if utable.unioned(left, right) {
Ok(())
} else {
let left_descriptor = utable.probe_key(left);
let right_descriptor = utable.probe_key(right);
actually_unify(left, left_descriptor, right, right_descriptor)
}
}
pub fn unify_structure<'a>(utable: &'a mut UTable<'a>, flat_type: &'a FlatType<'a>, var: &'a VarContent<'a>, other: &'a VarContent<'a>) -> &'a VarContent<'a> {
match other {
Wildcard => var,
RigidVar(_) => Mismatch,
FlexUnion(union) => unify_flex_union_with_flat_type(flex_union, flat_type),
RigidUnion(_) => Mismatch,
Structure(other_flat_type) =>
match (flat_type, other) {
(FlatType::Function(my_arg, my_return),
FlatType::Function(other_arg, other_return)) => {
guarded_unify(utable, my_arg, other_arg);
guarded_unify(utable, my_returned, other_returned);
},
(FlatType::Tuple2(my_first, my_second),
FlatType::Tuple2(other_first, other_second)) => {
guarded_unify(utable, my_first, other_first);
guarded_unify(utable, my_second, other_second);
}
}
Mismatch =>
other
}
}
fn unify_flex_union_with_flat_type<'a>(utable: &'a mut UTable<'a>, flex_union: &'a VarUnion<'a>, flat_type: &'a FlatType<'a>) -> &'a VarContent<'a> {
if var_union_contains(flex_union, flat_type) {
// This will use the UnifyValue trait to unify the values.
utable.union(var1, var2);
} else {
Mismatch
}
}
type ExpectedType<'a> = Type<'a>;
pub enum Constraint<'a> {
True,
Equal(Type<'a>, ExpectedType<'a>),
Batch(Vec<Constraint<'a>>),
}
pub fn infer_type<'a>(expr: Expr<'a>) -> Result<Type<'a>, Problem> {
Err(Problem::Mismatch)
}
struct State {
errors: Vec<String>
}
impl<'a> UnifyValue for Variable<'a> {
// We return our own Mismatch variant to track errors.
type Error = ena::unify::NoError;
fn unify_values(value1: &'a Variable<'a>, value2: &'a Variable<'a>) -> Result<Variable<'a>, ena::unify::NoError> {
// TODO unify 'em
// TODO problem: Elm's unification mutates and looks things up as it goes.
// I can see these possible ways to proceed:
// (1) Try to have the table's values contain a mutable reference to the table itself.
// This sounds like a mistake.
// (2) Implement unification without mutating as we go.
// Might be too slow, and might not even work.
// Like, what if I need to look something up in the middle?
// (3) Make a custom fork of ena that supports Elm's way.
// (3a) Change the unify_values function to accept the table itself, so it can be
// passed in and used during unification
// (3b) Change the unify_values function to accept the table itself, so it can be
// passed in and used during unification. I'm not super confident this would work.
//
// Possibly before doing any of this, I should look at ena's examples/tests
// TODO also I'm pretty sure in this implementation,
// I'm supposed to let them take care of the rank.
Ok(Variable {content, rank: min(rank1, rank2)})
}
}
fn type_to_var(rank: u8, typ: Type) -> Variable {
match typ {
Type::CallOperator(op, left_type, right_type) => {
let left_var = type_to_var(left_type);
let right_var = type_to_var(right_type);
// TODO should we match on op to hardcode the types we expect?
let flat_type = FlatType::Function(left_var, right_var);
let content = Structure(flat_type);
utable.new_key(Variable {rank, content})
}
}
}
pub fn unify(utable: Table, left_var: Variable, right_var: Variable) -> Result<(), ()>{
let left_content = utable.probe_value(left_var);
let right_content = utable.probe_value(right_var);
if left_content == right_content {
Ok(())
} else {
Ok(actually_unify(left, left_desc, right, right_desc))
}
}
pub fn solve(rank: u8, state: State, constraint: Constraint) {
match constraint {
True =>
state
Equal(actual_type, expectation) => {
let actual_var = type_to_var(rank, actual_type)
let expected_var = type_to_var(rank, expectation)
let answer = unify(actual_var, expected_var)
match answer {
Ok vars ->
panic!("TODO abc");
// do introduce rank pools vars
// return state
// UF.modify var $ \(Descriptor content _ mark copy) ->
// Descriptor content rank mark copy
// Unify.Err vars actualType expectedType ->
// panic!("TODO xyz");
// do introduce rank pools vars
// return $ addError state $
// Error.BadExpr region category actualType $
// Error.typeReplace expectation expectedType
}
}
}
}

View file

@ -1,8 +1,10 @@
use std::collections::BTreeSet;
use std::collections::BTreeMap;
use self::Type::*;
pub type Ident<'a> = &'a str;
pub type Name<'a> = &'a str;
pub type Field<'a> = &'a str;
pub type ModuleName<'a> = &'a str;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum Type<'a> {
@ -14,9 +16,9 @@ pub enum Type<'a> {
Number,
Symbol(&'a str),
Array(Box<Type<'a>>),
Record(Vec<(Field<'a>, Type<'a>)>),
Function(Box<Type<'a>>, Box<Type<'a>>),
Record(BTreeMap<Name<'a>, Type<'a>>),
Tuple(Vec<Type<'a>>),
Assignment(Ident<'a>, Box<Type<'a>>),
Union(BTreeSet<Type<'a>>),
}
@ -25,39 +27,73 @@ pub enum Type<'a> {
#[derive(Debug, PartialEq)]
pub enum Expr<'a> {
Literal(&'a Literal<'a>),
Assignment(Ident<'a>, Box<&'a Expr<'a>>),
If(Box<&'a Expr<'a>> /* Conditional */, Box<&'a Expr<'a>> /* True branch */, Box<&'a Expr<'a>> /* False branch */),
// Variables
Declaration(&'a Pattern<'a>, Box<&'a Expr<'a>>, Box<Expr<'a>>),
LookupLocal(&'a Name<'a>),
LookupGlobal(&'a ModuleName<'a>, &'a Name<'a>),
// Scalars
Symbol(&'a str),
String(&'a str),
Char(char),
HexOctalBinary(i64), // : Int
FractionalNumber(f64), // : Float
WholeNumber(i64), // : Int | Float
// Collections
Array(Vec<Expr<'a>>),
Record(Vec<(&'a Name<'a>, &'a Expr<'a>)>),
Tuple(Vec<&'a Expr<'a>>),
LookupName(Name<'a>, Box<&'a Expr<'a>>),
// TODO add record update
// TODO add conditional
// TODO add function
// Functions
Function(&'a Pattern<'a>, &'a Expr<'a>),
Call(Box<&'a Expr<'a>>, Box<&'a Expr<'a>>),
CallOperator(&'a Operator, Box<&'a Expr<'a>>, Box<&'a Expr<'a>>),
// Conditionals
If(Box<&'a Expr<'a>> /* Conditional */, Box<&'a Expr<'a>> /* True branch */, Box<&'a Expr<'a>> /* False branch */),
Case(Box<&'a Expr<'a>>, Vec<(&'a Pattern<'a>, &'a Expr<'a>)>),
}
#[derive(Debug, PartialEq)]
pub enum Literal<'a> {
pub enum Operator {
Plus, Minus, Star, Caret, Percent, FloatDivision, IntDivision,
GT, GTE, LT, LTE,
EQ, NE, And, Or,
QuestionMark, Or
}
#[derive(Debug, PartialEq)]
pub enum Pattern<'a> {
Name(&'a Name<'a>), // `foo =`
As(&'a Name<'a>, &'a Pattern<'a>), // `<pattern> as foo`
Type(&'a Type<'a>),
Symbol(&'a str),
String(&'a str),
Char(char),
Number(&'a str),
WholeNumber(&'a str),
FractionalNumber(&'a str),
HexOctalBinary(&'a str),
Symbol(&'a str),
Array(Vec<Expr<'a>>),
Record(Vec<(Field<'a>, &'a Expr<'a>)>),
Tuple(Vec<&'a Expr<'a>>)
Tuple(Vec<Pattern<'a>>),
Record(Vec<(Name<'a>, Option<Pattern<'a>>)>), // { a = 5, b : Int as x, c }
}
pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
match expr {
Expr::Literal(Literal::String(_)) => Ok(Type::String),
Expr::Literal(Literal::HexOctalBinary(_)) => Ok(Type::Int),
Expr::Literal(Literal::Char(_)) => Ok(Type::Char),
Expr::Literal(Literal::Number(_)) => Ok(Type::Number),
Expr::Literal(Literal::Symbol(sym)) => Ok(Type::Symbol(sym)),
Expr::Literal(Literal::Array(elem_exprs)) => {
Expr::String(_) => Ok(String),
Expr::Char(_) => Ok(Char),
Expr::HexOctalBinary(_) => Ok(Int),
Expr::FractionalNumber(_) => Ok(Float),
Expr::WholeNumber(_) => Ok(Number),
Expr::Symbol(sym) => Ok(Symbol(sym)),
Expr::Array(elem_exprs) => {
let elem_type;
if elem_exprs.is_empty() {
elem_type = Type::Unbound;
elem_type = Unbound;
} else {
let mut unified_type = BTreeSet::new();
@ -70,24 +106,24 @@ pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
// No point in storing a union of 1.
elem_type = unified_type.into_iter().next().unwrap()
} else {
elem_type = Type::Union(unified_type)
elem_type = Union(unified_type)
}
}
Ok(Type::Array(Box::new(elem_type)))
Ok(Array(Box::new(elem_type)))
},
Expr::Literal(Literal::Record(fields)) => {
let mut rec_type: Vec<(&'a str, Type<'a>)> = Vec::new();
Expr::Record(fields) => {
let mut rec_type: BTreeMap<&'a Name<'a>, Type<'a>> = BTreeMap::new();
for (field, subexpr) in fields {
let field_type = infer(subexpr)?;
rec_type.push((&field, field_type));
rec_type.insert(&field, field_type);
}
Ok(Type::Record(rec_type))
Ok(Record(rec_type))
},
Expr::Literal(Literal::Tuple(exprs)) => {
Expr::Tuple(exprs) => {
let mut tuple_type: Vec<Type<'a>> = Vec::new();
for subexpr in exprs {
@ -96,7 +132,7 @@ pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
tuple_type.push(field_type);
}
Ok(Type::Tuple(tuple_type))
Ok(Tuple(tuple_type))
},
Expr::If(box cond, expr_if_true, expr_if_false) => {
let cond_type = infer(&cond)?;
@ -122,29 +158,197 @@ pub fn infer<'a>(expr: &Expr<'a>) -> Result<Type<'a>, UnificationProblem> {
// but we can pull it back out of the set
Ok(unified_type.into_iter().next().unwrap())
} else {
Ok(Type::Union(unified_type))
Ok(Union(unified_type))
}
},
Expr::Assignment(ident, subexpr) => {
Ok(Type::Assignment(ident, Box::new(infer(subexpr)?)))
Call(func, arg) => {
},
CallOperator(op, left_expr, right_expr) => {
let left = &(infer(left_expr)?);
let right = &(infer(right_expr)?);
match op {
Operator::EQ | Operator::NE | Operator::And | Operator::Or => {
if types_match(left, right) {
conform_to_bool(left)
} else {
Err(UnificationProblem::TypeMismatch)
}
},
Operator::Plus | Operator::Minus | Operator::Star
| Operator::GT | Operator::LT | Operator::GTE | Operator::LTE
| Operator::Caret | Operator::Percent => {
if types_match(left, right) {
conform_to_number(left)
} else {
Err(UnificationProblem::TypeMismatch)
}
},
Operator::FloatDivision => {
if matches_float_type(left) && matches_float_type(right) {
Ok(&Float)
} else {
Err(UnificationProblem::TypeMismatch)
}
},
Operator::IntDivision => {
if matches_int_type(left) && matches_int_type(right) {
Ok(&Int)
} else {
Err(UnificationProblem::TypeMismatch)
}
},
Operator::CombineStrings => {
if matches_string_type(left) && matches_string_type(right) {
Ok(&String)
} else {
Err(UnificationProblem::TypeMismatch)
}
},
Operator::QuestionMark => {
if types_match(left, right) {
conform_to_optional(left)
} else {
Err(UnificationProblem::TypeMismatch)
}
}
}
},
Expr::Declaration(pattern, let_expr, in_expr) => {
// Think of this as a let..in even though syntactically it's not.
// We need to type-check the let-binding, but the type of the
// *expression* we're expaning is only affected by the in-block.
check_pattern(&pattern, &let_expr)?;
infer(in_expr)
}
}
}
fn types_match<'a>(first: &'a Type<'a>, second: &'a Type<'a>) -> bool {
match (first, second) {
(Type::Union(first_types), Type::Union(second_types)) => {
// If any type is not directly present in the other union,
// it must at least match *some* type in the other union
first_types.difference(second_types).into_iter().all(|not_in_second_type| {
second_types.iter().any(|second_type| types_match(second_type, not_in_second_type))
}) &&
second_types.difference(first_types).into_iter().all(|not_in_first_type| {
first_types.iter().any(|first_type| types_match(first_type, not_in_first_type))
})
},
// Happy path: try these first, since we expect them to succeed.
// These are sorted based on a vague guess of how often they will be used in practice.
(Type::Symbol(sym_one), Type::Symbol(sym_two)) => sym_one == sym_two,
(Type::String, Type::String) => true,
(Type::Unbound, _) | (_, Type::Unbound)=> true,
(Type::Array(box elem_type_one), Type::Array(box elem_type_two)) => {
types_match(elem_type_one, elem_type_two)
},
(Type::Number, Type::Number) => true,
(Type::Number, other) => matches_number_type(other),
(other, Type::Number) => matches_number_type(other),
(Type::Int, Type::Int) => true,
(Type::Float, Type::Float) => true,
(Type::Tuple(first_elems), Type::Tuple(second_elems)) => {
// TODO verify that the elems and their types match up
// TODO write some scenarios to understand these better -
// like, what happens if you have a function that takes
// a lambda whose argument takes an open record,
// and you pass a lamba whose argument takes *fewer* fields?
// that should work! the function is gonna pass it a lambda that
// has more fields than it needs.
// I think there's an element of directionality here that I'm
// disregarding. Maybe this function shouldn't commute.
},
(Type::Function(first_arg), Type::Function(second_arg)) => {
// TODO verify that the elems and their types match up
},
(Type::Record(first_fields), Type::Record(second_fields)) => {
// TODO verify that the fields and their types match up
// TODO what should happen if one is a superset of the other? fail?
},
(Type::Char, Type::Char) => true,
// Unhappy path - expect these to fail, so check them last
(Type::Union(first_types), _) => {
first_types.iter().all(|typ| types_match(typ, second))
},
(_, Type::Union(second_types)) => {
second_types.iter().all(|typ| types_match(first, typ))
},
(Type::String, _) | (_, Type::String) => false,
(Type::Char, _) | (_, Type::Char) => false,
(Type::Int, _) | (_, Type::Int) => false,
(Type::Float, _) | (_, Type::Float) => false,
(Type::Symbol(_), _) | (_, Type::Symbol(_)) => false,
(Type::Array(_), _) | (_, Type::Array(_)) => false,
(Type::Record(_), _) | (_, Type::Record(_)) => false,
(Type::Tuple(_), _) | (_, Type::Tuple(_)) => false,
(Type::Function(_, _), _) | (_, Type::Function(_, _)) => false,
}
}
fn check_pattern<'a>(pattern: &'a Pattern<'a>, expr: &'a Expr<'a>) -> Result<(), UnificationProblem> {
let expr_type = infer(expr)?;
panic!("TODO check the pattern's type against expr_type, then write some tests for funky record pattern cases - this is our first real unification! Next one will be field access, ooooo - gonna want lots of tests for that")
}
const TRUE_SYMBOL_STR: &'static str = "True";
const FALSE_SYMBOL_STR: &'static str = "False";
pub fn matches_string_type<'a>(candidate: &Type<'a>) -> bool {
match candidate {
Unbound | String => true,
Type::Union(types) => {
types.iter().all(|typ| matches_string_type(typ))
},
_ => Err(UnificationProblem::TypeMismatch)
}
}
pub fn matches_bool_type<'a>(candidate: &Type<'a>) -> bool {
match candidate {
Type::Symbol(str) => {
str == &TRUE_SYMBOL_STR || str == &FALSE_SYMBOL_STR
}
Type::Unbound => true,
Type::Symbol(str) => str == &TRUE_SYMBOL_STR || str == &FALSE_SYMBOL_STR,
Type::Union(types) => {
types.len() <= 2 && types.iter().all(|typ| matches_bool_type(typ))
types.iter().all(|typ| matches_bool_type(typ))
}
_ => {
false
_ => false
}
}
pub fn matches_number_type<'a>(candidate: &Type<'a>) -> bool {
match candidate {
Type::Unbound | Type::Int | Type::Float | Type::Number => true,
Type::Union(types) => {
types.iter().all(|typ| matches_number_type(typ))
}
_ => false
}
}
pub fn matches_int_type<'a>(candidate: &Type<'a>) -> bool {
match candidate {
Type::Unbound | Type::Int => true,
Type::Union(types) => {
types.iter().all(|typ| matches_int_type(typ))
}
_ => false
}
}
pub fn matches_float_type<'a>(candidate: &Type<'a>) -> bool {
match candidate {
Type::Unbound | Type::Float => true,
Type::Union(types) => {
types.iter().all(|typ| matches_float_type(typ))
}
_ => false
}
}

View file

@ -71,4 +71,16 @@ mod tests {
fn int<'a>() -> Box<&'a Expr<'a>> { Box::new(&HexOctalBinary(0x12)) }
fn float<'a>() -> Box<&'a Expr<'a>> { Box::new(&FractionalNumber(3.1)) }
fn num<'a>() -> Box<&'a Expr<'a>> { Box::new(&WholeNumber(5)) }
// TODO test unions that ought to be equivalent, but only after
// a reduction of some sort, e.g.
//
// ((a|b)|c) vs (a|(b|c))
//
// ((a|z)|(b|z)) vs (a|b|z)
//
// ideally, we fix these when constructing unions
// e.g. if a user puts this in as an annotation, reduce it immediately
// and when we're inferring unions, always infer them flat.
// This way we can avoid checking recursively.
}