Impl external expected types for type inference

This commit is contained in:
oxalica 2023-02-03 02:18:31 +08:00
parent e0707ad935
commit c003625948
4 changed files with 193 additions and 70 deletions

View file

@ -3,7 +3,7 @@ mod def;
mod diagnostic;
mod ide;
mod text_edit;
mod ty;
pub(crate) mod ty;
#[cfg(test)]
mod tests;

View file

@ -38,6 +38,8 @@ enum Ty {
List(TyVar),
Lambda(TyVar, TyVar),
Attrset(Attrset),
External(super::Ty),
}
impl Ty {
@ -66,6 +68,14 @@ impl InferenceResult {
}
pub(crate) fn infer_query(db: &dyn TyDatabase, file: FileId) -> Arc<InferenceResult> {
infer_with(db, file, None)
}
pub(crate) fn infer_with(
db: &dyn TyDatabase,
file: FileId,
expect_ty: Option<super::Ty>,
) -> Arc<InferenceResult> {
let module = db.module(file);
let nameres = db.name_resolution(file);
let table = UnionFind::new(module.names().len() + module.exprs().len(), |_| Ty::Unknown);
@ -74,7 +84,10 @@ pub(crate) fn infer_query(db: &dyn TyDatabase, file: FileId) -> Arc<InferenceRes
nameres: &nameres,
table,
};
ctx.infer_expr(module.entry_expr());
let ty = ctx.infer_expr(module.entry_expr());
if let Some(expect_ty) = expect_ty {
ctx.unify_var_ty(ty, Ty::External(expect_ty));
}
Arc::new(ctx.finish())
}
@ -101,10 +114,23 @@ impl<'db> InferCtx<'db> {
TyVar(self.module.names().len() as u32 + u32::from(i.into_raw()))
}
fn import_external(&mut self, ty: super::Ty) -> TyVar {
let ty = match ty {
super::Ty::Unknown => Ty::Unknown,
super::Ty::Bool => Ty::Bool,
super::Ty::Int => Ty::Int,
super::Ty::Float => Ty::Float,
super::Ty::String => Ty::String,
super::Ty::Path => Ty::Path,
super::Ty::List(_) | super::Ty::Lambda(..) | super::Ty::Attrset(_) => Ty::External(ty),
};
TyVar(self.table.push(ty))
}
fn infer_expr(&mut self, e: ExprId) -> TyVar {
let ty = self.infer_expr_inner(e);
let placeholder_ty = self.ty_for_expr(e);
self.unify(placeholder_ty, ty);
self.unify_var(placeholder_ty, ty);
ty
}
@ -134,11 +160,11 @@ impl<'db> InferCtx<'db> {
let param_ty = self.new_ty_var();
if let Some(name) = *name {
self.unify(param_ty, self.ty_for_name(name));
self.unify_var(param_ty, self.ty_for_name(name));
}
if let Some(pat) = pat {
self.unify_kind(param_ty, Ty::Attrset(Attrset::default()));
self.unify_var_ty(param_ty, Ty::Attrset(Attrset::default()));
for &(name, default_expr) in pat.fields.iter() {
// Always infer default_expr.
let default_ty = default_expr.map(|e| self.infer_expr(e));
@ -148,12 +174,12 @@ impl<'db> InferCtx<'db> {
};
let name_ty = self.ty_for_name(name);
if let Some(default_ty) = default_ty {
self.unify(name_ty, default_ty);
self.unify_var(name_ty, default_ty);
}
let field_text = self.module[name].text.clone();
let param_field_ty =
self.infer_set_field(param_ty, field_text, AttrSource::Name(name));
self.unify(param_field_ty, name_ty);
self.unify_var(param_field_ty, name_ty);
}
}
@ -170,10 +196,10 @@ impl<'db> InferCtx<'db> {
}
&Expr::IfThenElse(cond, then, else_) => {
let cond_ty = self.infer_expr(cond);
self.unify_kind(cond_ty, Ty::Bool);
self.unify_var_ty(cond_ty, Ty::Bool);
let then_ty = self.infer_expr(then);
let else_ty = self.infer_expr(else_);
self.unify(then_ty, else_ty);
self.unify_var(then_ty, else_ty);
then_ty
}
&Expr::Binary(op, lhs, rhs) => {
@ -188,15 +214,15 @@ impl<'db> InferCtx<'db> {
match op {
BinaryOpKind::Equal | BinaryOpKind::NotEqual => Ty::Bool.intern(self),
BinaryOpKind::Imply | BinaryOpKind::Or | BinaryOpKind::And => {
self.unify_kind(lhs_ty, Ty::Bool);
self.unify_kind(rhs_ty, Ty::Bool);
self.unify_var_ty(lhs_ty, Ty::Bool);
self.unify_var_ty(rhs_ty, Ty::Bool);
Ty::Bool.intern(self)
}
BinaryOpKind::Less
| BinaryOpKind::Greater
| BinaryOpKind::LessEqual
| BinaryOpKind::GreaterEqual => {
self.unify(lhs_ty, rhs_ty);
self.unify_var(lhs_ty, rhs_ty);
Ty::Bool.intern(self)
}
// TODO: Polymorphism.
@ -205,19 +231,19 @@ impl<'db> InferCtx<'db> {
| BinaryOpKind::Mul
| BinaryOpKind::Div => {
// TODO: Arguments have type: int | float.
self.unify(lhs_ty, rhs_ty);
self.unify_var(lhs_ty, rhs_ty);
lhs_ty
}
BinaryOpKind::Update => {
self.unify_kind(lhs_ty, Ty::Attrset(Attrset::default()));
self.unify_kind(rhs_ty, Ty::Attrset(Attrset::default()));
self.unify(lhs_ty, rhs_ty);
self.unify_var_ty(lhs_ty, Ty::Attrset(Attrset::default()));
self.unify_var_ty(rhs_ty, Ty::Attrset(Attrset::default()));
self.unify_var(lhs_ty, rhs_ty);
lhs_ty
}
BinaryOpKind::Concat => {
let ret_ty = Ty::List(self.new_ty_var()).intern(self);
self.unify(lhs_ty, ret_ty);
self.unify(rhs_ty, ret_ty);
self.unify_var(lhs_ty, ret_ty);
self.unify_var(rhs_ty, ret_ty);
ret_ty
}
}
@ -227,7 +253,7 @@ impl<'db> InferCtx<'db> {
match op {
None => self.new_ty_var(),
Some(UnaryOpKind::Not) => {
self.unify_kind(arg_ty, Ty::Bool);
self.unify_var_ty(arg_ty, Ty::Bool);
Ty::Bool.intern(self)
}
// TODO: The argument is int | bool.
@ -238,9 +264,9 @@ impl<'db> InferCtx<'db> {
let param_ty = self.new_ty_var();
let ret_ty = self.new_ty_var();
let lam_ty = self.infer_expr(lam);
self.unify_kind(lam_ty, Ty::Lambda(param_ty, ret_ty));
self.unify_var_ty(lam_ty, Ty::Lambda(param_ty, ret_ty));
let arg_ty = self.infer_expr(arg);
self.unify(arg_ty, param_ty);
self.unify_var(arg_ty, param_ty);
ret_ty
}
Expr::HasAttr(set_expr, path) => {
@ -248,7 +274,7 @@ impl<'db> InferCtx<'db> {
self.infer_expr(*set_expr);
for &attr in path.iter() {
let attr_ty = self.infer_expr(attr);
self.unify_kind(attr_ty, Ty::String);
self.unify_var_ty(attr_ty, Ty::String);
}
Ty::Bool.intern(self)
}
@ -256,20 +282,20 @@ impl<'db> InferCtx<'db> {
let set_ty = self.infer_expr(*set_expr);
let ret_ty = path.iter().fold(set_ty, |set_ty, &attr| {
let attr_ty = self.infer_expr(attr);
self.unify_kind(attr_ty, Ty::String);
self.unify_var_ty(attr_ty, Ty::String);
match &self.module[attr] {
Expr::Literal(Literal::String(key)) => {
self.infer_set_field(set_ty, key.clone(), AttrSource::Unknown)
}
_ => {
self.unify_kind(set_ty, Ty::Attrset(Attrset::default()));
self.unify_var_ty(set_ty, Ty::Attrset(Attrset::default()));
self.new_ty_var()
}
}
});
if let Some(default_expr) = *default_expr {
let default_ty = self.infer_expr(default_expr);
self.unify(ret_ty, default_ty);
self.unify_var(ret_ty, default_ty);
}
ret_ty
}
@ -277,7 +303,7 @@ impl<'db> InferCtx<'db> {
for &part in parts.iter() {
let ty = self.infer_expr(part);
// FIXME: Parts are coerce-able to string.
self.unify_kind(ty, Ty::String);
self.unify_var_ty(ty, Ty::String);
}
Ty::Path.intern(self)
}
@ -285,7 +311,7 @@ impl<'db> InferCtx<'db> {
for &part in parts.iter() {
let ty = self.infer_expr(part);
// FIXME: Parts are coerce-able to string.
self.unify_kind(ty, Ty::String);
self.unify_var_ty(ty, Ty::String);
}
Ty::String.intern(self)
}
@ -294,7 +320,7 @@ impl<'db> InferCtx<'db> {
let ret_ty = Ty::List(expect_elem_ty).intern(self);
for &elem in elems.iter() {
let elem_ty = self.infer_expr(elem);
self.unify(elem_ty, expect_elem_ty);
self.unify_var(elem_ty, expect_elem_ty);
}
ret_ty
}
@ -330,14 +356,14 @@ impl<'db> InferCtx<'db> {
self.infer_set_field(from_ty, name_text.clone(), AttrSource::Name(name))
}
};
self.unify(name_ty, value_ty);
self.unify_var(name_ty, value_ty);
let src = AttrSource::Name(name);
fields.insert(name_text, (value_ty, src));
}
for &(k, v) in bindings.dynamics.iter() {
let name_ty = self.infer_expr(k);
self.unify_kind(name_ty, Ty::String);
self.unify_var_ty(name_ty, Ty::String);
self.infer_expr(v);
}
@ -357,53 +383,44 @@ impl<'db> InferCtx<'db> {
ent.insert((next_ty, src));
}
},
Ty::External(super::Ty::Attrset(set)) => {
if let Some(ty) = set.get(&field).cloned() {
return self.import_external(ty);
}
}
k @ Ty::Unknown => {
*k = Ty::Attrset(Attrset([(field, (next_ty, src))].into_iter().collect()));
}
Ty::Bool
| Ty::Int
| Ty::Float
| Ty::String
| Ty::Path
| Ty::List(_)
| Ty::Lambda(_, _) => {}
_ => {}
}
self.new_ty_var()
}
/// Unify a type in table with an expected kind.
fn unify_kind(&mut self, a: TyVar, b: Ty) {
match (self.table.get_mut(a.0), b) {
(a @ Ty::Unknown, b) => *a = b,
(&mut Ty::List(a), Ty::List(b)) => self.unify(a, b),
(&mut Ty::Lambda(a1, a2), Ty::Lambda(b1, b2)) => {
self.unify(a1, b1);
self.unify(a2, b2);
}
(Ty::Attrset(_), Ty::Attrset(b)) => {
assert!(b.0.is_empty(), "Never unify_kind an non-empty set");
}
_ => {}
}
fn unify_var_ty(&mut self, var: TyVar, rhs: Ty) {
let lhs = mem::replace(self.table.get_mut(var.0), Ty::Unknown);
let ret = self.unify(lhs, rhs);
*self.table.get_mut(var.0) = ret;
}
fn unify(&mut self, a: TyVar, b: TyVar) {
let (i, other) = self.table.unify(a.0, b.0);
let other = match other {
Some(other) => other,
None => return,
};
let mut a = mem::replace(self.table.get_mut(i), Ty::Unknown);
match (&mut a, other) {
(a @ Ty::Unknown, b) => *a = b,
(&mut Ty::List(a), Ty::List(b)) => {
self.unify(a, b);
fn unify_var(&mut self, lhs: TyVar, rhs: TyVar) {
let (var, rhs) = self.table.unify(lhs.0, rhs.0);
let Some(rhs) = rhs else { return };
self.unify_var_ty(TyVar(var), rhs);
}
fn unify(&mut self, lhs: Ty, rhs: Ty) -> Ty {
match (lhs, rhs) {
(Ty::Unknown, other) | (other, Ty::Unknown) => other,
(Ty::List(a), Ty::List(b)) => {
self.unify_var(a, b);
Ty::List(a)
}
(&mut Ty::Lambda(a1, a2), Ty::Lambda(b1, b2)) => {
self.unify(a1, b1);
self.unify(a2, b2);
(Ty::Lambda(arg1, ret1), Ty::Lambda(arg2, ret2)) => {
self.unify_var(arg1, arg2);
self.unify_var(ret1, ret2);
Ty::Lambda(arg1, ret1)
}
(Ty::Attrset(a), Ty::Attrset(b)) => {
(Ty::Attrset(mut a), Ty::Attrset(b)) => {
for (field, (ty2, src2)) in b.0 {
match a.0.entry(field) {
Entry::Vacant(ent) => {
@ -412,14 +429,34 @@ impl<'db> InferCtx<'db> {
Entry::Occupied(mut ent) => {
let (ty1, src1) = ent.get_mut();
src1.unify(src2);
self.unify(*ty1, ty2);
self.unify_var(*ty1, ty2);
}
}
}
Ty::Attrset(a)
}
_ => {}
(Ty::External(external), local) | (local, Ty::External(external)) => {
match (local, &external) {
(Ty::Lambda(arg1, ret1), super::Ty::Lambda(arg2, ret2)) => {
let arg2 = self.import_external(super::Ty::clone(arg2));
let ret2 = self.import_external(super::Ty::clone(ret2));
self.unify_var(arg1, arg2);
self.unify_var(ret1, ret2);
}
(Ty::Attrset(a), super::Ty::Attrset(b)) => {
for (field, (ty, _)) in &a.0 {
if let Some(field_ty) = b.get(field) {
let var = self.import_external(field_ty.clone());
self.unify_var(*ty, var);
}
}
}
_ => {}
}
Ty::External(external)
}
(lhs, _) => lhs,
}
*self.table.get_mut(i) = a;
}
fn finish(mut self) -> InferenceResult {
@ -494,6 +531,7 @@ impl<'a> Collector<'a> {
.collect();
super::Ty::Attrset(super::Attrset(set).into())
}
Ty::External(ty) => ty,
}
}
}

View file

@ -1,3 +1,31 @@
#[rustfmt::skip]
#[macro_export]
macro_rules! ty {
(?) => { $crate::ty::Ty::Unknown };
(bool) => { $crate::ty::Ty::Int };
(int) => { $crate::ty::Ty::Int };
(float) => { $crate::ty::Ty::Float };
(string) => { $crate::ty::Ty::String };
(path) => { $crate::ty::Ty::Path };
// TODO: More precise type for derivations.
(derivation) => {
$crate::ty::Ty::Attrset(::std::sync::Arc::new($crate::ty::Attrset::default()))
};
(($($inner:tt)*)) => {{ ty!($($inner)*) }};
([$($inner:tt)*]) => { $crate::ty::Ty::List(::std::arc::Arc::new($ty!($($inner)*)))};
({ $($key:literal : $ty:tt),* $(,)? }) => {{
$crate::ty::Ty::Attrset(::std::sync::Arc::new($crate::ty::Attrset::from_internal([
$(($key, ty!($ty)),)*
])))
}};
($arg:tt -> $($ret:tt)*) => {
$crate::ty::Ty::Lambda(
::std::sync::Arc::new(ty!($arg)),
::std::sync::Arc::new(ty!($($ret)*)),
)
};
}
mod fmt;
mod infer;
mod union_find;
@ -65,6 +93,26 @@ impl std::fmt::Debug for Ty {
pub struct Attrset(Box<[(SmolStr, Ty, AttrSource)]>);
impl Attrset {
/// Build an Attrset for internal type schemas.
///
/// # Panics
/// The given iterator must have no duplicated fields, or it'll panic.
#[track_caller]
// FIXME: Currently this is only used in tests.
#[cfg_attr(not(test), allow(dead_code))]
fn from_internal(iter: impl IntoIterator<Item = (&'static str, Ty)>) -> Self {
let mut set = iter
.into_iter()
.map(|(name, ty)| (SmolStr::from(name), ty, AttrSource::Unknown))
.collect::<Box<[_]>>();
set.sort_by(|(lhs, ..), (rhs, ..)| lhs.cmp(rhs));
assert!(
set.windows(2).all(|w| w[0].0 != w[1].0),
"Duplicated fields",
);
Self(set)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}

View file

@ -2,6 +2,9 @@ use crate::tests::TestDB;
use crate::{DefDatabase, TyDatabase};
use expect_test::{expect, Expect};
use super::Ty;
#[track_caller]
fn check(src: &str, expect: Expect) {
let (db, file) = TestDB::single_file(src).unwrap();
let module = db.module(file);
@ -11,10 +14,16 @@ fn check(src: &str, expect: Expect) {
expect.assert_eq(&got);
}
#[track_caller]
fn check_all(src: &str, expect: Expect) {
check_all_expect(src, None, expect);
}
#[track_caller]
fn check_all_expect(src: &str, expect_ty: impl Into<Option<Ty>>, expect: Expect) {
let (db, file) = TestDB::single_file(src).unwrap();
let module = db.module(file);
let infer = db.infer(file);
let infer = super::infer::infer_with(&db, file, expect_ty.into());
let got = module
.names()
.map(|(i, name)| format!("{}: {}\n", name.text, infer.ty_for_name(i).debug()))
@ -140,3 +149,31 @@ fn select() {
"#]],
);
}
#[test]
fn external() {
check_all_expect(
"let a = a; in a",
ty!(int),
expect![[r#"
a: int
: int
"#]],
);
check_all_expect(
"{ stdenv }: stdenv.mkDerivation {
name = undefined;
}",
ty!({
"stdenv": {
"mkDerivation": ({ "name": string } -> derivation),
},
} -> derivation),
expect![[r#"
stdenv: { mkDerivation: { name: string } { } }
name: string
: { stdenv: { mkDerivation: { name: string } { } } } { }
"#]],
);
}