diff --git a/crates/ra_hir/src/diagnostics.rs b/crates/ra_hir/src/diagnostics.rs index 301109cb8d..475dd5766e 100644 --- a/crates/ra_hir/src/diagnostics.rs +++ b/crates/ra_hir/src/diagnostics.rs @@ -143,3 +143,31 @@ impl AstDiagnostic for MissingFields { ast::RecordFieldList::cast(node).unwrap() } } + +#[derive(Debug)] +pub struct MissingOkInTailExpr { + pub file: HirFileId, + pub expr: AstPtr, +} + +impl Diagnostic for MissingOkInTailExpr { + fn message(&self) -> String { + "wrap return expression in Ok".to_string() + } + fn source(&self) -> Source { + Source { file_id: self.file, ast: self.expr.into() } + } + fn as_any(&self) -> &(dyn Any + Send + 'static) { + self + } +} + +impl AstDiagnostic for MissingOkInTailExpr { + type AST = ast::Expr; + + fn ast(&self, db: &impl HirDatabase) -> Self::AST { + let root = db.parse_or_expand(self.file).unwrap(); + let node = self.source().ast.to_node(&root); + ast::Expr::cast(node).unwrap() + } +} diff --git a/crates/ra_hir/src/expr/validation.rs b/crates/ra_hir/src/expr/validation.rs index 62f7d41f5d..5d9d59ff89 100644 --- a/crates/ra_hir/src/expr/validation.rs +++ b/crates/ra_hir/src/expr/validation.rs @@ -6,11 +6,14 @@ use ra_syntax::ast::{AstNode, RecordLit}; use super::{Expr, ExprId, RecordLitField}; use crate::{ adt::AdtDef, - diagnostics::{DiagnosticSink, MissingFields}, + diagnostics::{DiagnosticSink, MissingFields, MissingOkInTailExpr}, expr::AstPtr, - ty::InferenceResult, - Function, HasSource, HirDatabase, Name, Path, + name, + path::{PathKind, PathSegment}, + ty::{ApplicationTy, InferenceResult, Ty, TypeCtor}, + Function, HasSource, HirDatabase, ModuleDef, Name, Path, PerNs, Resolution, }; +use ra_syntax::ast; pub(crate) struct ExprValidator<'a, 'b: 'a> { func: Function, @@ -29,11 +32,17 @@ impl<'a, 'b> ExprValidator<'a, 'b> { pub(crate) fn validate_body(&mut self, db: &impl HirDatabase) { let body = self.func.body(db); + for e in body.exprs() { if let (id, Expr::RecordLit { path, fields, spread }) = e { self.validate_record_literal(id, path, fields, *spread, db); } } + + let body_expr = &body[body.body_expr()]; + if let Expr::Block { statements: _, tail: Some(t) } = body_expr { + self.validate_results_in_tail_expr(*t, db); + } } fn validate_record_literal( @@ -87,4 +96,42 @@ impl<'a, 'b> ExprValidator<'a, 'b> { }) } } + + fn validate_results_in_tail_expr(&mut self, id: ExprId, db: &impl HirDatabase) { + let mismatch = match self.infer.type_mismatch_for_expr(id) { + Some(m) => m, + None => return, + }; + + let std_result_path = Path { + kind: PathKind::Abs, + segments: vec![ + PathSegment { name: name::STD, args_and_bindings: None }, + PathSegment { name: name::RESULT_MOD, args_and_bindings: None }, + PathSegment { name: name::RESULT_TYPE, args_and_bindings: None }, + ], + }; + + let resolver = self.func.resolver(db); + let std_result_enum = + match resolver.resolve_path_segments(db, &std_result_path).into_fully_resolved() { + PerNs { types: Some(Resolution::Def(ModuleDef::Enum(e))), .. } => e, + _ => return, + }; + + let std_result_ctor = TypeCtor::Adt(AdtDef::Enum(std_result_enum)); + let params = match &mismatch.expected { + Ty::Apply(ApplicationTy { ctor, parameters }) if ctor == &std_result_ctor => parameters, + _ => return, + }; + + if params.len() == 2 && ¶ms[0] == &mismatch.actual { + let source_map = self.func.body_source_map(db); + let file_id = self.func.source(db).file_id; + + if let Some(expr) = source_map.expr_syntax(id).and_then(|n| n.cast::()) { + self.sink.push(MissingOkInTailExpr { file: file_id, expr }); + } + } + } } diff --git a/crates/ra_hir/src/name.rs b/crates/ra_hir/src/name.rs index 6d14eea8ec..9c4822d917 100644 --- a/crates/ra_hir/src/name.rs +++ b/crates/ra_hir/src/name.rs @@ -120,6 +120,8 @@ pub(crate) const TRY: Name = Name::new(SmolStr::new_inline_from_ascii(3, b"Try") pub(crate) const OK: Name = Name::new(SmolStr::new_inline_from_ascii(2, b"Ok")); pub(crate) const FUTURE_MOD: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"future")); pub(crate) const FUTURE_TYPE: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Future")); +pub(crate) const RESULT_MOD: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"result")); +pub(crate) const RESULT_TYPE: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Result")); pub(crate) const OUTPUT: Name = Name::new(SmolStr::new_inline_from_ascii(6, b"Output")); fn resolve_name(text: &SmolStr) -> SmolStr { diff --git a/crates/ra_hir/src/ty/infer.rs b/crates/ra_hir/src/ty/infer.rs index b33de5687c..d94e8154b0 100644 --- a/crates/ra_hir/src/ty/infer.rs +++ b/crates/ra_hir/src/ty/infer.rs @@ -106,6 +106,13 @@ impl Default for BindingMode { } } +/// A mismatch between an expected and an inferred type. +#[derive(Clone, PartialEq, Eq, Debug, Hash)] +pub struct TypeMismatch { + pub expected: Ty, + pub actual: Ty, +} + /// The result of type inference: A mapping from expressions and patterns to types. #[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct InferenceResult { @@ -120,6 +127,7 @@ pub struct InferenceResult { diagnostics: Vec, pub(super) type_of_expr: ArenaMap, pub(super) type_of_pat: ArenaMap, + pub(super) type_mismatches: ArenaMap, } impl InferenceResult { @@ -141,6 +149,9 @@ impl InferenceResult { pub fn assoc_resolutions_for_pat(&self, id: PatId) -> Option { self.assoc_resolutions.get(&id.into()).copied() } + pub fn type_mismatch_for_expr(&self, expr: ExprId) -> Option<&TypeMismatch> { + self.type_mismatches.get(expr) + } pub(crate) fn add_diagnostics( &self, db: &impl HirDatabase, @@ -1345,9 +1356,15 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> { }; // use a new type variable if we got Ty::Unknown here let ty = self.insert_type_vars_shallow(ty); - self.unify(&ty, &expected.ty); + let could_unify = self.unify(&ty, &expected.ty); let ty = self.resolve_ty_as_possible(&mut vec![], ty); self.write_expr_ty(tgt_expr, ty.clone()); + if !could_unify { + self.result.type_mismatches.insert( + tgt_expr, + TypeMismatch { expected: expected.ty.clone(), actual: ty.clone() }, + ); + } ty } diff --git a/crates/ra_ide_api/src/diagnostics.rs b/crates/ra_ide_api/src/diagnostics.rs index c2b959cb3c..1a4882824f 100644 --- a/crates/ra_ide_api/src/diagnostics.rs +++ b/crates/ra_ide_api/src/diagnostics.rs @@ -75,6 +75,19 @@ pub(crate) fn diagnostics(db: &RootDatabase, file_id: FileId) -> Vec severity: Severity::Error, fix: Some(fix), }) + }) + .on::(|d| { + let node = d.ast(db); + let mut builder = TextEditBuilder::default(); + let replacement = format!("Ok({})", node.syntax()); + builder.replace(node.syntax().text_range(), replacement); + let fix = SourceChange::source_file_edit_from("wrap with ok", file_id, builder.finish()); + res.borrow_mut().push(Diagnostic { + range: d.highlight_range(), + message: d.message(), + severity: Severity::Error, + fix: Some(fix), + }) }); if let Some(m) = source_binder::module_from_file_id(db, file_id) { m.diagnostics(db, &mut sink); @@ -171,10 +184,11 @@ fn check_struct_shorthand_initialization( #[cfg(test)] mod tests { use insta::assert_debug_snapshot_matches; + use join_to_string::join; use ra_syntax::SourceFile; use test_utils::assert_eq_text; - use crate::mock_analysis::single_file; + use crate::mock_analysis::{analysis_and_position, single_file}; use super::*; @@ -203,6 +217,48 @@ mod tests { assert_eq_text!(after, &actual); } + /// Takes a multi-file input fixture with annotated cursor positions, + /// and checks that: + /// * a diagnostic is produced + /// * this diagnostic touches the input cursor position + /// * that the contents of the file containing the cursor match `after` after the diagnostic fix is applied + fn check_apply_diagnostic_fix_from_position(fixture: &str, after: &str) { + let (analysis, file_position) = analysis_and_position(fixture); + let diagnostic = analysis.diagnostics(file_position.file_id).unwrap().pop().unwrap(); + let mut fix = diagnostic.fix.unwrap(); + let edit = fix.source_file_edits.pop().unwrap().edit; + let target_file_contents = analysis.file_text(file_position.file_id).unwrap(); + let actual = edit.apply(&target_file_contents); + + // Strip indent and empty lines from `after`, to match the behaviour of + // `parse_fixture` called from `analysis_and_position`. + let margin = fixture + .lines() + .filter(|it| it.trim_start().starts_with("//-")) + .map(|it| it.len() - it.trim_start().len()) + .next() + .expect("empty fixture"); + let after = join(after.lines().filter_map(|line| { + if line.len() > margin { + Some(&line[margin..]) + } else { + None + } + })) + .separator("\n") + .suffix("\n") + .to_string(); + + assert_eq_text!(&after, &actual); + assert!( + diagnostic.range.start() <= file_position.offset + && diagnostic.range.end() >= file_position.offset, + "diagnostic range {} does not touch cursor position {}", + diagnostic.range, + file_position.offset + ); + } + fn check_apply_diagnostic_fix(before: &str, after: &str) { let (analysis, file_id) = single_file(before); let diagnostic = analysis.diagnostics(file_id).unwrap().pop().unwrap(); @@ -212,12 +268,169 @@ mod tests { assert_eq_text!(after, &actual); } + /// Takes a multi-file input fixture with annotated cursor position and checks that no diagnostics + /// apply to the file containing the cursor. + fn check_no_diagnostic_for_target_file(fixture: &str) { + let (analysis, file_position) = analysis_and_position(fixture); + let diagnostics = analysis.diagnostics(file_position.file_id).unwrap(); + assert_eq!(diagnostics.len(), 0); + } + fn check_no_diagnostic(content: &str) { let (analysis, file_id) = single_file(content); let diagnostics = analysis.diagnostics(file_id).unwrap(); assert_eq!(diagnostics.len(), 0); } + #[test] + fn test_wrap_return_type() { + let before = r#" + //- /main.rs + use std::{string::String, result::Result::{self, Ok, Err}}; + + fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err("div by zero".into()); + } + x / y<|> + } + + //- /std/lib.rs + pub mod string { + pub struct String { } + } + pub mod result { + pub enum Result { Ok(T), Err(E) } + } + "#; + let after = r#" + use std::{string::String, result::Result::{self, Ok, Err}}; + + fn div(x: i32, y: i32) -> Result { + if y == 0 { + return Err("div by zero".into()); + } + Ok(x / y) + } + "#; + check_apply_diagnostic_fix_from_position(before, after); + } + + #[test] + fn test_wrap_return_type_handles_generic_functions() { + let before = r#" + //- /main.rs + use std::result::Result::{self, Ok, Err}; + + fn div(x: T) -> Result { + if x == 0 { + return Err(7); + } + <|>x + } + + //- /std/lib.rs + pub mod result { + pub enum Result { Ok(T), Err(E) } + } + "#; + let after = r#" + use std::result::Result::{self, Ok, Err}; + + fn div(x: T) -> Result { + if x == 0 { + return Err(7); + } + Ok(x) + } + "#; + check_apply_diagnostic_fix_from_position(before, after); + } + + #[test] + fn test_wrap_return_type_handles_type_aliases() { + let before = r#" + //- /main.rs + use std::{string::String, result::Result::{self, Ok, Err}}; + + type MyResult = Result; + + fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + return Err("div by zero".into()); + } + x <|>/ y + } + + //- /std/lib.rs + pub mod string { + pub struct String { } + } + pub mod result { + pub enum Result { Ok(T), Err(E) } + } + "#; + let after = r#" + use std::{string::String, result::Result::{self, Ok, Err}}; + + type MyResult = Result; + fn div(x: i32, y: i32) -> MyResult { + if y == 0 { + return Err("div by zero".into()); + } + Ok(x / y) + } + "#; + check_apply_diagnostic_fix_from_position(before, after); + } + + #[test] + fn test_wrap_return_type_not_applicable_when_expr_type_does_not_match_ok_type() { + let content = r#" + //- /main.rs + use std::{string::String, result::Result::{self, Ok, Err}}; + + fn foo() -> Result { + 0<|> + } + + //- /std/lib.rs + pub mod string { + pub struct String { } + } + pub mod result { + pub enum Result { Ok(T), Err(E) } + } + "#; + check_no_diagnostic_for_target_file(content); + } + + #[test] + fn test_wrap_return_type_not_applicable_when_return_type_is_not_result() { + let content = r#" + //- /main.rs + use std::{string::String, result::Result::{self, Ok, Err}}; + + enum SomeOtherEnum { + Ok(i32), + Err(String), + } + + fn foo() -> SomeOtherEnum { + 0<|> + } + + //- /std/lib.rs + pub mod string { + pub struct String { } + } + pub mod result { + pub enum Result { Ok(T), Err(E) } + } + "#; + check_no_diagnostic_for_target_file(content); + } + #[test] fn test_fill_struct_fields_empty() { let before = r" diff --git a/crates/ra_syntax/src/ptr.rs b/crates/ra_syntax/src/ptr.rs index d24660ac3a..992034ef0f 100644 --- a/crates/ra_syntax/src/ptr.rs +++ b/crates/ra_syntax/src/ptr.rs @@ -31,6 +31,13 @@ impl SyntaxNodePtr { pub fn kind(self) -> SyntaxKind { self.kind } + + pub fn cast(self) -> Option> { + if !N::can_cast(self.kind()) { + return None; + } + Some(AstPtr { raw: self, _ty: PhantomData }) + } } /// Like `SyntaxNodePtr`, but remembers the type of node