Fix inference of AsyncFnX return type

This commit is contained in:
Chayim Refael Friedman 2025-05-27 06:44:34 +03:00
parent 7fa66d67a7
commit 2a7f18bbda
5 changed files with 81 additions and 11 deletions

View file

@ -259,7 +259,7 @@ impl chalk_solve::RustIrDatabase<Interner> for ChalkContext<'_> {
}
fn well_known_trait_id(
&self,
well_known_trait: rust_ir::WellKnownTrait,
well_known_trait: WellKnownTrait,
) -> Option<chalk_ir::TraitId<Interner>> {
let lang_attr = lang_item_from_well_known_trait(well_known_trait);
let trait_ = lang_attr.resolve_trait(self.db, self.krate)?;

View file

@ -1463,6 +1463,8 @@ impl HirDisplay for Ty {
}
if f.closure_style == ClosureStyle::RANotation || !sig.ret().is_unit() {
write!(f, " -> ")?;
// FIXME: We display `AsyncFn` as `-> impl Future`, but this is hard to fix because
// we don't have a trait environment here, required to normalize `<Ret as Future>::Output`.
sig.ret().hir_fmt(f)?;
}
} else {

View file

@ -38,7 +38,7 @@ use crate::{
infer::{BreakableKind, CoerceMany, Diverges, coerce::CoerceNever},
make_binders,
mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem},
to_chalk_trait_id,
to_assoc_type_id, to_chalk_trait_id,
traits::FnTrait,
utils::{self, elaborate_clause_supertraits},
};
@ -245,7 +245,7 @@ impl InferenceContext<'_> {
}
fn deduce_closure_kind_from_predicate_clauses(
&self,
&mut self,
expected_ty: &Ty,
clauses: impl DoubleEndedIterator<Item = WhereClause>,
closure_kind: ClosureKind,
@ -378,7 +378,7 @@ impl InferenceContext<'_> {
}
fn deduce_sig_from_projection(
&self,
&mut self,
closure_kind: ClosureKind,
projection_ty: &ProjectionTy,
projected_ty: &Ty,
@ -392,13 +392,16 @@ impl InferenceContext<'_> {
// For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits,
// for closures and async closures, respectively.
match closure_kind {
ClosureKind::Closure | ClosureKind::Async
if self.fn_trait_kind_from_trait_id(trait_).is_some() =>
{
self.extract_sig_from_projection(projection_ty, projected_ty)
let fn_trait_kind = self.fn_trait_kind_from_trait_id(trait_)?;
if !matches!(closure_kind, ClosureKind::Closure | ClosureKind::Async) {
return None;
}
_ => None,
if fn_trait_kind.is_async() {
// If the expected trait is `AsyncFn(...) -> X`, we don't know what the return type is,
// but we do know it must implement `Future<Output = X>`.
self.extract_async_fn_sig_from_projection(projection_ty, projected_ty)
} else {
self.extract_sig_from_projection(projection_ty, projected_ty)
}
}
@ -424,6 +427,39 @@ impl InferenceContext<'_> {
)))
}
fn extract_async_fn_sig_from_projection(
&mut self,
projection_ty: &ProjectionTy,
projected_ty: &Ty,
) -> Option<FnSubst<Interner>> {
let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner);
let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else {
return None;
};
let ret_param_future_output = projected_ty;
let ret_param_future = self.table.new_type_var();
let future_output =
LangItem::FutureOutput.resolve_type_alias(self.db, self.resolver.krate())?;
let future_projection = crate::AliasTy::Projection(crate::ProjectionTy {
associated_ty_id: to_assoc_type_id(future_output),
substitution: Substitution::from1(Interner, ret_param_future.clone()),
});
self.table.register_obligation(
crate::AliasEq { alias: future_projection, ty: ret_param_future_output.clone() }
.cast(Interner),
);
Some(FnSubst(Substitution::from_iter(
Interner,
input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new(
Interner,
chalk_ir::GenericArgData::Ty(ret_param_future),
))),
)))
}
fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> {
FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?)
}

View file

@ -4903,3 +4903,30 @@ fn main() {
"#]],
);
}
#[test]
fn async_fn_return_type() {
check_infer(
r#"
//- minicore: async_fn
fn foo<F: AsyncFn() -> R, R>(_: F) -> R {
loop {}
}
fn main() {
foo(async move || ());
}
"#,
expect![[r#"
29..30 '_': F
40..55 '{ loop {} }': R
46..53 'loop {}': !
51..53 '{}': ()
67..97 '{ ...()); }': ()
73..76 'foo': fn foo<impl AsyncFn() -> impl Future<Output = ()>, ()>(impl AsyncFn() -> impl Future<Output = ()>)
73..94 'foo(as...|| ())': ()
77..93 'async ... || ()': impl AsyncFn() -> impl Future<Output = ()>
91..93 '()': ()
"#]],
);
}

View file

@ -291,4 +291,9 @@ impl FnTrait {
pub fn get_id(self, db: &dyn HirDatabase, krate: Crate) -> Option<TraitId> {
self.lang_item().resolve_trait(db, krate)
}
#[inline]
pub(crate) fn is_async(self) -> bool {
matches!(self, FnTrait::AsyncFn | FnTrait::AsyncFnMut | FnTrait::AsyncFnOnce)
}
}