Merge pull request #18594 from ChayimFriedman2/async-closures

feat: Support `AsyncFnX` traits
This commit is contained in:
Lukas Wirth 2024-12-06 12:48:47 +00:00 committed by GitHub
commit abc7147bb7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 289 additions and 66 deletions

View file

@ -1287,8 +1287,8 @@ impl InferenceContext<'_> {
tgt_expr: ExprId,
) {
match fn_x {
FnTrait::FnOnce => (),
FnTrait::FnMut => {
FnTrait::FnOnce | FnTrait::AsyncFnOnce => (),
FnTrait::FnMut | FnTrait::AsyncFnMut => {
if let TyKind::Ref(Mutability::Mut, lt, inner) = derefed_callee.kind(Interner) {
if adjustments
.last()
@ -1312,7 +1312,7 @@ impl InferenceContext<'_> {
));
}
}
FnTrait::Fn => {
FnTrait::Fn | FnTrait::AsyncFn => {
if !matches!(derefed_callee.kind(Interner), TyKind::Ref(Mutability::Not, _, _)) {
adjustments.push(Adjustment::borrow(
Mutability::Not,

View file

@ -794,69 +794,75 @@ impl<'a> InferenceTable<'a> {
ty: &Ty,
num_args: usize,
) -> Option<(FnTrait, Vec<Ty>, Ty)> {
let krate = self.trait_env.krate;
let fn_once_trait = FnTrait::FnOnce.get_id(self.db, krate)?;
let trait_data = self.db.trait_data(fn_once_trait);
let output_assoc_type =
trait_data.associated_type_by_name(&Name::new_symbol_root(sym::Output.clone()))?;
for (fn_trait_name, output_assoc_name, subtraits) in [
(FnTrait::FnOnce, sym::Output.clone(), &[FnTrait::Fn, FnTrait::FnMut][..]),
(FnTrait::AsyncFnMut, sym::CallRefFuture.clone(), &[FnTrait::AsyncFn]),
(FnTrait::AsyncFnOnce, sym::CallOnceFuture.clone(), &[]),
] {
let krate = self.trait_env.krate;
let fn_trait = fn_trait_name.get_id(self.db, krate)?;
let trait_data = self.db.trait_data(fn_trait);
let output_assoc_type =
trait_data.associated_type_by_name(&Name::new_symbol_root(output_assoc_name))?;
let mut arg_tys = Vec::with_capacity(num_args);
let arg_ty = TyBuilder::tuple(num_args)
.fill(|it| {
let arg = match it {
ParamKind::Type => self.new_type_var(),
ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"),
ParamKind::Const(_) => unreachable!("Tuple with const parameter"),
};
arg_tys.push(arg.clone());
arg.cast(Interner)
})
.build();
let mut arg_tys = Vec::with_capacity(num_args);
let arg_ty = TyBuilder::tuple(num_args)
.fill(|it| {
let arg = match it {
ParamKind::Type => self.new_type_var(),
ParamKind::Lifetime => unreachable!("Tuple with lifetime parameter"),
ParamKind::Const(_) => unreachable!("Tuple with const parameter"),
};
arg_tys.push(arg.clone());
arg.cast(Interner)
})
.build();
let b = TyBuilder::trait_ref(self.db, fn_once_trait);
if b.remaining() != 2 {
return None;
}
let mut trait_ref = b.push(ty.clone()).push(arg_ty).build();
let b = TyBuilder::trait_ref(self.db, fn_trait);
if b.remaining() != 2 {
return None;
}
let mut trait_ref = b.push(ty.clone()).push(arg_ty).build();
let projection = {
TyBuilder::assoc_type_projection(
let projection = TyBuilder::assoc_type_projection(
self.db,
output_assoc_type,
Some(trait_ref.substitution.clone()),
)
.build()
};
.fill_with_unknown()
.build();
let trait_env = self.trait_env.env.clone();
let obligation = InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some() {
self.register_obligation(obligation.goal);
let return_ty = self.normalize_projection_ty(projection);
for fn_x in [FnTrait::Fn, FnTrait::FnMut, FnTrait::FnOnce] {
let fn_x_trait = fn_x.get_id(self.db, krate)?;
trait_ref.trait_id = to_chalk_trait_id(fn_x_trait);
let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> = InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self
.db
.trait_solve(krate, self.trait_env.block, canonical.cast(Interner))
.is_some()
{
return Some((fn_x, arg_tys, return_ty));
let trait_env = self.trait_env.env.clone();
let obligation = InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, self.trait_env.block, canonical.cast(Interner)).is_some()
{
self.register_obligation(obligation.goal);
let return_ty = self.normalize_projection_ty(projection);
for &fn_x in subtraits {
let fn_x_trait = fn_x.get_id(self.db, krate)?;
trait_ref.trait_id = to_chalk_trait_id(fn_x_trait);
let obligation: chalk_ir::InEnvironment<chalk_ir::Goal<Interner>> =
InEnvironment {
goal: trait_ref.clone().cast(Interner),
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self
.db
.trait_solve(krate, self.trait_env.block, canonical.cast(Interner))
.is_some()
{
return Some((fn_x, arg_tys, return_ty));
}
}
return Some((fn_trait_name, arg_tys, return_ty));
}
unreachable!("It should at least implement FnOnce at this point");
} else {
None
}
None
}
pub(super) fn insert_type_vars<T>(&mut self, ty: T) -> T

View file

@ -2023,11 +2023,11 @@ pub fn mir_body_for_closure_query(
ctx.result.locals.alloc(Local { ty: infer[*root].clone() });
let closure_local = ctx.result.locals.alloc(Local {
ty: match kind {
FnTrait::FnOnce => infer[expr].clone(),
FnTrait::FnMut => {
FnTrait::FnOnce | FnTrait::AsyncFnOnce => infer[expr].clone(),
FnTrait::FnMut | FnTrait::AsyncFnMut => {
TyKind::Ref(Mutability::Mut, error_lifetime(), infer[expr].clone()).intern(Interner)
}
FnTrait::Fn => {
FnTrait::Fn | FnTrait::AsyncFn => {
TyKind::Ref(Mutability::Not, error_lifetime(), infer[expr].clone()).intern(Interner)
}
},
@ -2055,8 +2055,10 @@ pub fn mir_body_for_closure_query(
let mut err = None;
let closure_local = ctx.result.locals.iter().nth(1).unwrap().0;
let closure_projection = match kind {
FnTrait::FnOnce => vec![],
FnTrait::FnMut | FnTrait::Fn => vec![ProjectionElem::Deref],
FnTrait::FnOnce | FnTrait::AsyncFnOnce => vec![],
FnTrait::FnMut | FnTrait::Fn | FnTrait::AsyncFnMut | FnTrait::AsyncFn => {
vec![ProjectionElem::Deref]
}
};
ctx.result.walk_places(|p, store| {
if let Some(it) = upvar_map.get(&p.local) {

View file

@ -4834,3 +4834,53 @@ fn bar(v: *const ()) {
"#]],
);
}
#[test]
fn async_fn_traits() {
check_infer(
r#"
//- minicore: async_fn
async fn foo<T: AsyncFn(u32) -> i32>(a: T) {
let fut1 = a(0);
fut1.await;
}
async fn bar<T: AsyncFnMut(u32) -> i32>(mut b: T) {
let fut2 = b(0);
fut2.await;
}
async fn baz<T: AsyncFnOnce(u32) -> i32>(c: T) {
let fut3 = c(0);
fut3.await;
}
"#,
expect![[r#"
37..38 'a': T
43..83 '{ ...ait; }': ()
43..83 '{ ...ait; }': impl Future<Output = ()>
53..57 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
60..61 'a': T
60..64 'a(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
62..63 '0': u32
70..74 'fut1': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
70..80 'fut1.await': i32
124..129 'mut b': T
134..174 '{ ...ait; }': ()
134..174 '{ ...ait; }': impl Future<Output = ()>
144..148 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
151..152 'b': T
151..155 'b(0)': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
153..154 '0': u32
161..165 'fut2': AsyncFnMut::CallRefFuture<'?, T, (u32,)>
161..171 'fut2.await': i32
216..217 'c': T
222..262 '{ ...ait; }': ()
222..262 '{ ...ait; }': impl Future<Output = ()>
232..236 'fut3': AsyncFnOnce::CallOnceFuture<T, (u32,)>
239..240 'c': T
239..243 'c(0)': AsyncFnOnce::CallOnceFuture<T, (u32,)>
241..242 '0': u32
249..253 'fut3': AsyncFnOnce::CallOnceFuture<T, (u32,)>
249..259 'fut3.await': i32
"#]],
);
}

View file

@ -220,6 +220,10 @@ pub enum FnTrait {
FnOnce,
FnMut,
Fn,
AsyncFnOnce,
AsyncFnMut,
AsyncFn,
}
impl fmt::Display for FnTrait {
@ -228,6 +232,9 @@ impl fmt::Display for FnTrait {
FnTrait::FnOnce => write!(f, "FnOnce"),
FnTrait::FnMut => write!(f, "FnMut"),
FnTrait::Fn => write!(f, "Fn"),
FnTrait::AsyncFnOnce => write!(f, "AsyncFnOnce"),
FnTrait::AsyncFnMut => write!(f, "AsyncFnMut"),
FnTrait::AsyncFn => write!(f, "AsyncFn"),
}
}
}
@ -238,6 +245,9 @@ impl FnTrait {
FnTrait::FnOnce => "call_once",
FnTrait::FnMut => "call_mut",
FnTrait::Fn => "call",
FnTrait::AsyncFnOnce => "async_call_once",
FnTrait::AsyncFnMut => "async_call_mut",
FnTrait::AsyncFn => "async_call",
}
}
@ -246,6 +256,9 @@ impl FnTrait {
FnTrait::FnOnce => LangItem::FnOnce,
FnTrait::FnMut => LangItem::FnMut,
FnTrait::Fn => LangItem::Fn,
FnTrait::AsyncFnOnce => LangItem::AsyncFnOnce,
FnTrait::AsyncFnMut => LangItem::AsyncFnMut,
FnTrait::AsyncFn => LangItem::AsyncFn,
}
}
@ -254,15 +267,19 @@ impl FnTrait {
LangItem::FnOnce => Some(FnTrait::FnOnce),
LangItem::FnMut => Some(FnTrait::FnMut),
LangItem::Fn => Some(FnTrait::Fn),
LangItem::AsyncFnOnce => Some(FnTrait::AsyncFnOnce),
LangItem::AsyncFnMut => Some(FnTrait::AsyncFnMut),
LangItem::AsyncFn => Some(FnTrait::AsyncFn),
_ => None,
}
}
pub const fn to_chalk_ir(self) -> rust_ir::ClosureKind {
// Chalk doesn't support async fn traits.
match self {
FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce,
FnTrait::FnMut => rust_ir::ClosureKind::FnMut,
FnTrait::Fn => rust_ir::ClosureKind::Fn,
FnTrait::AsyncFnOnce | FnTrait::FnOnce => rust_ir::ClosureKind::FnOnce,
FnTrait::AsyncFnMut | FnTrait::FnMut => rust_ir::ClosureKind::FnMut,
FnTrait::AsyncFn | FnTrait::Fn => rust_ir::ClosureKind::Fn,
}
}
@ -271,6 +288,9 @@ impl FnTrait {
FnTrait::FnOnce => Name::new_symbol_root(sym::call_once.clone()),
FnTrait::FnMut => Name::new_symbol_root(sym::call_mut.clone()),
FnTrait::Fn => Name::new_symbol_root(sym::call.clone()),
FnTrait::AsyncFnOnce => Name::new_symbol_root(sym::async_call_once.clone()),
FnTrait::AsyncFnMut => Name::new_symbol_root(sym::async_call_mut.clone()),
FnTrait::AsyncFn => Name::new_symbol_root(sym::async_call.clone()),
}
}