diff --git a/crates/ide_assists/src/handlers/extract_function.rs b/crates/ide_assists/src/handlers/extract_function.rs index 6311afc1f5..a2dba915ce 100644 --- a/crates/ide_assists/src/handlers/extract_function.rs +++ b/crates/ide_assists/src/handlers/extract_function.rs @@ -642,6 +642,10 @@ fn vars_used_in_body(ctx: &AssistContext, body: &FunctionBody) -> Vec { .collect() } +fn body_contains_await(body: &FunctionBody) -> bool { + body.descendants().any(|d| matches!(d.kind(), SyntaxKind::AWAIT_EXPR)) +} + /// find `self` param, that was not defined inside `body` /// /// It should skip `self` params from impls inside `body` @@ -1123,9 +1127,10 @@ fn format_function( let params = make_param_list(ctx, module, fun); let ret_ty = make_ret_ty(ctx, module, fun); let body = make_body(ctx, old_indent, new_indent, fun); + let async_kw = if body_contains_await(&fun.body) { "async " } else { "" }; match ctx.config.snippet_cap { - Some(_) => format_to!(fn_def, "\n\n{}fn $0{}{}", new_indent, fun.name, params), - None => format_to!(fn_def, "\n\n{}fn {}{}", new_indent, fun.name, params), + Some(_) => format_to!(fn_def, "\n\n{}{}fn $0{}{}", new_indent, async_kw, fun.name, params), + None => format_to!(fn_def, "\n\n{}{}fn {}{}", new_indent, async_kw, fun.name, params), } if let Some(ret_ty) = ret_ty { format_to!(fn_def, " {}", ret_ty); @@ -3565,4 +3570,60 @@ fn $0fun_name(n: i32) -> i32 { }", ); } + + #[test] + fn extract_with_await() { + check_assist( + extract_function, + r#"fn main() { + $0some_function().await;$0 +} + +async fn some_function() { + +} +"#, + r#" +fn main() { + fun_name(); +} + +async fn $0fun_name() { + some_function().await; +} + +async fn some_function() { + +} +"#, + ); + } + + #[test] + fn extract_with_await_in_args() { + check_assist( + extract_function, + r#"fn main() { + $0function_call("a", some_function().await);$0 +} + +async fn some_function() { + +} +"#, + r#" +fn main() { + fun_name(); +} + +async fn $0fun_name() { + function_call("a", some_function().await); +} + +async fn some_function() { + +} +"#, + ); + } }