diff --git a/crates/hir-def/src/body.rs b/crates/hir-def/src/body.rs index 7f9b9476dc..bd87eaa221 100644 --- a/crates/hir-def/src/body.rs +++ b/crates/hir-def/src/body.rs @@ -353,8 +353,9 @@ impl Body { let _p = profile::span("body_with_source_map_query"); let mut params = None; - let (file_id, module, body) = match def { + let (file_id, module, body, is_async_fn) = match def { DefWithBodyId::FunctionId(f) => { + let data = db.function_data(f); let f = f.lookup(db); let src = f.source(db); params = src.value.param_list().map(|param_list| { @@ -371,27 +372,33 @@ impl Body { }), ) }); - (src.file_id, f.module(db), src.value.body().map(ast::Expr::from)) + ( + src.file_id, + f.module(db), + src.value.body().map(ast::Expr::from), + data.has_async_kw(), + ) } DefWithBodyId::ConstId(c) => { let c = c.lookup(db); let src = c.source(db); - (src.file_id, c.module(db), src.value.body()) + (src.file_id, c.module(db), src.value.body(), false) } DefWithBodyId::StaticId(s) => { let s = s.lookup(db); let src = s.source(db); - (src.file_id, s.module(db), src.value.body()) + (src.file_id, s.module(db), src.value.body(), false) } DefWithBodyId::VariantId(v) => { let e = v.parent.lookup(db); let src = v.parent.child_source(db); let variant = &src.value[v.local_id]; - (src.file_id, e.container, variant.expr()) + (src.file_id, e.container, variant.expr(), false) } }; let expander = Expander::new(db, file_id, module); - let (mut body, source_map) = Body::new(db, expander, params, body, module.krate); + let (mut body, source_map) = + Body::new(db, expander, params, body, module.krate, is_async_fn); body.shrink_to_fit(); (Arc::new(body), Arc::new(source_map)) @@ -421,8 +428,9 @@ impl Body { params: Option<(ast::ParamList, impl Iterator)>, body: Option, krate: CrateId, + is_async_fn: bool, ) -> (Body, BodySourceMap) { - lower::lower(db, expander, params, body, krate) + lower::lower(db, expander, params, body, krate, is_async_fn) } fn shrink_to_fit(&mut self) { diff --git a/crates/hir-def/src/body/lower.rs b/crates/hir-def/src/body/lower.rs index 886d71ebed..5362737583 100644 --- a/crates/hir-def/src/body/lower.rs +++ b/crates/hir-def/src/body/lower.rs @@ -84,6 +84,7 @@ pub(super) fn lower( params: Option<(ast::ParamList, impl Iterator)>, body: Option, krate: CrateId, + is_async_fn: bool, ) -> (Body, BodySourceMap) { ExprCollector { db, @@ -105,7 +106,7 @@ pub(super) fn lower( is_lowering_assignee_expr: false, is_lowering_generator: false, } - .collect(params, body) + .collect(params, body, is_async_fn) } struct ExprCollector<'a> { @@ -141,6 +142,7 @@ impl ExprCollector<'_> { mut self, param_list: Option<(ast::ParamList, impl Iterator)>, body: Option, + is_async_fn: bool, ) -> (Body, BodySourceMap) { if let Some((param_list, mut attr_enabled)) = param_list { if let Some(self_param) = @@ -170,7 +172,25 @@ impl ExprCollector<'_> { } }; - self.body.body_expr = self.collect_expr_opt(body); + self.body.body_expr = if is_async_fn { + self.current_try_block = + Some(self.alloc_label_desugared(Label { name: Name::generate_new_name() })); + let expr = self.collect_expr_opt(body); + let expr = self.alloc_expr_desugared(Expr::Block { + id: None, + statements: Box::new([]), + tail: Some(expr), + label: self.current_try_block, + }); + let expr = self.alloc_expr_desugared(Expr::Async { + id: None, + statements: Box::new([]), + tail: Some(expr), + }); + expr + } else { + self.collect_expr_opt(body) + }; (self.body, self.source_map) } diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 08ba80cdff..90f67e449d 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -459,7 +459,6 @@ pub(crate) struct InferenceContext<'a> { resume_yield_tys: Option<(Ty, Ty)>, diverges: Diverges, breakables: Vec, - is_async_fn: bool, } #[derive(Clone, Debug)] @@ -527,7 +526,6 @@ impl<'a> InferenceContext<'a> { resolver, diverges: Diverges::Maybe, breakables: Vec::new(), - is_async_fn: false, } } @@ -639,9 +637,6 @@ impl<'a> InferenceContext<'a> { self.infer_top_pat(*pat, &ty); } let return_ty = &*data.ret_type; - if data.has_async_kw() { - self.is_async_fn = true; - } let ctx = crate::lower::TyLoweringContext::new(self.db, &self.resolver) .with_impl_trait_mode(ImplTraitLoweringMode::Opaque); diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 23ef32db22..035f61fc18 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -293,7 +293,6 @@ impl<'a> InferenceContext<'a> { // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); - let prev_is_async_fn = mem::replace(&mut self.is_async_fn, false); let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); let prev_ret_coercion = mem::replace(&mut self.return_coercion, Some(CoerceMany::new(ret_ty))); @@ -307,7 +306,6 @@ impl<'a> InferenceContext<'a> { self.diverges = prev_diverges; self.return_ty = prev_ret_ty; self.return_coercion = prev_ret_coercion; - self.is_async_fn = prev_is_async_fn; self.resume_yield_tys = prev_resume_yield_tys; ty @@ -963,11 +961,7 @@ impl<'a> InferenceContext<'a> { .as_mut() .expect("infer_return called outside function body") .expected_ty(); - let return_expr_ty = if self.is_async_fn { - self.infer_async_block(expr, &None, &[], &Some(expr)) - } else { - self.infer_expr_inner(expr, &Expectation::HasType(ret_ty)) - }; + let return_expr_ty = self.infer_expr_inner(expr, &Expectation::HasType(ret_ty)); let mut coerce_many = self.return_coercion.take().unwrap(); coerce_many.coerce(self, Some(expr), &return_expr_ty); self.return_coercion = Some(coerce_many); diff --git a/crates/hir-ty/src/tests/simple.rs b/crates/hir-ty/src/tests/simple.rs index 17663ad38b..0c037a39ec 100644 --- a/crates/hir-ty/src/tests/simple.rs +++ b/crates/hir-ty/src/tests/simple.rs @@ -2094,6 +2094,24 @@ async fn main() { "#]], ) } + +#[test] +fn async_fn_and_try_operator() { + check_no_mismatches( + r#" +//- minicore: future, result, fn, try, from +async fn foo() -> Result<(), ()> { + Ok(()) +} + +async fn bar() -> Result<(), ()> { + let x = foo().await?; + Ok(x) +} + "#, + ) +} + #[test] fn async_block_early_return() { check_infer(