From a5c46e5a88a5f26d96406d43eaaa58acf646a5b6 Mon Sep 17 00:00:00 2001 From: Ryan Levick Date: Thu, 12 Aug 2021 16:05:32 +0200 Subject: [PATCH] Factor out return type handling for both function and method --- .../src/handlers/generate_function.rs | 176 ++++++++---------- 1 file changed, 75 insertions(+), 101 deletions(-) diff --git a/crates/ide_assists/src/handlers/generate_function.rs b/crates/ide_assists/src/handlers/generate_function.rs index ae33a96a0a..0a19e49b78 100644 --- a/crates/ide_assists/src/handlers/generate_function.rs +++ b/crates/ide_assists/src/handlers/generate_function.rs @@ -171,7 +171,7 @@ struct FunctionTemplate { insert_offset: TextSize, leading_ws: String, fn_def: ast::Fn, - ret_type: ast::RetType, + ret_type: Option, should_focus_tail_expr: bool, trailing_ws: String, file: FileId, @@ -183,7 +183,11 @@ impl FunctionTemplate { let f = match cap { Some(cap) => { let cursor = if self.should_focus_tail_expr { - self.ret_type.syntax() + if let Some(ref ret_type) = self.ret_type { + ret_type.syntax() + } else { + self.tail_expr.syntax() + } } else { self.tail_expr.syntax() }; @@ -201,7 +205,7 @@ struct FunctionBuilder { fn_name: ast::Name, type_params: Option, params: ast::ParamList, - ret_type: ast::RetType, + ret_type: Option, should_focus_tail_expr: bool, file: FileId, needs_pub: bool, @@ -235,33 +239,8 @@ impl FunctionBuilder { let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); let is_async = await_expr.is_some(); - // should_focus_tail_expr intends to express a rough level of confidence about - // the correctness of the return type. - // - // If we are able to infer some return type, and that return type is not unit, we - // don't want to render the snippet. The assumption here is in this situation the - // return type is just as likely to be correct as any other part of the generated - // function. - // - // In the case where the return type is inferred as unit it is likely that the - // user does in fact intend for this generated function to return some non unit - // type, but that the current state of their code doesn't allow that return type - // to be accurately inferred. - let (ret_ty, should_focus_tail_expr) = { - match ctx.sema.type_of_expr(&ast::Expr::CallExpr(call.clone())).map(TypeInfo::original) - { - Some(ty) if ty.is_unknown() || ty.is_unit() => (make::ty_unit(), true), - Some(ty) => { - let rendered = ty.display_source_code(ctx.db(), target_module.into()); - match rendered { - Ok(rendered) => (make::ty(&rendered), false), - Err(_) => (make::ty_unit(), true), - } - } - None => (make::ty_unit(), true), - } - }; - let ret_type = make::ret_type(ret_ty); + let (ret_type, should_focus_tail_expr) = + make_return_type(ctx, &ast::Expr::CallExpr(call.clone()), target_module); Some(Self { target, @@ -305,36 +284,8 @@ impl FunctionBuilder { let await_expr = call.syntax().parent().and_then(ast::AwaitExpr::cast); let is_async = await_expr.is_some(); - // should_render_snippet intends to express a rough level of confidence about - // the correctness of the return type. - // - // If we are able to infer some return type, and that return type is not unit, we - // don't want to render the snippet. The assumption here is in this situation the - // return type is just as likely to be correct as any other part of the generated - // function. - // - // In the case where the return type is inferred as unit it is likely that the - // user does in fact intend for this generated function to return some non unit - // type, but that the current state of their code doesn't allow that return type - // to be accurately inferred. - let (ret_ty, should_render_snippet) = { - match ctx - .sema - .type_of_expr(&ast::Expr::MethodCallExpr(call.clone())) - .map(TypeInfo::original) - { - Some(ty) if ty.is_unknown() || ty.is_unit() => (make::ty_unit(), true), - Some(ty) => { - let rendered = ty.display_source_code(ctx.db(), target_module.into()); - match rendered { - Ok(rendered) => (make::ty(&rendered), false), - Err(_) => (make::ty_unit(), true), - } - } - None => (make::ty_unit(), true), - } - }; - let ret_type = make::ret_type(ret_ty); + let (ret_type, should_focus_tail_expr) = + make_return_type(ctx, &ast::Expr::MethodCallExpr(call.clone()), target_module); Some(Self { target, @@ -342,7 +293,7 @@ impl FunctionBuilder { type_params, params, ret_type, - should_focus_tail_expr: should_render_snippet, + should_focus_tail_expr, file, needs_pub, is_async, @@ -359,7 +310,7 @@ impl FunctionBuilder { self.type_params, self.params, fn_body, - Some(self.ret_type), + self.ret_type, self.is_async, ); let leading_ws; @@ -386,7 +337,7 @@ impl FunctionBuilder { insert_offset, leading_ws, // PANIC: we guarantee we always create a function with a return type - ret_type: fn_def.ret_type().unwrap(), + ret_type: fn_def.ret_type(), // PANIC: we guarantee we always create a function body with a tail expr tail_expr: fn_def.body().unwrap().tail_expr().unwrap(), should_focus_tail_expr: self.should_focus_tail_expr, @@ -397,6 +348,29 @@ impl FunctionBuilder { } } +fn make_return_type( + ctx: &AssistContext, + call: &ast::Expr, + target_module: Module, +) -> (Option, bool) { + let (ret_ty, should_focus_tail_expr) = { + match ctx.sema.type_of_expr(call).map(TypeInfo::original) { + Some(ty) if ty.is_unit() => (None, false), + Some(ty) if ty.is_unknown() => (Some(make::ty_unit()), true), + None => (Some(make::ty_unit()), true), + Some(ty) => { + let rendered = ty.display_source_code(ctx.db(), target_module.into()); + match rendered { + Ok(rendered) => (Some(make::ty(&rendered)), false), + Err(_) => (Some(make::ty_unit()), true), + } + } + } + }; + let ret_type = ret_ty.map(|rt| make::ret_type(rt)); + (ret_type, should_focus_tail_expr) +} + enum GeneratedFunctionTarget { BehindItem(SyntaxNode), InEmptyItemList(SyntaxNode), @@ -825,8 +799,8 @@ fn foo() { bar("bar") } -fn bar(arg: &str) ${0:-> ()} { - todo!() +fn bar(arg: &str) { + ${0:todo!()} } "#, ) @@ -846,8 +820,8 @@ fn foo() { bar('x') } -fn bar(arg: char) ${0:-> ()} { - todo!() +fn bar(arg: char) { + ${0:todo!()} } "#, ) @@ -867,8 +841,8 @@ fn foo() { bar(42) } -fn bar(arg: i32) ${0:-> ()} { - todo!() +fn bar(arg: i32) { + ${0:todo!()} } ", ) @@ -888,8 +862,8 @@ fn foo() { bar(42 as u8) } -fn bar(arg: u8) ${0:-> ()} { - todo!() +fn bar(arg: u8) { + ${0:todo!()} } ", ) @@ -913,8 +887,8 @@ fn foo() { bar(x as u8) } -fn bar(x: u8) ${0:-> ()} { - todo!() +fn bar(x: u8) { + ${0:todo!()} } ", ) @@ -936,8 +910,8 @@ fn foo() { bar(worble) } -fn bar(worble: ()) ${0:-> ()} { - todo!() +fn bar(worble: ()) { + ${0:todo!()} } ", ) @@ -965,8 +939,8 @@ fn baz() { bar(foo()) } -fn bar(foo: impl Foo) ${0:-> ()} { - todo!() +fn bar(foo: impl Foo) { + ${0:todo!()} } ", ) @@ -992,8 +966,8 @@ fn foo() { bar(&baz()) } -fn bar(baz: &Baz) ${0:-> ()} { - todo!() +fn bar(baz: &Baz) { + ${0:todo!()} } ", ) @@ -1021,8 +995,8 @@ fn foo() { bar(Baz::baz()) } -fn bar(baz: Baz::Bof) ${0:-> ()} { - todo!() +fn bar(baz: Baz::Bof) { + ${0:todo!()} } ", ) @@ -1043,8 +1017,8 @@ fn foo(t: T) { bar(t) } -fn bar(t: T) ${0:-> ()} { - todo!() +fn bar(t: T) { + ${0:todo!()} } ", ) @@ -1097,8 +1071,8 @@ fn foo() { bar(closure) } -fn bar(closure: ()) ${0:-> ()} { - todo!() +fn bar(closure: ()) { + ${0:todo!()} } ", ) @@ -1118,8 +1092,8 @@ fn foo() { bar(baz) } -fn bar(baz: ()) ${0:-> ()} { - todo!() +fn bar(baz: ()) { + ${0:todo!()} } ", ) @@ -1143,8 +1117,8 @@ fn foo() { bar(baz(), baz()) } -fn bar(baz_1: Baz, baz_2: Baz) ${0:-> ()} { - todo!() +fn bar(baz_1: Baz, baz_2: Baz) { + ${0:todo!()} } ", ) @@ -1168,8 +1142,8 @@ fn foo() { bar(baz(), baz(), "foo", "bar") } -fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) ${0:-> ()} { - todo!() +fn bar(baz_1: Baz, baz_2: Baz, arg_1: &str, arg_2: &str) { + ${0:todo!()} } "#, ) @@ -1188,8 +1162,8 @@ fn foo() { ", r" mod bar { - pub(crate) fn my_fn() ${0:-> ()} { - todo!() + pub(crate) fn my_fn() { + ${0:todo!()} } } @@ -1224,8 +1198,8 @@ fn bar() { baz(foo) } -fn baz(foo: foo::Foo) ${0:-> ()} { - todo!() +fn baz(foo: foo::Foo) { + ${0:todo!()} } "#, ) @@ -1248,8 +1222,8 @@ fn foo() { mod bar { fn something_else() {} - pub(crate) fn my_fn() ${0:-> ()} { - todo!() + pub(crate) fn my_fn() { + ${0:todo!()} } } @@ -1276,8 +1250,8 @@ fn foo() { r" mod bar { mod baz { - pub(crate) fn my_fn() ${0:-> ()} { - todo!() + pub(crate) fn my_fn() { + ${0:todo!()} } } } @@ -1305,8 +1279,8 @@ fn main() { r" -pub(crate) fn bar() ${0:-> ()} { - todo!() +pub(crate) fn bar() { + ${0:todo!()} }", ) }