mirror of
https://github.com/oxalica/nil.git
synced 2025-12-23 09:19:49 +00:00
Impl external expected types for type inference
This commit is contained in:
parent
e0707ad935
commit
c003625948
4 changed files with 193 additions and 70 deletions
|
|
@ -3,7 +3,7 @@ mod def;
|
|||
mod diagnostic;
|
||||
mod ide;
|
||||
mod text_edit;
|
||||
mod ty;
|
||||
pub(crate) mod ty;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ enum Ty {
|
|||
List(TyVar),
|
||||
Lambda(TyVar, TyVar),
|
||||
Attrset(Attrset),
|
||||
|
||||
External(super::Ty),
|
||||
}
|
||||
|
||||
impl Ty {
|
||||
|
|
@ -66,6 +68,14 @@ impl InferenceResult {
|
|||
}
|
||||
|
||||
pub(crate) fn infer_query(db: &dyn TyDatabase, file: FileId) -> Arc<InferenceResult> {
|
||||
infer_with(db, file, None)
|
||||
}
|
||||
|
||||
pub(crate) fn infer_with(
|
||||
db: &dyn TyDatabase,
|
||||
file: FileId,
|
||||
expect_ty: Option<super::Ty>,
|
||||
) -> Arc<InferenceResult> {
|
||||
let module = db.module(file);
|
||||
let nameres = db.name_resolution(file);
|
||||
let table = UnionFind::new(module.names().len() + module.exprs().len(), |_| Ty::Unknown);
|
||||
|
|
@ -74,7 +84,10 @@ pub(crate) fn infer_query(db: &dyn TyDatabase, file: FileId) -> Arc<InferenceRes
|
|||
nameres: &nameres,
|
||||
table,
|
||||
};
|
||||
ctx.infer_expr(module.entry_expr());
|
||||
let ty = ctx.infer_expr(module.entry_expr());
|
||||
if let Some(expect_ty) = expect_ty {
|
||||
ctx.unify_var_ty(ty, Ty::External(expect_ty));
|
||||
}
|
||||
Arc::new(ctx.finish())
|
||||
}
|
||||
|
||||
|
|
@ -101,10 +114,23 @@ impl<'db> InferCtx<'db> {
|
|||
TyVar(self.module.names().len() as u32 + u32::from(i.into_raw()))
|
||||
}
|
||||
|
||||
fn import_external(&mut self, ty: super::Ty) -> TyVar {
|
||||
let ty = match ty {
|
||||
super::Ty::Unknown => Ty::Unknown,
|
||||
super::Ty::Bool => Ty::Bool,
|
||||
super::Ty::Int => Ty::Int,
|
||||
super::Ty::Float => Ty::Float,
|
||||
super::Ty::String => Ty::String,
|
||||
super::Ty::Path => Ty::Path,
|
||||
super::Ty::List(_) | super::Ty::Lambda(..) | super::Ty::Attrset(_) => Ty::External(ty),
|
||||
};
|
||||
TyVar(self.table.push(ty))
|
||||
}
|
||||
|
||||
fn infer_expr(&mut self, e: ExprId) -> TyVar {
|
||||
let ty = self.infer_expr_inner(e);
|
||||
let placeholder_ty = self.ty_for_expr(e);
|
||||
self.unify(placeholder_ty, ty);
|
||||
self.unify_var(placeholder_ty, ty);
|
||||
ty
|
||||
}
|
||||
|
||||
|
|
@ -134,11 +160,11 @@ impl<'db> InferCtx<'db> {
|
|||
let param_ty = self.new_ty_var();
|
||||
|
||||
if let Some(name) = *name {
|
||||
self.unify(param_ty, self.ty_for_name(name));
|
||||
self.unify_var(param_ty, self.ty_for_name(name));
|
||||
}
|
||||
|
||||
if let Some(pat) = pat {
|
||||
self.unify_kind(param_ty, Ty::Attrset(Attrset::default()));
|
||||
self.unify_var_ty(param_ty, Ty::Attrset(Attrset::default()));
|
||||
for &(name, default_expr) in pat.fields.iter() {
|
||||
// Always infer default_expr.
|
||||
let default_ty = default_expr.map(|e| self.infer_expr(e));
|
||||
|
|
@ -148,12 +174,12 @@ impl<'db> InferCtx<'db> {
|
|||
};
|
||||
let name_ty = self.ty_for_name(name);
|
||||
if let Some(default_ty) = default_ty {
|
||||
self.unify(name_ty, default_ty);
|
||||
self.unify_var(name_ty, default_ty);
|
||||
}
|
||||
let field_text = self.module[name].text.clone();
|
||||
let param_field_ty =
|
||||
self.infer_set_field(param_ty, field_text, AttrSource::Name(name));
|
||||
self.unify(param_field_ty, name_ty);
|
||||
self.unify_var(param_field_ty, name_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -170,10 +196,10 @@ impl<'db> InferCtx<'db> {
|
|||
}
|
||||
&Expr::IfThenElse(cond, then, else_) => {
|
||||
let cond_ty = self.infer_expr(cond);
|
||||
self.unify_kind(cond_ty, Ty::Bool);
|
||||
self.unify_var_ty(cond_ty, Ty::Bool);
|
||||
let then_ty = self.infer_expr(then);
|
||||
let else_ty = self.infer_expr(else_);
|
||||
self.unify(then_ty, else_ty);
|
||||
self.unify_var(then_ty, else_ty);
|
||||
then_ty
|
||||
}
|
||||
&Expr::Binary(op, lhs, rhs) => {
|
||||
|
|
@ -188,15 +214,15 @@ impl<'db> InferCtx<'db> {
|
|||
match op {
|
||||
BinaryOpKind::Equal | BinaryOpKind::NotEqual => Ty::Bool.intern(self),
|
||||
BinaryOpKind::Imply | BinaryOpKind::Or | BinaryOpKind::And => {
|
||||
self.unify_kind(lhs_ty, Ty::Bool);
|
||||
self.unify_kind(rhs_ty, Ty::Bool);
|
||||
self.unify_var_ty(lhs_ty, Ty::Bool);
|
||||
self.unify_var_ty(rhs_ty, Ty::Bool);
|
||||
Ty::Bool.intern(self)
|
||||
}
|
||||
BinaryOpKind::Less
|
||||
| BinaryOpKind::Greater
|
||||
| BinaryOpKind::LessEqual
|
||||
| BinaryOpKind::GreaterEqual => {
|
||||
self.unify(lhs_ty, rhs_ty);
|
||||
self.unify_var(lhs_ty, rhs_ty);
|
||||
Ty::Bool.intern(self)
|
||||
}
|
||||
// TODO: Polymorphism.
|
||||
|
|
@ -205,19 +231,19 @@ impl<'db> InferCtx<'db> {
|
|||
| BinaryOpKind::Mul
|
||||
| BinaryOpKind::Div => {
|
||||
// TODO: Arguments have type: int | float.
|
||||
self.unify(lhs_ty, rhs_ty);
|
||||
self.unify_var(lhs_ty, rhs_ty);
|
||||
lhs_ty
|
||||
}
|
||||
BinaryOpKind::Update => {
|
||||
self.unify_kind(lhs_ty, Ty::Attrset(Attrset::default()));
|
||||
self.unify_kind(rhs_ty, Ty::Attrset(Attrset::default()));
|
||||
self.unify(lhs_ty, rhs_ty);
|
||||
self.unify_var_ty(lhs_ty, Ty::Attrset(Attrset::default()));
|
||||
self.unify_var_ty(rhs_ty, Ty::Attrset(Attrset::default()));
|
||||
self.unify_var(lhs_ty, rhs_ty);
|
||||
lhs_ty
|
||||
}
|
||||
BinaryOpKind::Concat => {
|
||||
let ret_ty = Ty::List(self.new_ty_var()).intern(self);
|
||||
self.unify(lhs_ty, ret_ty);
|
||||
self.unify(rhs_ty, ret_ty);
|
||||
self.unify_var(lhs_ty, ret_ty);
|
||||
self.unify_var(rhs_ty, ret_ty);
|
||||
ret_ty
|
||||
}
|
||||
}
|
||||
|
|
@ -227,7 +253,7 @@ impl<'db> InferCtx<'db> {
|
|||
match op {
|
||||
None => self.new_ty_var(),
|
||||
Some(UnaryOpKind::Not) => {
|
||||
self.unify_kind(arg_ty, Ty::Bool);
|
||||
self.unify_var_ty(arg_ty, Ty::Bool);
|
||||
Ty::Bool.intern(self)
|
||||
}
|
||||
// TODO: The argument is int | bool.
|
||||
|
|
@ -238,9 +264,9 @@ impl<'db> InferCtx<'db> {
|
|||
let param_ty = self.new_ty_var();
|
||||
let ret_ty = self.new_ty_var();
|
||||
let lam_ty = self.infer_expr(lam);
|
||||
self.unify_kind(lam_ty, Ty::Lambda(param_ty, ret_ty));
|
||||
self.unify_var_ty(lam_ty, Ty::Lambda(param_ty, ret_ty));
|
||||
let arg_ty = self.infer_expr(arg);
|
||||
self.unify(arg_ty, param_ty);
|
||||
self.unify_var(arg_ty, param_ty);
|
||||
ret_ty
|
||||
}
|
||||
Expr::HasAttr(set_expr, path) => {
|
||||
|
|
@ -248,7 +274,7 @@ impl<'db> InferCtx<'db> {
|
|||
self.infer_expr(*set_expr);
|
||||
for &attr in path.iter() {
|
||||
let attr_ty = self.infer_expr(attr);
|
||||
self.unify_kind(attr_ty, Ty::String);
|
||||
self.unify_var_ty(attr_ty, Ty::String);
|
||||
}
|
||||
Ty::Bool.intern(self)
|
||||
}
|
||||
|
|
@ -256,20 +282,20 @@ impl<'db> InferCtx<'db> {
|
|||
let set_ty = self.infer_expr(*set_expr);
|
||||
let ret_ty = path.iter().fold(set_ty, |set_ty, &attr| {
|
||||
let attr_ty = self.infer_expr(attr);
|
||||
self.unify_kind(attr_ty, Ty::String);
|
||||
self.unify_var_ty(attr_ty, Ty::String);
|
||||
match &self.module[attr] {
|
||||
Expr::Literal(Literal::String(key)) => {
|
||||
self.infer_set_field(set_ty, key.clone(), AttrSource::Unknown)
|
||||
}
|
||||
_ => {
|
||||
self.unify_kind(set_ty, Ty::Attrset(Attrset::default()));
|
||||
self.unify_var_ty(set_ty, Ty::Attrset(Attrset::default()));
|
||||
self.new_ty_var()
|
||||
}
|
||||
}
|
||||
});
|
||||
if let Some(default_expr) = *default_expr {
|
||||
let default_ty = self.infer_expr(default_expr);
|
||||
self.unify(ret_ty, default_ty);
|
||||
self.unify_var(ret_ty, default_ty);
|
||||
}
|
||||
ret_ty
|
||||
}
|
||||
|
|
@ -277,7 +303,7 @@ impl<'db> InferCtx<'db> {
|
|||
for &part in parts.iter() {
|
||||
let ty = self.infer_expr(part);
|
||||
// FIXME: Parts are coerce-able to string.
|
||||
self.unify_kind(ty, Ty::String);
|
||||
self.unify_var_ty(ty, Ty::String);
|
||||
}
|
||||
Ty::Path.intern(self)
|
||||
}
|
||||
|
|
@ -285,7 +311,7 @@ impl<'db> InferCtx<'db> {
|
|||
for &part in parts.iter() {
|
||||
let ty = self.infer_expr(part);
|
||||
// FIXME: Parts are coerce-able to string.
|
||||
self.unify_kind(ty, Ty::String);
|
||||
self.unify_var_ty(ty, Ty::String);
|
||||
}
|
||||
Ty::String.intern(self)
|
||||
}
|
||||
|
|
@ -294,7 +320,7 @@ impl<'db> InferCtx<'db> {
|
|||
let ret_ty = Ty::List(expect_elem_ty).intern(self);
|
||||
for &elem in elems.iter() {
|
||||
let elem_ty = self.infer_expr(elem);
|
||||
self.unify(elem_ty, expect_elem_ty);
|
||||
self.unify_var(elem_ty, expect_elem_ty);
|
||||
}
|
||||
ret_ty
|
||||
}
|
||||
|
|
@ -330,14 +356,14 @@ impl<'db> InferCtx<'db> {
|
|||
self.infer_set_field(from_ty, name_text.clone(), AttrSource::Name(name))
|
||||
}
|
||||
};
|
||||
self.unify(name_ty, value_ty);
|
||||
self.unify_var(name_ty, value_ty);
|
||||
let src = AttrSource::Name(name);
|
||||
fields.insert(name_text, (value_ty, src));
|
||||
}
|
||||
|
||||
for &(k, v) in bindings.dynamics.iter() {
|
||||
let name_ty = self.infer_expr(k);
|
||||
self.unify_kind(name_ty, Ty::String);
|
||||
self.unify_var_ty(name_ty, Ty::String);
|
||||
self.infer_expr(v);
|
||||
}
|
||||
|
||||
|
|
@ -357,53 +383,44 @@ impl<'db> InferCtx<'db> {
|
|||
ent.insert((next_ty, src));
|
||||
}
|
||||
},
|
||||
Ty::External(super::Ty::Attrset(set)) => {
|
||||
if let Some(ty) = set.get(&field).cloned() {
|
||||
return self.import_external(ty);
|
||||
}
|
||||
}
|
||||
k @ Ty::Unknown => {
|
||||
*k = Ty::Attrset(Attrset([(field, (next_ty, src))].into_iter().collect()));
|
||||
}
|
||||
Ty::Bool
|
||||
| Ty::Int
|
||||
| Ty::Float
|
||||
| Ty::String
|
||||
| Ty::Path
|
||||
| Ty::List(_)
|
||||
| Ty::Lambda(_, _) => {}
|
||||
_ => {}
|
||||
}
|
||||
self.new_ty_var()
|
||||
}
|
||||
|
||||
/// Unify a type in table with an expected kind.
|
||||
fn unify_kind(&mut self, a: TyVar, b: Ty) {
|
||||
match (self.table.get_mut(a.0), b) {
|
||||
(a @ Ty::Unknown, b) => *a = b,
|
||||
(&mut Ty::List(a), Ty::List(b)) => self.unify(a, b),
|
||||
(&mut Ty::Lambda(a1, a2), Ty::Lambda(b1, b2)) => {
|
||||
self.unify(a1, b1);
|
||||
self.unify(a2, b2);
|
||||
}
|
||||
(Ty::Attrset(_), Ty::Attrset(b)) => {
|
||||
assert!(b.0.is_empty(), "Never unify_kind an non-empty set");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
fn unify_var_ty(&mut self, var: TyVar, rhs: Ty) {
|
||||
let lhs = mem::replace(self.table.get_mut(var.0), Ty::Unknown);
|
||||
let ret = self.unify(lhs, rhs);
|
||||
*self.table.get_mut(var.0) = ret;
|
||||
}
|
||||
|
||||
fn unify(&mut self, a: TyVar, b: TyVar) {
|
||||
let (i, other) = self.table.unify(a.0, b.0);
|
||||
let other = match other {
|
||||
Some(other) => other,
|
||||
None => return,
|
||||
};
|
||||
let mut a = mem::replace(self.table.get_mut(i), Ty::Unknown);
|
||||
match (&mut a, other) {
|
||||
(a @ Ty::Unknown, b) => *a = b,
|
||||
(&mut Ty::List(a), Ty::List(b)) => {
|
||||
self.unify(a, b);
|
||||
fn unify_var(&mut self, lhs: TyVar, rhs: TyVar) {
|
||||
let (var, rhs) = self.table.unify(lhs.0, rhs.0);
|
||||
let Some(rhs) = rhs else { return };
|
||||
self.unify_var_ty(TyVar(var), rhs);
|
||||
}
|
||||
|
||||
fn unify(&mut self, lhs: Ty, rhs: Ty) -> Ty {
|
||||
match (lhs, rhs) {
|
||||
(Ty::Unknown, other) | (other, Ty::Unknown) => other,
|
||||
(Ty::List(a), Ty::List(b)) => {
|
||||
self.unify_var(a, b);
|
||||
Ty::List(a)
|
||||
}
|
||||
(&mut Ty::Lambda(a1, a2), Ty::Lambda(b1, b2)) => {
|
||||
self.unify(a1, b1);
|
||||
self.unify(a2, b2);
|
||||
(Ty::Lambda(arg1, ret1), Ty::Lambda(arg2, ret2)) => {
|
||||
self.unify_var(arg1, arg2);
|
||||
self.unify_var(ret1, ret2);
|
||||
Ty::Lambda(arg1, ret1)
|
||||
}
|
||||
(Ty::Attrset(a), Ty::Attrset(b)) => {
|
||||
(Ty::Attrset(mut a), Ty::Attrset(b)) => {
|
||||
for (field, (ty2, src2)) in b.0 {
|
||||
match a.0.entry(field) {
|
||||
Entry::Vacant(ent) => {
|
||||
|
|
@ -412,14 +429,34 @@ impl<'db> InferCtx<'db> {
|
|||
Entry::Occupied(mut ent) => {
|
||||
let (ty1, src1) = ent.get_mut();
|
||||
src1.unify(src2);
|
||||
self.unify(*ty1, ty2);
|
||||
self.unify_var(*ty1, ty2);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ty::Attrset(a)
|
||||
}
|
||||
_ => {}
|
||||
(Ty::External(external), local) | (local, Ty::External(external)) => {
|
||||
match (local, &external) {
|
||||
(Ty::Lambda(arg1, ret1), super::Ty::Lambda(arg2, ret2)) => {
|
||||
let arg2 = self.import_external(super::Ty::clone(arg2));
|
||||
let ret2 = self.import_external(super::Ty::clone(ret2));
|
||||
self.unify_var(arg1, arg2);
|
||||
self.unify_var(ret1, ret2);
|
||||
}
|
||||
(Ty::Attrset(a), super::Ty::Attrset(b)) => {
|
||||
for (field, (ty, _)) in &a.0 {
|
||||
if let Some(field_ty) = b.get(field) {
|
||||
let var = self.import_external(field_ty.clone());
|
||||
self.unify_var(*ty, var);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
Ty::External(external)
|
||||
}
|
||||
(lhs, _) => lhs,
|
||||
}
|
||||
*self.table.get_mut(i) = a;
|
||||
}
|
||||
|
||||
fn finish(mut self) -> InferenceResult {
|
||||
|
|
@ -494,6 +531,7 @@ impl<'a> Collector<'a> {
|
|||
.collect();
|
||||
super::Ty::Attrset(super::Attrset(set).into())
|
||||
}
|
||||
Ty::External(ty) => ty,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,31 @@
|
|||
#[rustfmt::skip]
|
||||
#[macro_export]
|
||||
macro_rules! ty {
|
||||
(?) => { $crate::ty::Ty::Unknown };
|
||||
(bool) => { $crate::ty::Ty::Int };
|
||||
(int) => { $crate::ty::Ty::Int };
|
||||
(float) => { $crate::ty::Ty::Float };
|
||||
(string) => { $crate::ty::Ty::String };
|
||||
(path) => { $crate::ty::Ty::Path };
|
||||
// TODO: More precise type for derivations.
|
||||
(derivation) => {
|
||||
$crate::ty::Ty::Attrset(::std::sync::Arc::new($crate::ty::Attrset::default()))
|
||||
};
|
||||
(($($inner:tt)*)) => {{ ty!($($inner)*) }};
|
||||
([$($inner:tt)*]) => { $crate::ty::Ty::List(::std::arc::Arc::new($ty!($($inner)*)))};
|
||||
({ $($key:literal : $ty:tt),* $(,)? }) => {{
|
||||
$crate::ty::Ty::Attrset(::std::sync::Arc::new($crate::ty::Attrset::from_internal([
|
||||
$(($key, ty!($ty)),)*
|
||||
])))
|
||||
}};
|
||||
($arg:tt -> $($ret:tt)*) => {
|
||||
$crate::ty::Ty::Lambda(
|
||||
::std::sync::Arc::new(ty!($arg)),
|
||||
::std::sync::Arc::new(ty!($($ret)*)),
|
||||
)
|
||||
};
|
||||
}
|
||||
|
||||
mod fmt;
|
||||
mod infer;
|
||||
mod union_find;
|
||||
|
|
@ -65,6 +93,26 @@ impl std::fmt::Debug for Ty {
|
|||
pub struct Attrset(Box<[(SmolStr, Ty, AttrSource)]>);
|
||||
|
||||
impl Attrset {
|
||||
/// Build an Attrset for internal type schemas.
|
||||
///
|
||||
/// # Panics
|
||||
/// The given iterator must have no duplicated fields, or it'll panic.
|
||||
#[track_caller]
|
||||
// FIXME: Currently this is only used in tests.
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
fn from_internal(iter: impl IntoIterator<Item = (&'static str, Ty)>) -> Self {
|
||||
let mut set = iter
|
||||
.into_iter()
|
||||
.map(|(name, ty)| (SmolStr::from(name), ty, AttrSource::Unknown))
|
||||
.collect::<Box<[_]>>();
|
||||
set.sort_by(|(lhs, ..), (rhs, ..)| lhs.cmp(rhs));
|
||||
assert!(
|
||||
set.windows(2).all(|w| w[0].0 != w[1].0),
|
||||
"Duplicated fields",
|
||||
);
|
||||
Self(set)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ use crate::tests::TestDB;
|
|||
use crate::{DefDatabase, TyDatabase};
|
||||
use expect_test::{expect, Expect};
|
||||
|
||||
use super::Ty;
|
||||
|
||||
#[track_caller]
|
||||
fn check(src: &str, expect: Expect) {
|
||||
let (db, file) = TestDB::single_file(src).unwrap();
|
||||
let module = db.module(file);
|
||||
|
|
@ -11,10 +14,16 @@ fn check(src: &str, expect: Expect) {
|
|||
expect.assert_eq(&got);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn check_all(src: &str, expect: Expect) {
|
||||
check_all_expect(src, None, expect);
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
fn check_all_expect(src: &str, expect_ty: impl Into<Option<Ty>>, expect: Expect) {
|
||||
let (db, file) = TestDB::single_file(src).unwrap();
|
||||
let module = db.module(file);
|
||||
let infer = db.infer(file);
|
||||
let infer = super::infer::infer_with(&db, file, expect_ty.into());
|
||||
let got = module
|
||||
.names()
|
||||
.map(|(i, name)| format!("{}: {}\n", name.text, infer.ty_for_name(i).debug()))
|
||||
|
|
@ -140,3 +149,31 @@ fn select() {
|
|||
"#]],
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn external() {
|
||||
check_all_expect(
|
||||
"let a = a; in a",
|
||||
ty!(int),
|
||||
expect![[r#"
|
||||
a: int
|
||||
: int
|
||||
"#]],
|
||||
);
|
||||
|
||||
check_all_expect(
|
||||
"{ stdenv }: stdenv.mkDerivation {
|
||||
name = undefined;
|
||||
}",
|
||||
ty!({
|
||||
"stdenv": {
|
||||
"mkDerivation": ({ "name": string } -> derivation),
|
||||
},
|
||||
} -> derivation),
|
||||
expect![[r#"
|
||||
stdenv: { mkDerivation: { name: string } → { } }
|
||||
name: string
|
||||
: { stdenv: { mkDerivation: { name: string } → { } } } → { }
|
||||
"#]],
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue