diff --git a/crates/ide/src/lib.rs b/crates/ide/src/lib.rs index 301380d..296a0d6 100644 --- a/crates/ide/src/lib.rs +++ b/crates/ide/src/lib.rs @@ -3,7 +3,7 @@ mod def; mod diagnostic; mod ide; mod text_edit; -mod ty; +pub(crate) mod ty; #[cfg(test)] mod tests; diff --git a/crates/ide/src/ty/infer.rs b/crates/ide/src/ty/infer.rs index 02ba830..dded443 100644 --- a/crates/ide/src/ty/infer.rs +++ b/crates/ide/src/ty/infer.rs @@ -38,6 +38,8 @@ enum Ty { List(TyVar), Lambda(TyVar, TyVar), Attrset(Attrset), + + External(super::Ty), } impl Ty { @@ -66,6 +68,14 @@ impl InferenceResult { } pub(crate) fn infer_query(db: &dyn TyDatabase, file: FileId) -> Arc { + infer_with(db, file, None) +} + +pub(crate) fn infer_with( + db: &dyn TyDatabase, + file: FileId, + expect_ty: Option, +) -> Arc { let module = db.module(file); let nameres = db.name_resolution(file); let table = UnionFind::new(module.names().len() + module.exprs().len(), |_| Ty::Unknown); @@ -74,7 +84,10 @@ pub(crate) fn infer_query(db: &dyn TyDatabase, file: FileId) -> Arc InferCtx<'db> { TyVar(self.module.names().len() as u32 + u32::from(i.into_raw())) } + fn import_external(&mut self, ty: super::Ty) -> TyVar { + let ty = match ty { + super::Ty::Unknown => Ty::Unknown, + super::Ty::Bool => Ty::Bool, + super::Ty::Int => Ty::Int, + super::Ty::Float => Ty::Float, + super::Ty::String => Ty::String, + super::Ty::Path => Ty::Path, + super::Ty::List(_) | super::Ty::Lambda(..) | super::Ty::Attrset(_) => Ty::External(ty), + }; + TyVar(self.table.push(ty)) + } + fn infer_expr(&mut self, e: ExprId) -> TyVar { let ty = self.infer_expr_inner(e); let placeholder_ty = self.ty_for_expr(e); - self.unify(placeholder_ty, ty); + self.unify_var(placeholder_ty, ty); ty } @@ -134,11 +160,11 @@ impl<'db> InferCtx<'db> { let param_ty = self.new_ty_var(); if let Some(name) = *name { - self.unify(param_ty, self.ty_for_name(name)); + self.unify_var(param_ty, self.ty_for_name(name)); } if let Some(pat) = pat { - self.unify_kind(param_ty, Ty::Attrset(Attrset::default())); + self.unify_var_ty(param_ty, Ty::Attrset(Attrset::default())); for &(name, default_expr) in pat.fields.iter() { // Always infer default_expr. let default_ty = default_expr.map(|e| self.infer_expr(e)); @@ -148,12 +174,12 @@ impl<'db> InferCtx<'db> { }; let name_ty = self.ty_for_name(name); if let Some(default_ty) = default_ty { - self.unify(name_ty, default_ty); + self.unify_var(name_ty, default_ty); } let field_text = self.module[name].text.clone(); let param_field_ty = self.infer_set_field(param_ty, field_text, AttrSource::Name(name)); - self.unify(param_field_ty, name_ty); + self.unify_var(param_field_ty, name_ty); } } @@ -170,10 +196,10 @@ impl<'db> InferCtx<'db> { } &Expr::IfThenElse(cond, then, else_) => { let cond_ty = self.infer_expr(cond); - self.unify_kind(cond_ty, Ty::Bool); + self.unify_var_ty(cond_ty, Ty::Bool); let then_ty = self.infer_expr(then); let else_ty = self.infer_expr(else_); - self.unify(then_ty, else_ty); + self.unify_var(then_ty, else_ty); then_ty } &Expr::Binary(op, lhs, rhs) => { @@ -188,15 +214,15 @@ impl<'db> InferCtx<'db> { match op { BinaryOpKind::Equal | BinaryOpKind::NotEqual => Ty::Bool.intern(self), BinaryOpKind::Imply | BinaryOpKind::Or | BinaryOpKind::And => { - self.unify_kind(lhs_ty, Ty::Bool); - self.unify_kind(rhs_ty, Ty::Bool); + self.unify_var_ty(lhs_ty, Ty::Bool); + self.unify_var_ty(rhs_ty, Ty::Bool); Ty::Bool.intern(self) } BinaryOpKind::Less | BinaryOpKind::Greater | BinaryOpKind::LessEqual | BinaryOpKind::GreaterEqual => { - self.unify(lhs_ty, rhs_ty); + self.unify_var(lhs_ty, rhs_ty); Ty::Bool.intern(self) } // TODO: Polymorphism. @@ -205,19 +231,19 @@ impl<'db> InferCtx<'db> { | BinaryOpKind::Mul | BinaryOpKind::Div => { // TODO: Arguments have type: int | float. - self.unify(lhs_ty, rhs_ty); + self.unify_var(lhs_ty, rhs_ty); lhs_ty } BinaryOpKind::Update => { - self.unify_kind(lhs_ty, Ty::Attrset(Attrset::default())); - self.unify_kind(rhs_ty, Ty::Attrset(Attrset::default())); - self.unify(lhs_ty, rhs_ty); + self.unify_var_ty(lhs_ty, Ty::Attrset(Attrset::default())); + self.unify_var_ty(rhs_ty, Ty::Attrset(Attrset::default())); + self.unify_var(lhs_ty, rhs_ty); lhs_ty } BinaryOpKind::Concat => { let ret_ty = Ty::List(self.new_ty_var()).intern(self); - self.unify(lhs_ty, ret_ty); - self.unify(rhs_ty, ret_ty); + self.unify_var(lhs_ty, ret_ty); + self.unify_var(rhs_ty, ret_ty); ret_ty } } @@ -227,7 +253,7 @@ impl<'db> InferCtx<'db> { match op { None => self.new_ty_var(), Some(UnaryOpKind::Not) => { - self.unify_kind(arg_ty, Ty::Bool); + self.unify_var_ty(arg_ty, Ty::Bool); Ty::Bool.intern(self) } // TODO: The argument is int | bool. @@ -238,9 +264,9 @@ impl<'db> InferCtx<'db> { let param_ty = self.new_ty_var(); let ret_ty = self.new_ty_var(); let lam_ty = self.infer_expr(lam); - self.unify_kind(lam_ty, Ty::Lambda(param_ty, ret_ty)); + self.unify_var_ty(lam_ty, Ty::Lambda(param_ty, ret_ty)); let arg_ty = self.infer_expr(arg); - self.unify(arg_ty, param_ty); + self.unify_var(arg_ty, param_ty); ret_ty } Expr::HasAttr(set_expr, path) => { @@ -248,7 +274,7 @@ impl<'db> InferCtx<'db> { self.infer_expr(*set_expr); for &attr in path.iter() { let attr_ty = self.infer_expr(attr); - self.unify_kind(attr_ty, Ty::String); + self.unify_var_ty(attr_ty, Ty::String); } Ty::Bool.intern(self) } @@ -256,20 +282,20 @@ impl<'db> InferCtx<'db> { let set_ty = self.infer_expr(*set_expr); let ret_ty = path.iter().fold(set_ty, |set_ty, &attr| { let attr_ty = self.infer_expr(attr); - self.unify_kind(attr_ty, Ty::String); + self.unify_var_ty(attr_ty, Ty::String); match &self.module[attr] { Expr::Literal(Literal::String(key)) => { self.infer_set_field(set_ty, key.clone(), AttrSource::Unknown) } _ => { - self.unify_kind(set_ty, Ty::Attrset(Attrset::default())); + self.unify_var_ty(set_ty, Ty::Attrset(Attrset::default())); self.new_ty_var() } } }); if let Some(default_expr) = *default_expr { let default_ty = self.infer_expr(default_expr); - self.unify(ret_ty, default_ty); + self.unify_var(ret_ty, default_ty); } ret_ty } @@ -277,7 +303,7 @@ impl<'db> InferCtx<'db> { for &part in parts.iter() { let ty = self.infer_expr(part); // FIXME: Parts are coerce-able to string. - self.unify_kind(ty, Ty::String); + self.unify_var_ty(ty, Ty::String); } Ty::Path.intern(self) } @@ -285,7 +311,7 @@ impl<'db> InferCtx<'db> { for &part in parts.iter() { let ty = self.infer_expr(part); // FIXME: Parts are coerce-able to string. - self.unify_kind(ty, Ty::String); + self.unify_var_ty(ty, Ty::String); } Ty::String.intern(self) } @@ -294,7 +320,7 @@ impl<'db> InferCtx<'db> { let ret_ty = Ty::List(expect_elem_ty).intern(self); for &elem in elems.iter() { let elem_ty = self.infer_expr(elem); - self.unify(elem_ty, expect_elem_ty); + self.unify_var(elem_ty, expect_elem_ty); } ret_ty } @@ -330,14 +356,14 @@ impl<'db> InferCtx<'db> { self.infer_set_field(from_ty, name_text.clone(), AttrSource::Name(name)) } }; - self.unify(name_ty, value_ty); + self.unify_var(name_ty, value_ty); let src = AttrSource::Name(name); fields.insert(name_text, (value_ty, src)); } for &(k, v) in bindings.dynamics.iter() { let name_ty = self.infer_expr(k); - self.unify_kind(name_ty, Ty::String); + self.unify_var_ty(name_ty, Ty::String); self.infer_expr(v); } @@ -357,53 +383,44 @@ impl<'db> InferCtx<'db> { ent.insert((next_ty, src)); } }, + Ty::External(super::Ty::Attrset(set)) => { + if let Some(ty) = set.get(&field).cloned() { + return self.import_external(ty); + } + } k @ Ty::Unknown => { *k = Ty::Attrset(Attrset([(field, (next_ty, src))].into_iter().collect())); } - Ty::Bool - | Ty::Int - | Ty::Float - | Ty::String - | Ty::Path - | Ty::List(_) - | Ty::Lambda(_, _) => {} + _ => {} } self.new_ty_var() } - /// Unify a type in table with an expected kind. - fn unify_kind(&mut self, a: TyVar, b: Ty) { - match (self.table.get_mut(a.0), b) { - (a @ Ty::Unknown, b) => *a = b, - (&mut Ty::List(a), Ty::List(b)) => self.unify(a, b), - (&mut Ty::Lambda(a1, a2), Ty::Lambda(b1, b2)) => { - self.unify(a1, b1); - self.unify(a2, b2); - } - (Ty::Attrset(_), Ty::Attrset(b)) => { - assert!(b.0.is_empty(), "Never unify_kind an non-empty set"); - } - _ => {} - } + fn unify_var_ty(&mut self, var: TyVar, rhs: Ty) { + let lhs = mem::replace(self.table.get_mut(var.0), Ty::Unknown); + let ret = self.unify(lhs, rhs); + *self.table.get_mut(var.0) = ret; } - fn unify(&mut self, a: TyVar, b: TyVar) { - let (i, other) = self.table.unify(a.0, b.0); - let other = match other { - Some(other) => other, - None => return, - }; - let mut a = mem::replace(self.table.get_mut(i), Ty::Unknown); - match (&mut a, other) { - (a @ Ty::Unknown, b) => *a = b, - (&mut Ty::List(a), Ty::List(b)) => { - self.unify(a, b); + fn unify_var(&mut self, lhs: TyVar, rhs: TyVar) { + let (var, rhs) = self.table.unify(lhs.0, rhs.0); + let Some(rhs) = rhs else { return }; + self.unify_var_ty(TyVar(var), rhs); + } + + fn unify(&mut self, lhs: Ty, rhs: Ty) -> Ty { + match (lhs, rhs) { + (Ty::Unknown, other) | (other, Ty::Unknown) => other, + (Ty::List(a), Ty::List(b)) => { + self.unify_var(a, b); + Ty::List(a) } - (&mut Ty::Lambda(a1, a2), Ty::Lambda(b1, b2)) => { - self.unify(a1, b1); - self.unify(a2, b2); + (Ty::Lambda(arg1, ret1), Ty::Lambda(arg2, ret2)) => { + self.unify_var(arg1, arg2); + self.unify_var(ret1, ret2); + Ty::Lambda(arg1, ret1) } - (Ty::Attrset(a), Ty::Attrset(b)) => { + (Ty::Attrset(mut a), Ty::Attrset(b)) => { for (field, (ty2, src2)) in b.0 { match a.0.entry(field) { Entry::Vacant(ent) => { @@ -412,14 +429,34 @@ impl<'db> InferCtx<'db> { Entry::Occupied(mut ent) => { let (ty1, src1) = ent.get_mut(); src1.unify(src2); - self.unify(*ty1, ty2); + self.unify_var(*ty1, ty2); } } } + Ty::Attrset(a) } - _ => {} + (Ty::External(external), local) | (local, Ty::External(external)) => { + match (local, &external) { + (Ty::Lambda(arg1, ret1), super::Ty::Lambda(arg2, ret2)) => { + let arg2 = self.import_external(super::Ty::clone(arg2)); + let ret2 = self.import_external(super::Ty::clone(ret2)); + self.unify_var(arg1, arg2); + self.unify_var(ret1, ret2); + } + (Ty::Attrset(a), super::Ty::Attrset(b)) => { + for (field, (ty, _)) in &a.0 { + if let Some(field_ty) = b.get(field) { + let var = self.import_external(field_ty.clone()); + self.unify_var(*ty, var); + } + } + } + _ => {} + } + Ty::External(external) + } + (lhs, _) => lhs, } - *self.table.get_mut(i) = a; } fn finish(mut self) -> InferenceResult { @@ -494,6 +531,7 @@ impl<'a> Collector<'a> { .collect(); super::Ty::Attrset(super::Attrset(set).into()) } + Ty::External(ty) => ty, } } } diff --git a/crates/ide/src/ty/mod.rs b/crates/ide/src/ty/mod.rs index bef879a..94e578d 100644 --- a/crates/ide/src/ty/mod.rs +++ b/crates/ide/src/ty/mod.rs @@ -1,3 +1,31 @@ +#[rustfmt::skip] +#[macro_export] +macro_rules! ty { + (?) => { $crate::ty::Ty::Unknown }; + (bool) => { $crate::ty::Ty::Int }; + (int) => { $crate::ty::Ty::Int }; + (float) => { $crate::ty::Ty::Float }; + (string) => { $crate::ty::Ty::String }; + (path) => { $crate::ty::Ty::Path }; + // TODO: More precise type for derivations. + (derivation) => { + $crate::ty::Ty::Attrset(::std::sync::Arc::new($crate::ty::Attrset::default())) + }; + (($($inner:tt)*)) => {{ ty!($($inner)*) }}; + ([$($inner:tt)*]) => { $crate::ty::Ty::List(::std::arc::Arc::new($ty!($($inner)*)))}; + ({ $($key:literal : $ty:tt),* $(,)? }) => {{ + $crate::ty::Ty::Attrset(::std::sync::Arc::new($crate::ty::Attrset::from_internal([ + $(($key, ty!($ty)),)* + ]))) + }}; + ($arg:tt -> $($ret:tt)*) => { + $crate::ty::Ty::Lambda( + ::std::sync::Arc::new(ty!($arg)), + ::std::sync::Arc::new(ty!($($ret)*)), + ) + }; +} + mod fmt; mod infer; mod union_find; @@ -65,6 +93,26 @@ impl std::fmt::Debug for Ty { pub struct Attrset(Box<[(SmolStr, Ty, AttrSource)]>); impl Attrset { + /// Build an Attrset for internal type schemas. + /// + /// # Panics + /// The given iterator must have no duplicated fields, or it'll panic. + #[track_caller] + // FIXME: Currently this is only used in tests. + #[cfg_attr(not(test), allow(dead_code))] + fn from_internal(iter: impl IntoIterator) -> Self { + let mut set = iter + .into_iter() + .map(|(name, ty)| (SmolStr::from(name), ty, AttrSource::Unknown)) + .collect::>(); + set.sort_by(|(lhs, ..), (rhs, ..)| lhs.cmp(rhs)); + assert!( + set.windows(2).all(|w| w[0].0 != w[1].0), + "Duplicated fields", + ); + Self(set) + } + pub fn is_empty(&self) -> bool { self.0.is_empty() } diff --git a/crates/ide/src/ty/tests.rs b/crates/ide/src/ty/tests.rs index a1ae488..aa995b3 100644 --- a/crates/ide/src/ty/tests.rs +++ b/crates/ide/src/ty/tests.rs @@ -2,6 +2,9 @@ use crate::tests::TestDB; use crate::{DefDatabase, TyDatabase}; use expect_test::{expect, Expect}; +use super::Ty; + +#[track_caller] fn check(src: &str, expect: Expect) { let (db, file) = TestDB::single_file(src).unwrap(); let module = db.module(file); @@ -11,10 +14,16 @@ fn check(src: &str, expect: Expect) { expect.assert_eq(&got); } +#[track_caller] fn check_all(src: &str, expect: Expect) { + check_all_expect(src, None, expect); +} + +#[track_caller] +fn check_all_expect(src: &str, expect_ty: impl Into>, expect: Expect) { let (db, file) = TestDB::single_file(src).unwrap(); let module = db.module(file); - let infer = db.infer(file); + let infer = super::infer::infer_with(&db, file, expect_ty.into()); let got = module .names() .map(|(i, name)| format!("{}: {}\n", name.text, infer.ty_for_name(i).debug())) @@ -140,3 +149,31 @@ fn select() { "#]], ); } + +#[test] +fn external() { + check_all_expect( + "let a = a; in a", + ty!(int), + expect![[r#" + a: int + : int + "#]], + ); + + check_all_expect( + "{ stdenv }: stdenv.mkDerivation { + name = undefined; + }", + ty!({ + "stdenv": { + "mkDerivation": ({ "name": string } -> derivation), + }, + } -> derivation), + expect![[r#" + stdenv: { mkDerivation: { name: string } → { } } + name: string + : { stdenv: { mkDerivation: { name: string } → { } } } → { } + "#]], + ); +}