diff --git a/Cargo.lock b/Cargo.lock index 5bf946b34c..51cf1825d6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -695,6 +695,7 @@ name = "ra_hir" version = "0.1.0" dependencies = [ "arrayvec 0.4.9 (registry+https://github.com/rust-lang/crates.io-index)", + "flexi_logger 0.10.3 (registry+https://github.com/rust-lang/crates.io-index)", "id-arena 2.0.0 (registry+https://github.com/rust-lang/crates.io-index)", "log 0.4.6 (registry+https://github.com/rust-lang/crates.io-index)", "parking_lot 0.6.4 (registry+https://github.com/rust-lang/crates.io-index)", diff --git a/crates/ra_analysis/src/db.rs b/crates/ra_analysis/src/db.rs index 94729d2967..780a84291a 100644 --- a/crates/ra_analysis/src/db.rs +++ b/crates/ra_analysis/src/db.rs @@ -93,6 +93,8 @@ salsa::database_storage! { fn item_map() for hir::db::ItemMapQuery; fn fn_syntax() for hir::db::FnSyntaxQuery; fn submodules() for hir::db::SubmodulesQuery; + fn infer() for hir::db::InferQuery; + fn type_for_def() for hir::db::TypeForDefQuery; } } } diff --git a/crates/ra_analysis/src/imp.rs b/crates/ra_analysis/src/imp.rs index b01382808e..40996bfd73 100644 --- a/crates/ra_analysis/src/imp.rs +++ b/crates/ra_analysis/src/imp.rs @@ -5,7 +5,8 @@ use std::{ use ra_editor::{self, find_node_at_offset, FileSymbol, LineIndex, LocalEdit}; use ra_syntax::{ - ast::{self, ArgListOwner, Expr, NameOwner}, + ast::{self, ArgListOwner, Expr, NameOwner, FnDef}, + algo::find_covering_node, AstNode, SourceFileNode, SyntaxKind::*, SyntaxNodeRef, TextRange, TextUnit, @@ -510,6 +511,23 @@ impl AnalysisImpl { Ok(None) } + pub fn type_of(&self, file_id: FileId, range: TextRange) -> Cancelable> { + let file = self.db.source_file(file_id); + let syntax = file.syntax(); + let node = find_covering_node(syntax, range); + let parent_fn = node.ancestors().filter_map(FnDef::cast).next(); + let parent_fn = if let Some(p) = parent_fn { + p + } else { + return Ok(None); + }; + let function = ctry!(source_binder::function_from_source( + &*self.db, file_id, parent_fn + )?); + let infer = function.infer(&*self.db)?; + Ok(infer.type_of_node(node).map(|t| t.to_string())) + } + fn index_resolve(&self, name_ref: ast::NameRef) -> Cancelable> { let name = name_ref.text(); let mut query = Query::new(name.to_string()); diff --git a/crates/ra_analysis/src/lib.rs b/crates/ra_analysis/src/lib.rs index 85df9c089c..8308981405 100644 --- a/crates/ra_analysis/src/lib.rs +++ b/crates/ra_analysis/src/lib.rs @@ -366,6 +366,9 @@ impl Analysis { ) -> Cancelable)>> { self.imp.resolve_callable(position) } + pub fn type_of(&self, file_id: FileId, range: TextRange) -> Cancelable> { + self.imp.type_of(file_id, range) + } } pub struct LibraryData { diff --git a/crates/ra_hir/Cargo.toml b/crates/ra_hir/Cargo.toml index 61650cee9d..594176337d 100644 --- a/crates/ra_hir/Cargo.toml +++ b/crates/ra_hir/Cargo.toml @@ -16,3 +16,6 @@ ra_syntax = { path = "../ra_syntax" } ra_editor = { path = "../ra_editor" } ra_db = { path = "../ra_db" } test_utils = { path = "../test_utils" } + +[dev-dependencies] +flexi_logger = "0.10.0" diff --git a/crates/ra_hir/src/db.rs b/crates/ra_hir/src/db.rs index 62cf9ab17c..d94f75857f 100644 --- a/crates/ra_hir/src/db.rs +++ b/crates/ra_hir/src/db.rs @@ -14,6 +14,7 @@ use crate::{ function::FnId, module::{ModuleId, ModuleTree, ModuleSource, nameres::{ItemMap, InputModuleItems}}, + ty::{InferenceResult, Ty}, }; salsa::query_group! { @@ -30,6 +31,16 @@ pub trait HirDatabase: SyntaxDatabase use fn query_definitions::fn_syntax; } + fn infer(fn_id: FnId) -> Cancelable> { + type InferQuery; + use fn query_definitions::infer; + } + + fn type_for_def(def_id: DefId) -> Cancelable { + type TypeForDefQuery; + use fn query_definitions::type_for_def; + } + fn file_items(file_id: FileId) -> Arc { type SourceFileItemsQuery; use fn query_definitions::file_items; diff --git a/crates/ra_hir/src/function.rs b/crates/ra_hir/src/function.rs index 2925beb16b..d36477b48d 100644 --- a/crates/ra_hir/src/function.rs +++ b/crates/ra_hir/src/function.rs @@ -5,12 +5,13 @@ use std::{ sync::Arc, }; +use ra_db::Cancelable; use ra_syntax::{ TextRange, TextUnit, ast::{self, AstNode, DocCommentsOwner, NameOwner}, }; -use crate::{ DefId, HirDatabase }; +use crate::{ DefId, HirDatabase, ty::InferenceResult, Module }; pub use self::scope::FnScopes; @@ -18,7 +19,7 @@ pub use self::scope::FnScopes; pub struct FnId(pub(crate) DefId); pub struct Function { - fn_id: FnId, + pub(crate) fn_id: FnId, } impl Function { @@ -27,6 +28,10 @@ impl Function { Function { fn_id } } + pub fn syntax(&self, db: &impl HirDatabase) -> ast::FnDefNode { + db.fn_syntax(self.fn_id) + } + pub fn scopes(&self, db: &impl HirDatabase) -> Arc { db.fn_scopes(self.fn_id) } @@ -35,6 +40,15 @@ impl Function { let syntax = db.fn_syntax(self.fn_id); FnSignatureInfo::new(syntax.borrowed()) } + + pub fn infer(&self, db: &impl HirDatabase) -> Cancelable> { + db.infer(self.fn_id) + } + + pub fn module(&self, db: &impl HirDatabase) -> Cancelable { + let loc = self.fn_id.0.loc(db); + Module::new(db, loc.source_root_id, loc.module_id) + } } #[derive(Debug, Clone)] diff --git a/crates/ra_hir/src/lib.rs b/crates/ra_hir/src/lib.rs index f56214b47a..a0d99a84df 100644 --- a/crates/ra_hir/src/lib.rs +++ b/crates/ra_hir/src/lib.rs @@ -25,10 +25,11 @@ pub mod source_binder; mod krate; mod module; mod function; +mod ty; use std::ops::Index; -use ra_syntax::{SyntaxNodeRef, SyntaxNode}; +use ra_syntax::{SyntaxNodeRef, SyntaxNode, SyntaxKind}; use ra_db::{LocationIntener, SourceRootId, FileId, Cancelable}; use crate::{ @@ -66,6 +67,23 @@ pub struct DefLoc { source_item_id: SourceItemId, } +impl DefKind { + pub(crate) fn for_syntax_kind(kind: SyntaxKind) -> Option { + match kind { + SyntaxKind::FN_DEF => Some(DefKind::Function), + SyntaxKind::MODULE => Some(DefKind::Module), + // These define items, but don't have their own DefKinds yet: + SyntaxKind::STRUCT_DEF => Some(DefKind::Item), + SyntaxKind::ENUM_DEF => Some(DefKind::Item), + SyntaxKind::TRAIT_DEF => Some(DefKind::Item), + SyntaxKind::TYPE_DEF => Some(DefKind::Item), + SyntaxKind::CONST_DEF => Some(DefKind::Item), + SyntaxKind::STATIC_DEF => Some(DefKind::Item), + _ => None, + } + } +} + impl DefId { pub(crate) fn loc(self, db: &impl AsRef>) -> DefLoc { db.as_ref().id2loc(self) diff --git a/crates/ra_hir/src/mock.rs b/crates/ra_hir/src/mock.rs index 9423e65714..b5a9971707 100644 --- a/crates/ra_hir/src/mock.rs +++ b/crates/ra_hir/src/mock.rs @@ -8,7 +8,7 @@ use test_utils::{parse_fixture, CURSOR_MARKER, extract_offset}; use crate::{db, DefId, DefLoc}; -const WORKSPACE: SourceRootId = SourceRootId(0); +pub const WORKSPACE: SourceRootId = SourceRootId(0); #[derive(Debug)] pub(crate) struct MockDatabase { @@ -24,6 +24,15 @@ impl MockDatabase { (db, source_root) } + pub(crate) fn with_single_file(text: &str) -> (MockDatabase, SourceRoot, FileId) { + let mut db = MockDatabase::default(); + let mut source_root = SourceRoot::default(); + let file_id = db.add_file(&mut source_root, "/main.rs", text); + db.query_mut(ra_db::SourceRootQuery) + .set(WORKSPACE, Arc::new(source_root.clone())); + (db, source_root, file_id) + } + pub(crate) fn with_position(fixture: &str) -> (MockDatabase, FilePosition) { let (db, _, position) = MockDatabase::from_fixture(fixture); let position = position.expect("expected a marker ( <|> )"); @@ -182,6 +191,8 @@ salsa::database_storage! { fn item_map() for db::ItemMapQuery; fn fn_syntax() for db::FnSyntaxQuery; fn submodules() for db::SubmodulesQuery; + fn infer() for db::InferQuery; + fn type_for_def() for db::TypeForDefQuery; } } } diff --git a/crates/ra_hir/src/module.rs b/crates/ra_hir/src/module.rs index cd31e8cfe6..8911199530 100644 --- a/crates/ra_hir/src/module.rs +++ b/crates/ra_hir/src/module.rs @@ -2,6 +2,7 @@ pub(super) mod imp; pub(super) mod nameres; use std::sync::Arc; +use log; use ra_syntax::{ algo::generate, diff --git a/crates/ra_hir/src/module/nameres.rs b/crates/ra_hir/src/module/nameres.rs index 39e891cda5..0b152a4062 100644 --- a/crates/ra_hir/src/module/nameres.rs +++ b/crates/ra_hir/src/module/nameres.rs @@ -272,13 +272,13 @@ where } } } - // Populate explicitelly declared items, except modules + // Populate explicitly declared items, except modules for item in input.items.iter() { if item.kind == MODULE { continue; } let def_loc = DefLoc { - kind: DefKind::Item, + kind: DefKind::for_syntax_kind(item.kind).unwrap_or(DefKind::Item), source_root_id: self.source_root, module_id, source_item_id: SourceItemId { diff --git a/crates/ra_hir/src/query_definitions.rs b/crates/ra_hir/src/query_definitions.rs index efaeb1525a..b654af9204 100644 --- a/crates/ra_hir/src/query_definitions.rs +++ b/crates/ra_hir/src/query_definitions.rs @@ -11,7 +11,7 @@ use ra_syntax::{ use ra_db::{SourceRootId, FileId, Cancelable,}; use crate::{ - SourceFileItems, SourceItemId, DefKind, + SourceFileItems, SourceItemId, DefKind, Function, DefId, db::HirDatabase, function::{FnScopes, FnId}, module::{ @@ -19,6 +19,7 @@ use crate::{ imp::Submodule, nameres::{InputModuleItems, ItemMap, Resolver}, }, + ty::{self, InferenceResult, Ty} }; /// Resolve `FnId` to the corresponding `SyntaxNode` @@ -35,6 +36,15 @@ pub(super) fn fn_scopes(db: &impl HirDatabase, fn_id: FnId) -> Arc { Arc::new(res) } +pub(super) fn infer(db: &impl HirDatabase, fn_id: FnId) -> Cancelable> { + let function = Function { fn_id }; + ty::infer(db, function).map(Arc::new) +} + +pub(super) fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable { + ty::type_for_def(db, def_id) +} + pub(super) fn file_items(db: &impl HirDatabase, file_id: FileId) -> Arc { let mut res = SourceFileItems::new(file_id); let source_file = db.source_file(file_id); diff --git a/crates/ra_hir/src/ty.rs b/crates/ra_hir/src/ty.rs new file mode 100644 index 0000000000..c759d4c8b1 --- /dev/null +++ b/crates/ra_hir/src/ty.rs @@ -0,0 +1,601 @@ +mod primitive; +#[cfg(test)] +mod tests; + +use std::sync::Arc; +use std::fmt; + +use log; +use rustc_hash::{FxHashMap}; + +use ra_db::{LocalSyntaxPtr, Cancelable}; +use ra_syntax::{ + SmolStr, + ast::{self, AstNode, LoopBodyOwner, ArgListOwner}, + SyntaxNodeRef +}; + +use crate::{Def, DefId, FnScopes, Module, Function, Path, db::HirDatabase}; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub enum Ty { + /// The primitive boolean type. Written as `bool`. + Bool, + + /// The primitive character type; holds a Unicode scalar value + /// (a non-surrogate code point). Written as `char`. + Char, + + /// A primitive signed integer type. For example, `i32`. + Int(primitive::IntTy), + + /// A primitive unsigned integer type. For example, `u32`. + Uint(primitive::UintTy), + + /// A primitive floating-point type. For example, `f64`. + Float(primitive::FloatTy), + + // Structures, enumerations and unions. + // Adt(AdtDef, Substs), + /// The pointee of a string slice. Written as `str`. + Str, + + // An array with the given length. Written as `[T; n]`. + // Array(Ty, ty::Const), + /// The pointee of an array slice. Written as `[T]`. + Slice(TyRef), + + // A raw pointer. Written as `*mut T` or `*const T` + // RawPtr(TypeAndMut<'tcx>), + + // A reference; a pointer with an associated lifetime. Written as + // `&'a mut T` or `&'a T`. + // Ref(Ty<'tcx>, hir::Mutability), + /// A pointer to a function. Written as `fn() -> i32`. + /// + /// For example the type of `bar` here: + /// + /// ```rust + /// fn foo() -> i32 { 1 } + /// let bar: fn() -> i32 = foo; + /// ``` + FnPtr(Arc), + + // A trait, defined with `dyn trait`. + // Dynamic(), + /// The anonymous type of a closure. Used to represent the type of + /// `|a| a`. + // Closure(DefId, ClosureSubsts<'tcx>), + + /// The anonymous type of a generator. Used to represent the type of + /// `|a| yield a`. + // Generator(DefId, GeneratorSubsts<'tcx>, hir::GeneratorMovability), + + /// A type representin the types stored inside a generator. + /// This should only appear in GeneratorInteriors. + // GeneratorWitness(Binder<&'tcx List>>), + + /// The never type `!` + Never, + + /// A tuple type. For example, `(i32, bool)`. + Tuple(Vec), + + // The projection of an associated type. For example, + // `>::N`. + // Projection(ProjectionTy), + + // Opaque (`impl Trait`) type found in a return type. + // The `DefId` comes either from + // * the `impl Trait` ast::Ty node, + // * or the `existential type` declaration + // The substitutions are for the generics of the function in question. + // Opaque(DefId, Substs), + + // A type parameter; for example, `T` in `fn f(x: T) {} + // Param(ParamTy), + + // A placeholder type - universally quantified higher-ranked type. + // Placeholder(ty::PlaceholderType), + + // A type variable used during type checking. + // Infer(InferTy), + /// A placeholder for a type which could not be computed; this is + /// propagated to avoid useless error messages. + Unknown, +} + +type TyRef = Arc; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug)] +pub struct FnSig { + input: Vec, + output: Ty, +} + +impl Ty { + pub fn new(_db: &impl HirDatabase, node: ast::TypeRef) -> Cancelable { + use ra_syntax::ast::TypeRef::*; + Ok(match node { + ParenType(_inner) => Ty::Unknown, // TODO + TupleType(_inner) => Ty::Unknown, // TODO + NeverType(..) => Ty::Never, + PathType(inner) => { + let path = if let Some(p) = inner.path() { + p + } else { + return Ok(Ty::Unknown); + }; + if path.qualifier().is_none() { + let name = path + .segment() + .and_then(|s| s.name_ref()) + .map(|n| n.text()) + .unwrap_or(SmolStr::new("")); + if let Some(int_ty) = primitive::IntTy::from_string(&name) { + Ty::Int(int_ty) + } else if let Some(uint_ty) = primitive::UintTy::from_string(&name) { + Ty::Uint(uint_ty) + } else if let Some(float_ty) = primitive::FloatTy::from_string(&name) { + Ty::Float(float_ty) + } else { + // TODO + Ty::Unknown + } + } else { + // TODO + Ty::Unknown + } + } + PointerType(_inner) => Ty::Unknown, // TODO + ArrayType(_inner) => Ty::Unknown, // TODO + SliceType(_inner) => Ty::Unknown, // TODO + ReferenceType(_inner) => Ty::Unknown, // TODO + PlaceholderType(_inner) => Ty::Unknown, // TODO + FnPointerType(_inner) => Ty::Unknown, // TODO + ForType(_inner) => Ty::Unknown, // TODO + ImplTraitType(_inner) => Ty::Unknown, // TODO + DynTraitType(_inner) => Ty::Unknown, // TODO + }) + } + + pub fn unit() -> Self { + Ty::Tuple(Vec::new()) + } +} + +impl fmt::Display for Ty { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Ty::Bool => write!(f, "bool"), + Ty::Char => write!(f, "char"), + Ty::Int(t) => write!(f, "{}", t.ty_to_string()), + Ty::Uint(t) => write!(f, "{}", t.ty_to_string()), + Ty::Float(t) => write!(f, "{}", t.ty_to_string()), + Ty::Str => write!(f, "str"), + Ty::Slice(t) => write!(f, "[{}]", t), + Ty::Never => write!(f, "!"), + Ty::Tuple(ts) => { + write!(f, "(")?; + for t in ts { + write!(f, "{},", t)?; + } + write!(f, ")") + } + Ty::FnPtr(sig) => { + write!(f, "fn(")?; + for t in &sig.input { + write!(f, "{},", t)?; + } + write!(f, ") -> {}", sig.output) + } + Ty::Unknown => write!(f, "[unknown]"), + } + } +} + +pub fn type_for_fn(db: &impl HirDatabase, f: Function) -> Cancelable { + let syntax = f.syntax(db); + let node = syntax.borrowed(); + // TODO we ignore type parameters for now + let input = node + .param_list() + .map(|pl| { + pl.params() + .map(|p| { + p.type_ref() + .map(|t| Ty::new(db, t)) + .unwrap_or(Ok(Ty::Unknown)) + }) + .collect() + }) + .unwrap_or_else(|| Ok(Vec::new()))?; + let output = node + .ret_type() + .and_then(|rt| rt.type_ref()) + .map(|t| Ty::new(db, t)) + .unwrap_or(Ok(Ty::Unknown))?; + let sig = FnSig { input, output }; + Ok(Ty::FnPtr(Arc::new(sig))) +} + +// TODO this should probably be per namespace (i.e. types vs. values), since for +// a tuple struct `struct Foo(Bar)`, Foo has function type as a value, but +// defines the struct type Foo when used in the type namespace. rustc has a +// separate DefId for the constructor, but with the current DefId approach, that +// seems complicated. +pub fn type_for_def(db: &impl HirDatabase, def_id: DefId) -> Cancelable { + let def = def_id.resolve(db)?; + match def { + Def::Module(..) => { + log::debug!("trying to get type for module {:?}", def_id); + Ok(Ty::Unknown) + } + Def::Function(f) => type_for_fn(db, f), + Def::Item => { + log::debug!("trying to get type for item of unknown type {:?}", def_id); + Ok(Ty::Unknown) + } + } +} + +#[derive(Clone, PartialEq, Eq, Debug)] +pub struct InferenceResult { + type_of: FxHashMap, +} + +impl InferenceResult { + pub fn type_of_node(&self, node: SyntaxNodeRef) -> Option { + self.type_of.get(&LocalSyntaxPtr::new(node)).cloned() + } +} + +#[derive(Clone, Debug)] +pub struct InferenceContext<'a, D: HirDatabase> { + db: &'a D, + scopes: Arc, + module: Module, + // TODO unification tables... + type_of: FxHashMap, +} + +impl<'a, D: HirDatabase> InferenceContext<'a, D> { + fn new(db: &'a D, scopes: Arc, module: Module) -> Self { + InferenceContext { + type_of: FxHashMap::default(), + db, + scopes, + module, + } + } + + fn write_ty(&mut self, node: SyntaxNodeRef, ty: Ty) { + self.type_of.insert(LocalSyntaxPtr::new(node), ty); + } + + fn unify(&mut self, ty1: &Ty, ty2: &Ty) -> Option { + if *ty1 == Ty::Unknown { + return Some(ty2.clone()); + } + if *ty2 == Ty::Unknown { + return Some(ty1.clone()); + } + if ty1 == ty2 { + return Some(ty1.clone()); + } + // TODO implement actual unification + return None; + } + + fn unify_with_coercion(&mut self, ty1: &Ty, ty2: &Ty) -> Option { + // TODO implement coercion + self.unify(ty1, ty2) + } + + fn infer_path_expr(&mut self, expr: ast::PathExpr) -> Cancelable> { + let ast_path = ctry!(expr.path()); + let path = ctry!(Path::from_ast(ast_path)); + if path.is_ident() { + // resolve locally + let name = ctry!(ast_path.segment().and_then(|s| s.name_ref())); + if let Some(scope_entry) = self.scopes.resolve_local_name(name) { + let ty = ctry!(self.type_of.get(&scope_entry.ptr())); + return Ok(Some(ty.clone())); + }; + }; + + // resolve in module + let resolved = ctry!(self.module.resolve_path(self.db, path)?); + let ty = self.db.type_for_def(resolved)?; + // TODO we will need to add type variables for type parameters etc. here + Ok(Some(ty)) + } + + fn infer_expr(&mut self, expr: ast::Expr) -> Cancelable { + let ty = match expr { + ast::Expr::IfExpr(e) => { + if let Some(condition) = e.condition() { + if let Some(e) = condition.expr() { + // TODO if no pat, this should be bool + self.infer_expr(e)?; + } + // TODO write type for pat + }; + let if_ty = if let Some(block) = e.then_branch() { + self.infer_block(block)? + } else { + Ty::Unknown + }; + let else_ty = if let Some(block) = e.else_branch() { + self.infer_block(block)? + } else { + Ty::Unknown + }; + if let Some(ty) = self.unify(&if_ty, &else_ty) { + ty + } else { + // TODO report diagnostic + Ty::Unknown + } + } + ast::Expr::BlockExpr(e) => { + if let Some(block) = e.block() { + self.infer_block(block)? + } else { + Ty::Unknown + } + } + ast::Expr::LoopExpr(e) => { + if let Some(block) = e.loop_body() { + self.infer_block(block)?; + }; + // TODO never, or the type of the break param + Ty::Unknown + } + ast::Expr::WhileExpr(e) => { + if let Some(condition) = e.condition() { + if let Some(e) = condition.expr() { + // TODO if no pat, this should be bool + self.infer_expr(e)?; + } + // TODO write type for pat + }; + if let Some(block) = e.loop_body() { + // TODO + self.infer_block(block)?; + }; + // TODO always unit? + Ty::Unknown + } + ast::Expr::ForExpr(e) => { + if let Some(expr) = e.iterable() { + self.infer_expr(expr)?; + } + if let Some(_pat) = e.pat() { + // TODO write type for pat + } + if let Some(block) = e.loop_body() { + self.infer_block(block)?; + } + // TODO always unit? + Ty::Unknown + } + ast::Expr::LambdaExpr(e) => { + let _body_ty = if let Some(body) = e.body() { + self.infer_expr(body)? + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::CallExpr(e) => { + let callee_ty = if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + if let Some(arg_list) = e.arg_list() { + for arg in arg_list.args() { + // TODO unify / expect argument type + self.infer_expr(arg)?; + } + } + match callee_ty { + Ty::FnPtr(sig) => sig.output.clone(), + _ => { + // not callable + // TODO report an error? + Ty::Unknown + } + } + } + ast::Expr::MethodCallExpr(e) => { + let _receiver_ty = if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + if let Some(arg_list) = e.arg_list() { + for arg in arg_list.args() { + // TODO unify / expect argument type + self.infer_expr(arg)?; + } + } + Ty::Unknown + } + ast::Expr::MatchExpr(e) => { + let _ty = if let Some(match_expr) = e.expr() { + self.infer_expr(match_expr)? + } else { + Ty::Unknown + }; + if let Some(match_arm_list) = e.match_arm_list() { + for arm in match_arm_list.arms() { + // TODO type the bindings in pat + // TODO type the guard + let _ty = if let Some(e) = arm.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + } + // TODO unify all the match arm types + Ty::Unknown + } else { + Ty::Unknown + } + } + ast::Expr::TupleExpr(_e) => Ty::Unknown, + ast::Expr::ArrayExpr(_e) => Ty::Unknown, + ast::Expr::PathExpr(e) => self.infer_path_expr(e)?.unwrap_or(Ty::Unknown), + ast::Expr::ContinueExpr(_e) => Ty::Never, + ast::Expr::BreakExpr(_e) => Ty::Never, + ast::Expr::ParenExpr(e) => { + if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + } + } + ast::Expr::Label(_e) => Ty::Unknown, + ast::Expr::ReturnExpr(e) => { + if let Some(e) = e.expr() { + // TODO unify with return type + self.infer_expr(e)?; + }; + Ty::Never + } + ast::Expr::MatchArmList(_) | ast::Expr::MatchArm(_) | ast::Expr::MatchGuard(_) => { + // Can this even occur outside of a match expression? + Ty::Unknown + } + ast::Expr::StructLit(_e) => Ty::Unknown, + ast::Expr::NamedFieldList(_) | ast::Expr::NamedField(_) => { + // Can this even occur outside of a struct literal? + Ty::Unknown + } + ast::Expr::IndexExpr(_e) => Ty::Unknown, + ast::Expr::FieldExpr(_e) => Ty::Unknown, + ast::Expr::TryExpr(e) => { + let _inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::CastExpr(e) => { + let _inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + let cast_ty = e + .type_ref() + .map(|t| Ty::new(self.db, t)) + .unwrap_or(Ok(Ty::Unknown))?; + // TODO do the coercion... + cast_ty + } + ast::Expr::RefExpr(e) => { + let _inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::PrefixExpr(e) => { + let _inner_ty = if let Some(e) = e.expr() { + self.infer_expr(e)? + } else { + Ty::Unknown + }; + Ty::Unknown + } + ast::Expr::RangeExpr(_e) => Ty::Unknown, + ast::Expr::BinExpr(_e) => Ty::Unknown, + ast::Expr::Literal(_e) => Ty::Unknown, + }; + self.write_ty(expr.syntax(), ty.clone()); + Ok(ty) + } + + fn infer_block(&mut self, node: ast::Block) -> Cancelable { + for stmt in node.statements() { + match stmt { + ast::Stmt::LetStmt(stmt) => { + let decl_ty = if let Some(type_ref) = stmt.type_ref() { + Ty::new(self.db, type_ref)? + } else { + Ty::Unknown + }; + let ty = if let Some(expr) = stmt.initializer() { + // TODO pass expectation + let expr_ty = self.infer_expr(expr)?; + self.unify_with_coercion(&expr_ty, &decl_ty) + .unwrap_or(decl_ty) + } else { + decl_ty + }; + + if let Some(pat) = stmt.pat() { + self.write_ty(pat.syntax(), ty); + }; + } + ast::Stmt::ExprStmt(expr_stmt) => { + if let Some(expr) = expr_stmt.expr() { + self.infer_expr(expr)?; + } + } + } + } + let ty = if let Some(expr) = node.expr() { + self.infer_expr(expr)? + } else { + Ty::unit() + }; + self.write_ty(node.syntax(), ty.clone()); + Ok(ty) + } +} + +pub fn infer(db: &impl HirDatabase, function: Function) -> Cancelable { + let scopes = function.scopes(db); + let module = function.module(db)?; + let mut ctx = InferenceContext::new(db, scopes, module); + + let syntax = function.syntax(db); + let node = syntax.borrowed(); + + if let Some(param_list) = node.param_list() { + for param in param_list.params() { + let pat = if let Some(pat) = param.pat() { + pat + } else { + continue; + }; + if let Some(type_ref) = param.type_ref() { + let ty = Ty::new(db, type_ref)?; + ctx.type_of.insert(LocalSyntaxPtr::new(pat.syntax()), ty); + } else { + // TODO self param + ctx.type_of + .insert(LocalSyntaxPtr::new(pat.syntax()), Ty::Unknown); + }; + } + } + + // TODO get Ty for node.ret_type() and pass that to infer_block as expectation + // (see Expectation in rustc_typeck) + + if let Some(block) = node.body() { + ctx.infer_block(block)?; + } + + // TODO 'resolve' the types: replace inference variables by their inferred results + + Ok(InferenceResult { + type_of: ctx.type_of, + }) +} diff --git a/crates/ra_hir/src/ty/primitive.rs b/crates/ra_hir/src/ty/primitive.rs new file mode 100644 index 0000000000..ad79b17e41 --- /dev/null +++ b/crates/ra_hir/src/ty/primitive.rs @@ -0,0 +1,130 @@ +use std::fmt; + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy)] +pub enum IntTy { + Isize, + I8, + I16, + I32, + I64, + I128, +} + +impl fmt::Debug for IntTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for IntTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ty_to_string()) + } +} + +impl IntTy { + pub fn ty_to_string(&self) -> &'static str { + match *self { + IntTy::Isize => "isize", + IntTy::I8 => "i8", + IntTy::I16 => "i16", + IntTy::I32 => "i32", + IntTy::I64 => "i64", + IntTy::I128 => "i128", + } + } + + pub fn from_string(s: &str) -> Option { + match s { + "isize" => Some(IntTy::Isize), + "i8" => Some(IntTy::I8), + "i16" => Some(IntTy::I16), + "i32" => Some(IntTy::I32), + "i64" => Some(IntTy::I64), + "i128" => Some(IntTy::I128), + _ => None, + } + } +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Copy)] +pub enum UintTy { + Usize, + U8, + U16, + U32, + U64, + U128, +} + +impl UintTy { + pub fn ty_to_string(&self) -> &'static str { + match *self { + UintTy::Usize => "usize", + UintTy::U8 => "u8", + UintTy::U16 => "u16", + UintTy::U32 => "u32", + UintTy::U64 => "u64", + UintTy::U128 => "u128", + } + } + + pub fn from_string(s: &str) -> Option { + match s { + "usize" => Some(UintTy::Usize), + "u8" => Some(UintTy::U8), + "u16" => Some(UintTy::U16), + "u32" => Some(UintTy::U32), + "u64" => Some(UintTy::U64), + "u128" => Some(UintTy::U128), + _ => None, + } + } +} + +impl fmt::Debug for UintTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for UintTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ty_to_string()) + } +} + +#[derive(Clone, PartialEq, Eq, Hash, Copy, PartialOrd, Ord)] +pub enum FloatTy { + F32, + F64, +} + +impl fmt::Debug for FloatTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + fmt::Display::fmt(self, f) + } +} + +impl fmt::Display for FloatTy { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.ty_to_string()) + } +} + +impl FloatTy { + pub fn ty_to_string(self) -> &'static str { + match self { + FloatTy::F32 => "f32", + FloatTy::F64 => "f64", + } + } + + pub fn from_string(s: &str) -> Option { + match s { + "f32" => Some(FloatTy::F32), + "f64" => Some(FloatTy::F64), + _ => None, + } + } +} diff --git a/crates/ra_hir/src/ty/tests.rs b/crates/ra_hir/src/ty/tests.rs new file mode 100644 index 0000000000..b6c02cd80c --- /dev/null +++ b/crates/ra_hir/src/ty/tests.rs @@ -0,0 +1,134 @@ +use std::fmt::Write; +use std::path::{PathBuf, Path}; +use std::fs; + +use ra_db::{SyntaxDatabase}; +use ra_syntax::ast::{self, AstNode}; +use test_utils::{project_dir, assert_eq_text, read_text}; + +use crate::{ + source_binder, + mock::MockDatabase, +}; + +// These tests compare the inference results for all expressions in a file +// against snapshots of the current results. If you change something and these +// tests fail expectedly, you can update the comparison files by deleting them +// and running the tests again. Similarly, to add a new test, just write the +// test here in the same pattern and it will automatically write the snapshot. + +#[test] +fn infer_basics() { + check_inference( + r#" +fn test(a: u32, b: isize, c: !, d: &str) { + a; + b; + c; + d; + 1usize; + 1isize; + "test"; + 1.0f32; +}"#, + "0001_basics.txt", + ); +} + +#[test] +fn infer_let() { + check_inference( + r#" +fn test() { + let a = 1isize; + let b: usize = 1; + let c = b; +} +}"#, + "0002_let.txt", + ); +} + +#[test] +fn infer_paths() { + check_inference( + r#" +fn a() -> u32 { 1 } + +mod b { + fn c() -> u32 { 1 } +} + +fn test() { + a(); + b::c(); +} +}"#, + "0003_paths.txt", + ); +} + +fn infer(content: &str) -> String { + let (db, _, file_id) = MockDatabase::with_single_file(content); + let source_file = db.source_file(file_id); + let mut acc = String::new(); + for fn_def in source_file + .syntax() + .descendants() + .filter_map(ast::FnDef::cast) + { + let func = source_binder::function_from_source(&db, file_id, fn_def) + .unwrap() + .unwrap(); + let inference_result = func.infer(&db).unwrap(); + for (syntax_ptr, ty) in &inference_result.type_of { + let node = syntax_ptr.resolve(&source_file); + write!( + acc, + "{} '{}': {}\n", + syntax_ptr.range(), + ellipsize(node.text().to_string().replace("\n", " "), 15), + ty + ) + .unwrap(); + } + } + acc +} + +fn check_inference(content: &str, data_file: impl AsRef) { + let data_file_path = test_data_dir().join(data_file); + let result = infer(content); + + if !data_file_path.exists() { + println!("File with expected result doesn't exist, creating...\n"); + println!("{}\n{}", content, result); + fs::write(&data_file_path, &result).unwrap(); + panic!("File {:?} with expected result was created", data_file_path); + } + + let expected = read_text(&data_file_path); + assert_eq_text!(&expected, &result); +} + +fn ellipsize(mut text: String, max_len: usize) -> String { + if text.len() <= max_len { + return text; + } + let ellipsis = "..."; + let e_len = ellipsis.len(); + let mut prefix_len = (max_len - e_len) / 2; + while !text.is_char_boundary(prefix_len) { + prefix_len += 1; + } + let mut suffix_len = max_len - e_len - prefix_len; + while !text.is_char_boundary(text.len() - suffix_len) { + suffix_len += 1; + } + text.replace_range(prefix_len..text.len() - suffix_len, ellipsis); + text +} + +fn test_data_dir() -> PathBuf { + project_dir().join("crates/ra_hir/src/ty/tests/data") +} diff --git a/crates/ra_hir/src/ty/tests/data/0001_basics.txt b/crates/ra_hir/src/ty/tests/data/0001_basics.txt new file mode 100644 index 0000000000..0c46f243a8 --- /dev/null +++ b/crates/ra_hir/src/ty/tests/data/0001_basics.txt @@ -0,0 +1,13 @@ +[33; 34) 'd': [unknown] +[88; 94) '1isize': [unknown] +[48; 49) 'a': u32 +[55; 56) 'b': isize +[112; 118) '1.0f32': [unknown] +[76; 82) '1usize': [unknown] +[9; 10) 'a': u32 +[27; 28) 'c': ! +[62; 63) 'c': ! +[17; 18) 'b': isize +[100; 106) '"test"': [unknown] +[42; 121) '{ ...f32; }': () +[69; 70) 'd': [unknown] diff --git a/crates/ra_hir/src/ty/tests/data/0002_let.txt b/crates/ra_hir/src/ty/tests/data/0002_let.txt new file mode 100644 index 0000000000..2d0d1f57b3 --- /dev/null +++ b/crates/ra_hir/src/ty/tests/data/0002_let.txt @@ -0,0 +1,7 @@ +[21; 22) 'a': [unknown] +[52; 53) '1': [unknown] +[11; 71) '{ ...= b; }': () +[63; 64) 'c': usize +[25; 31) '1isize': [unknown] +[41; 42) 'b': usize +[67; 68) 'b': usize diff --git a/crates/ra_hir/src/ty/tests/data/0003_paths.txt b/crates/ra_hir/src/ty/tests/data/0003_paths.txt new file mode 100644 index 0000000000..dcb5456ae3 --- /dev/null +++ b/crates/ra_hir/src/ty/tests/data/0003_paths.txt @@ -0,0 +1,9 @@ +[15; 20) '{ 1 }': [unknown] +[17; 18) '1': [unknown] +[50; 51) '1': [unknown] +[48; 53) '{ 1 }': [unknown] +[82; 88) 'b::c()': u32 +[67; 91) '{ ...c(); }': () +[73; 74) 'a': fn() -> u32 +[73; 76) 'a()': u32 +[82; 86) 'b::c': fn() -> u32 diff --git a/crates/ra_syntax/src/ast/generated.rs b/crates/ra_syntax/src/ast/generated.rs index bf056131ef..c735338619 100644 --- a/crates/ra_syntax/src/ast/generated.rs +++ b/crates/ra_syntax/src/ast/generated.rs @@ -523,7 +523,15 @@ impl> CastExprNode { } -impl<'a> CastExpr<'a> {} +impl<'a> CastExpr<'a> { + pub fn expr(self) -> Option> { + super::child_opt(self) + } + + pub fn type_ref(self) -> Option> { + super::child_opt(self) + } +} // Char #[derive(Debug, Clone, Copy,)] @@ -1553,6 +1561,10 @@ impl<'a> LetStmt<'a> { super::child_opt(self) } + pub fn type_ref(self) -> Option> { + super::child_opt(self) + } + pub fn initializer(self) -> Option> { super::child_opt(self) } @@ -2312,6 +2324,10 @@ impl<'a> Param<'a> { pub fn pat(self) -> Option> { super::child_opt(self) } + + pub fn type_ref(self) -> Option> { + super::child_opt(self) + } } // ParamList @@ -2394,7 +2410,11 @@ impl> ParenExprNode { } -impl<'a> ParenExpr<'a> {} +impl<'a> ParenExpr<'a> { + pub fn expr(self) -> Option> { + super::child_opt(self) + } +} // ParenType #[derive(Debug, Clone, Copy,)] @@ -2681,7 +2701,11 @@ impl> PathTypeNode { } -impl<'a> PathType<'a> {} +impl<'a> PathType<'a> { + pub fn path(self) -> Option> { + super::child_opt(self) + } +} // PlaceholderPat #[derive(Debug, Clone, Copy,)] @@ -2829,7 +2853,11 @@ impl> PrefixExprNode { } -impl<'a> PrefixExpr<'a> {} +impl<'a> PrefixExpr<'a> { + pub fn expr(self) -> Option> { + super::child_opt(self) + } +} // RangeExpr #[derive(Debug, Clone, Copy,)] @@ -2940,7 +2968,11 @@ impl> RefExprNode { } -impl<'a> RefExpr<'a> {} +impl<'a> RefExpr<'a> { + pub fn expr(self) -> Option> { + super::child_opt(self) + } +} // RefPat #[derive(Debug, Clone, Copy,)] @@ -3051,7 +3083,11 @@ impl> RetTypeNode { } -impl<'a> RetType<'a> {} +impl<'a> RetType<'a> { + pub fn type_ref(self) -> Option> { + super::child_opt(self) + } +} // ReturnExpr #[derive(Debug, Clone, Copy,)] @@ -3088,7 +3124,11 @@ impl> ReturnExprNode { } -impl<'a> ReturnExpr<'a> {} +impl<'a> ReturnExpr<'a> { + pub fn expr(self) -> Option> { + super::child_opt(self) + } +} // SelfParam #[derive(Debug, Clone, Copy,)] @@ -3578,7 +3618,11 @@ impl> TryExprNode { } -impl<'a> TryExpr<'a> {} +impl<'a> TryExpr<'a> { + pub fn expr(self) -> Option> { + super::child_opt(self) + } +} // TupleExpr #[derive(Debug, Clone, Copy,)] diff --git a/crates/ra_syntax/src/grammar.ron b/crates/ra_syntax/src/grammar.ron index eed67637e0..e3b9032a0c 100644 --- a/crates/ra_syntax/src/grammar.ron +++ b/crates/ra_syntax/src/grammar.ron @@ -254,7 +254,7 @@ Grammar( ], options: [ "ParamList", ["body", "Block"], "RetType" ], ), - "RetType": (), + "RetType": (options: ["TypeRef"]), "StructDef": ( traits: [ "NameOwner", @@ -304,7 +304,7 @@ Grammar( "ParenType": (), "TupleType": (), "NeverType": (), - "PathType": (), + "PathType": (options: ["Path"]), "PointerType": (), "ArrayType": (), "SliceType": (), @@ -346,7 +346,7 @@ Grammar( "TupleExpr": (), "ArrayExpr": (), - "ParenExpr": (), + "ParenExpr": (options: ["Expr"]), "PathExpr": (options: ["Path"]), "LambdaExpr": ( options: [ @@ -377,7 +377,7 @@ Grammar( "BlockExpr": ( options: [ "Block" ] ), - "ReturnExpr": (), + "ReturnExpr": (options: ["Expr"]), "MatchExpr": ( options: [ "Expr", "MatchArmList" ], ), @@ -405,10 +405,10 @@ Grammar( ), "IndexExpr": (), "FieldExpr": (), - "TryExpr": (), - "CastExpr": (), - "RefExpr": (), - "PrefixExpr": (), + "TryExpr": (options: ["Expr"]), + "CastExpr": (options: ["Expr", "TypeRef"]), + "RefExpr": (options: ["Expr"]), + "PrefixExpr": (options: ["Expr"]), "RangeExpr": (), "BinExpr": (), "String": (), @@ -499,6 +499,7 @@ Grammar( ), "LetStmt": ( options: [ ["pat", "Pat"], + ["type_ref", "TypeRef"], ["initializer", "Expr"], ]), "Condition": ( @@ -521,7 +522,7 @@ Grammar( ), "SelfParam": (), "Param": ( - options: [ "Pat" ], + options: [ "Pat", "TypeRef" ], ), "UseItem": ( options: [ "UseTree" ] diff --git a/crates/ra_syntax/tests/test.rs b/crates/ra_syntax/tests/test.rs index 4266864bdf..2235dc401d 100644 --- a/crates/ra_syntax/tests/test.rs +++ b/crates/ra_syntax/tests/test.rs @@ -1,14 +1,13 @@ extern crate ra_syntax; -#[macro_use] extern crate test_utils; extern crate walkdir; use std::{ fmt::Write, - fs, - path::{Path, PathBuf, Component}, + path::{PathBuf, Component}, }; +use test_utils::{project_dir, dir_tests, read_text, collect_tests}; use ra_syntax::{ utils::{check_fuzz_invariants, dump_tree}, SourceFileNode, @@ -16,7 +15,7 @@ use ra_syntax::{ #[test] fn lexer_tests() { - dir_tests(&["lexer"], |text, _| { + dir_tests(&test_data_dir(), &["lexer"], |text, _| { let tokens = ra_syntax::tokenize(text); dump_tokens(&tokens, text) }) @@ -24,33 +23,41 @@ fn lexer_tests() { #[test] fn parser_tests() { - dir_tests(&["parser/inline/ok", "parser/ok"], |text, path| { - let file = SourceFileNode::parse(text); - let errors = file.errors(); - assert_eq!( - &*errors, - &[] as &[ra_syntax::SyntaxError], - "There should be no errors in the file {:?}", - path.display() - ); - dump_tree(file.syntax()) - }); - dir_tests(&["parser/err", "parser/inline/err"], |text, path| { - let file = SourceFileNode::parse(text); - let errors = file.errors(); - assert_ne!( - &*errors, - &[] as &[ra_syntax::SyntaxError], - "There should be errors in the file {:?}", - path.display() - ); - dump_tree(file.syntax()) - }); + dir_tests( + &test_data_dir(), + &["parser/inline/ok", "parser/ok"], + |text, path| { + let file = SourceFileNode::parse(text); + let errors = file.errors(); + assert_eq!( + &*errors, + &[] as &[ra_syntax::SyntaxError], + "There should be no errors in the file {:?}", + path.display() + ); + dump_tree(file.syntax()) + }, + ); + dir_tests( + &test_data_dir(), + &["parser/err", "parser/inline/err"], + |text, path| { + let file = SourceFileNode::parse(text); + let errors = file.errors(); + assert_ne!( + &*errors, + &[] as &[ra_syntax::SyntaxError], + "There should be errors in the file {:?}", + path.display() + ); + dump_tree(file.syntax()) + }, + ); } #[test] fn parser_fuzz_tests() { - for (_, text) in collect_tests(&["parser/fuzz-failures"]) { + for (_, text) in collect_tests(&test_data_dir(), &["parser/fuzz-failures"]) { check_fuzz_invariants(&text) } } @@ -92,102 +99,6 @@ fn self_hosting_parsing() { "self_hosting_parsing found too few files - is it running in the right directory?" ) } -/// Read file and normalize newlines. -/// -/// `rustc` seems to always normalize `\r\n` newlines to `\n`: -/// -/// ``` -/// let s = " -/// "; -/// assert_eq!(s.as_bytes(), &[10]); -/// ``` -/// -/// so this should always be correct. -fn read_text(path: &Path) -> String { - fs::read_to_string(path) - .expect(&format!("File at {:?} should be valid", path)) - .replace("\r\n", "\n") -} - -fn dir_tests(paths: &[&str], f: F) -where - F: Fn(&str, &Path) -> String, -{ - for (path, input_code) in collect_tests(paths) { - let parse_tree = f(&input_code, &path); - let path = path.with_extension("txt"); - if !path.exists() { - println!("\nfile: {}", path.display()); - println!("No .txt file with expected result, creating...\n"); - println!("{}\n{}", input_code, parse_tree); - fs::write(&path, &parse_tree).unwrap(); - panic!("No expected result") - } - let expected = read_text(&path); - let expected = expected.as_str(); - let parse_tree = parse_tree.as_str(); - assert_equal_text(expected, parse_tree, &path); - } -} - -const REWRITE: bool = false; - -fn assert_equal_text(expected: &str, actual: &str, path: &Path) { - if expected == actual { - return; - } - let dir = project_dir(); - let pretty_path = path.strip_prefix(&dir).unwrap_or_else(|_| path); - if expected.trim() == actual.trim() { - println!("whitespace difference, rewriting"); - println!("file: {}\n", pretty_path.display()); - fs::write(path, actual).unwrap(); - return; - } - if REWRITE { - println!("rewriting {}", pretty_path.display()); - fs::write(path, actual).unwrap(); - return; - } - assert_eq_text!(expected, actual, "file: {}", pretty_path.display()); -} - -fn collect_tests(paths: &[&str]) -> Vec<(PathBuf, String)> { - paths - .iter() - .flat_map(|path| { - let path = test_data_dir().join(path); - test_from_dir(&path).into_iter() - }) - .map(|path| { - let text = read_text(&path); - (path, text) - }) - .collect() -} - -fn test_from_dir(dir: &Path) -> Vec { - let mut acc = Vec::new(); - for file in fs::read_dir(&dir).unwrap() { - let file = file.unwrap(); - let path = file.path(); - if path.extension().unwrap_or_default() == "rs" { - acc.push(path); - } - } - acc.sort(); - acc -} - -fn project_dir() -> PathBuf { - let dir = env!("CARGO_MANIFEST_DIR"); - PathBuf::from(dir) - .parent() - .unwrap() - .parent() - .unwrap() - .to_owned() -} fn test_data_dir() -> PathBuf { project_dir().join("crates/ra_syntax/tests/data") diff --git a/crates/test_utils/src/lib.rs b/crates/test_utils/src/lib.rs index beb936c616..012b1d0b40 100644 --- a/crates/test_utils/src/lib.rs +++ b/crates/test_utils/src/lib.rs @@ -1,4 +1,6 @@ use std::fmt; +use std::fs; +use std::path::{Path, PathBuf}; use itertools::Itertools; use text_unit::{TextRange, TextUnit}; @@ -262,3 +264,100 @@ pub fn find_mismatch<'a>(expected: &'a Value, actual: &'a Value) -> Option<(&'a _ => Some((expected, actual)), } } + +pub fn dir_tests(test_data_dir: &Path, paths: &[&str], f: F) +where + F: Fn(&str, &Path) -> String, +{ + for (path, input_code) in collect_tests(test_data_dir, paths) { + let parse_tree = f(&input_code, &path); + let path = path.with_extension("txt"); + if !path.exists() { + println!("\nfile: {}", path.display()); + println!("No .txt file with expected result, creating...\n"); + println!("{}\n{}", input_code, parse_tree); + fs::write(&path, &parse_tree).unwrap(); + panic!("No expected result") + } + let expected = read_text(&path); + let expected = expected.as_str(); + let parse_tree = parse_tree.as_str(); + assert_equal_text(expected, parse_tree, &path); + } +} + +pub fn collect_tests(test_data_dir: &Path, paths: &[&str]) -> Vec<(PathBuf, String)> { + paths + .iter() + .flat_map(|path| { + let path = test_data_dir.to_owned().join(path); + test_from_dir(&path).into_iter() + }) + .map(|path| { + let text = read_text(&path); + (path, text) + }) + .collect() +} + +fn test_from_dir(dir: &Path) -> Vec { + let mut acc = Vec::new(); + for file in fs::read_dir(&dir).unwrap() { + let file = file.unwrap(); + let path = file.path(); + if path.extension().unwrap_or_default() == "rs" { + acc.push(path); + } + } + acc.sort(); + acc +} + +pub fn project_dir() -> PathBuf { + let dir = env!("CARGO_MANIFEST_DIR"); + PathBuf::from(dir) + .parent() + .unwrap() + .parent() + .unwrap() + .to_owned() +} + +/// Read file and normalize newlines. +/// +/// `rustc` seems to always normalize `\r\n` newlines to `\n`: +/// +/// ``` +/// let s = " +/// "; +/// assert_eq!(s.as_bytes(), &[10]); +/// ``` +/// +/// so this should always be correct. +pub fn read_text(path: &Path) -> String { + fs::read_to_string(path) + .expect(&format!("File at {:?} should be valid", path)) + .replace("\r\n", "\n") +} + +const REWRITE: bool = false; + +fn assert_equal_text(expected: &str, actual: &str, path: &Path) { + if expected == actual { + return; + } + let dir = project_dir(); + let pretty_path = path.strip_prefix(&dir).unwrap_or_else(|_| path); + if expected.trim() == actual.trim() { + println!("whitespace difference, rewriting"); + println!("file: {}\n", pretty_path.display()); + fs::write(path, actual).unwrap(); + return; + } + if REWRITE { + println!("rewriting {}", pretty_path.display()); + fs::write(path, actual).unwrap(); + return; + } + assert_eq_text!(expected, actual, "file: {}", pretty_path.display()); +}