diff --git a/crates/hir/src/semantics.rs b/crates/hir/src/semantics.rs index 763f53031e..c78b59826c 100644 --- a/crates/hir/src/semantics.rs +++ b/crates/hir/src/semantics.rs @@ -14,6 +14,7 @@ use hir_def::{ hir::Expr, lower::LowerCtx, nameres::MacroSubNs, + path::ModPath, resolver::{self, HasResolver, Resolver, TypeNs}, type_ref::Mutability, AsMacroCall, DefWithBodyId, FunctionId, MacroId, TraitId, VariantId, @@ -46,9 +47,9 @@ use crate::{ source_analyzer::{resolve_hir_path, SourceAnalyzer}, Access, Adjust, Adjustment, Adt, AutoBorrow, BindingMode, BuiltinAttr, Callable, Const, ConstParam, Crate, DeriveHelper, Enum, Field, Function, HasSource, HirFileId, Impl, InFile, - Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path, ScopeDef, - Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias, TypeParam, Union, - Variant, VariantDef, + ItemInNs, Label, LifetimeParam, Local, Macro, Module, ModuleDef, Name, OverloadedDeref, Path, + ScopeDef, Static, Struct, ToolModule, Trait, TraitAlias, TupleField, Type, TypeAlias, + TypeParam, Union, Variant, VariantDef, }; const CONTINUE_NO_BREAKS: ControlFlow = ControlFlow::Continue(()); @@ -1384,6 +1385,16 @@ impl<'db> SemanticsImpl<'db> { self.analyze(path.syntax())?.resolve_path(self.db, path) } + pub fn resolve_mod_path( + &self, + scope: &SyntaxNode, + path: &ModPath, + ) -> Option> { + let analyze = self.analyze(scope)?; + let items = analyze.resolver.resolve_module_path_in_items(self.db.upcast(), path); + Some(items.iter_items().map(|(item, _)| item.into())) + } + fn resolve_variant(&self, record_lit: ast::RecordExpr) -> Option { self.analyze(record_lit.syntax())?.resolve_variant(self.db, record_lit) } diff --git a/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs b/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs index b68ed00f77..8f0e9b4fe0 100644 --- a/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs +++ b/crates/ide-assists/src/handlers/wrap_return_type_in_result.rs @@ -1,12 +1,14 @@ use std::iter; +use hir::HasSource; use ide_db::{ famous_defs::FamousDefs, syntax_helpers::node_ext::{for_each_tail_expr, walk_expr}, }; +use itertools::Itertools; use syntax::{ - ast::{self, make, Expr}, - match_ast, ted, AstNode, + ast::{self, make, Expr, HasGenericParams}, + match_ast, ted, AstNode, ToSmolStr, }; use crate::{AssistContext, AssistId, AssistKind, Assists}; @@ -39,25 +41,22 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext< }; let type_ref = &ret_type.ty()?; - let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); - let result_enum = + let core_result = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate()).core_result_Result()?; - if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == result_enum) { + let ty = ctx.sema.resolve_type(type_ref)?.as_adt(); + if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_result) { + // The return type is already wrapped in a Result cov_mark::hit!(wrap_return_type_in_result_simple_return_type_already_result); return None; } - let new_result_ty = - make::ext::ty_result(type_ref.clone(), make::ty_placeholder()).clone_for_update(); - let generic_args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast)?; - let last_genarg = generic_args.generic_args().last()?; - acc.add( AssistId("wrap_return_type_in_result", AssistKind::RefactorRewrite), "Wrap return type in Result", type_ref.syntax().text_range(), |edit| { + let new_result_ty = result_type(ctx, &core_result, type_ref).clone_for_update(); let body = edit.make_mut(ast::Expr::BlockExpr(body)); let mut exprs_to_wrap = Vec::new(); @@ -81,16 +80,72 @@ pub(crate) fn wrap_return_type_in_result(acc: &mut Assists, ctx: &AssistContext< } let old_result_ty = edit.make_mut(type_ref.clone()); - ted::replace(old_result_ty.syntax(), new_result_ty.syntax()); - if let Some(cap) = ctx.config.snippet_cap { - edit.add_placeholder_snippet(cap, last_genarg); + // Add a placeholder snippet at the first generic argument that doesn't equal the return type. + // This is normally the error type, but that may not be the case when we inserted a type alias. + let args = new_result_ty.syntax().descendants().find_map(ast::GenericArgList::cast); + let error_type_arg = args.and_then(|list| { + list.generic_args().find(|arg| match arg { + ast::GenericArg::TypeArg(_) => arg.syntax().text() != type_ref.syntax().text(), + ast::GenericArg::LifetimeArg(_) => false, + _ => true, + }) + }); + if let Some(error_type_arg) = error_type_arg { + if let Some(cap) = ctx.config.snippet_cap { + edit.add_placeholder_snippet(cap, error_type_arg); + } } }, ) } +fn result_type( + ctx: &AssistContext<'_>, + core_result: &hir::Enum, + ret_type: &ast::Type, +) -> ast::Type { + // Try to find a Result type alias in the current scope (shadowing the default). + let result_path = hir::ModPath::from_segments( + hir::PathKind::Plain, + iter::once(hir::Name::new_symbol_root(hir::sym::Result.clone())), + ); + let alias = ctx.sema.resolve_mod_path(ret_type.syntax(), &result_path).and_then(|def| { + def.filter_map(|def| match def.as_module_def()? { + hir::ModuleDef::TypeAlias(alias) => { + let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?; + (&enum_ty == core_result).then_some(alias) + } + _ => None, + }) + .find_map(|alias| { + let mut inserted_ret_type = false; + let generic_params = alias + .source(ctx.db())? + .value + .generic_param_list()? + .generic_params() + .map(|param| match param { + // Replace the very first type parameter with the functions return type. + ast::GenericParam::TypeParam(_) if !inserted_ret_type => { + inserted_ret_type = true; + ret_type.to_smolstr() + } + ast::GenericParam::LifetimeParam(_) => make::lifetime("'_").to_smolstr(), + _ => make::ty_placeholder().to_smolstr(), + }) + .join(", "); + + let name = alias.name(ctx.db()); + let name = name.as_str(); + Some(make::ty(&format!("{name}<{generic_params}>"))) + }) + }); + // If there is no applicable alias in scope use the default Result type. + alias.unwrap_or_else(|| make::ext::ty_result(ret_type.clone(), make::ty_placeholder())) +} + fn tail_cb_impl(acc: &mut Vec, e: &ast::Expr) { match e { Expr::BreakExpr(break_expr) => { @@ -998,4 +1053,216 @@ fn foo(the_field: u32) -> Result { "#, ); } + + #[test] + fn wrap_return_type_in_local_result_type() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +type Result = core::result::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result2 = core::result::Result; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +type Result2 = core::result::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_imported_local_result_type() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::Result; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::*; + +fn foo() -> i3$02 { + return 42i32; +} +"#, + r#" +mod some_module { + pub type Result = core::result::Result; +} + +use some_module::*; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_local_result_type_from_function_body() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +fn foo() -> i3$02 { + type Result = core::result::Result; + 0 +} +"#, + r#" +fn foo() -> Result { + type Result = core::result::Result; + Ok(0) +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_local_result_type_already_using_alias() { + check_assist_not_applicable( + wrap_return_type_in_result, + r#" +//- minicore: result +pub type Result = core::result::Result; + +fn foo() -> Result { + return Ok(42i32); +} +"#, + ); + } + + #[test] + fn wrap_return_type_in_local_result_type_multiple_generics() { + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result; + +fn foo() -> i3$02 { + 0 +} +"#, + r#" +type Result = core::result::Result; + +fn foo() -> Result { + Ok(0) +} +"#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result, ()>; + +fn foo() -> i3$02 { + 0 +} + "#, + r#" +type Result = core::result::Result, ()>; + +fn foo() -> Result { + Ok(0) +} + "#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result<'a, T, E> = core::result::Result, &'a ()>; + +fn foo() -> i3$02 { + 0 +} + "#, + r#" +type Result<'a, T, E> = core::result::Result, &'a ()>; + +fn foo() -> Result<'_, i32, ${0:_}> { + Ok(0) +} + "#, + ); + + check_assist( + wrap_return_type_in_result, + r#" +//- minicore: result +type Result = core::result::Result, Bar>; + +fn foo() -> i3$02 { + 0 +} + "#, + r#" +type Result = core::result::Result, Bar>; + +fn foo() -> Result { + Ok(0) +} + "#, + ); + } }