Implement the call argument checking order hack for closures

This commit is contained in:
Florian Diebold 2019-09-24 23:04:33 +02:00
parent a0aeb6e7ad
commit 6a86706650
3 changed files with 108 additions and 12 deletions

View file

@ -790,11 +790,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
};
self.unify(&expected_receiver_ty, &actual_receiver_ty);
let param_iter = param_tys.into_iter().chain(repeat(Ty::Unknown));
for (arg, param_ty) in args.iter().zip(param_iter) {
let param_ty = self.normalize_associated_types_in(param_ty);
self.infer_expr(*arg, &Expectation::has_type(param_ty));
}
self.check_call_arguments(args, &param_tys);
let ret_ty = self.normalize_associated_types_in(ret_ty);
ret_ty
}
@ -928,11 +924,7 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
}
};
self.register_obligations_for_call(&callee_ty);
let param_iter = param_tys.into_iter().chain(repeat(Ty::Unknown));
for (arg, param_ty) in args.iter().zip(param_iter) {
let param_ty = self.normalize_associated_types_in(param_ty);
self.infer_expr(*arg, &Expectation::has_type(param_ty));
}
self.check_call_arguments(args, &param_tys);
let ret_ty = self.normalize_associated_types_in(ret_ty);
ret_ty
}
@ -1274,6 +1266,30 @@ impl<'a, D: HirDatabase> InferenceContext<'a, D> {
ty
}
fn check_call_arguments(&mut self, args: &[ExprId], param_tys: &[Ty]) {
// Quoting https://github.com/rust-lang/rust/blob/6ef275e6c3cb1384ec78128eceeb4963ff788dca/src/librustc_typeck/check/mod.rs#L3325 --
// We do this in a pretty awful way: first we type-check any arguments
// that are not closures, then we type-check the closures. This is so
// that we have more information about the types of arguments when we
// type-check the functions. This isn't really the right way to do this.
for &check_closures in &[false, true] {
let param_iter = param_tys.iter().cloned().chain(repeat(Ty::Unknown));
for (&arg, param_ty) in args.iter().zip(param_iter) {
let is_closure = match &self.body[arg] {
Expr::Lambda { .. } => true,
_ => false,
};
if is_closure != check_closures {
continue;
}
let param_ty = self.normalize_associated_types_in(param_ty);
self.infer_expr(arg, &Expectation::has_type(param_ty));
}
}
}
fn collect_const(&mut self, data: &ConstData) {
self.return_ty = self.make_ty(data.type_ref());
}

View file

@ -4078,6 +4078,86 @@ fn test<F: FnOnce(u32) -> u64>(f: F) {
);
}
#[test]
fn closure_as_argument_inference_order() {
assert_snapshot!(
infer(r#"
#[lang = "fn_once"]
trait FnOnce<Args> {
type Output;
}
fn foo1<T, U, F: FnOnce(T) -> U>(x: T, f: F) -> U {}
fn foo2<T, U, F: FnOnce(T) -> U>(f: F, x: T) -> U {}
struct S;
impl S {
fn method(self) -> u64;
fn foo1<T, U, F: FnOnce(T) -> U>(self, x: T, f: F) -> U {}
fn foo2<T, U, F: FnOnce(T) -> U>(self, f: F, x: T) -> U {}
}
fn test() {
let x1 = foo1(S, |s| s.method());
let x2 = foo2(|s| s.method(), S);
let x3 = S.foo1(S, |s| s.method());
let x4 = S.foo2(|s| s.method(), S);
}
"#),
@r###"
[95; 96) 'x': T
[101; 102) 'f': F
[112; 114) '{}': ()
[148; 149) 'f': F
[154; 155) 'x': T
[165; 167) '{}': ()
[202; 206) 'self': S
[254; 258) 'self': S
[260; 261) 'x': T
[266; 267) 'f': F
[277; 279) '{}': ()
[317; 321) 'self': S
[323; 324) 'f': F
[329; 330) 'x': T
[340; 342) '{}': ()
[356; 515) '{ ... S); }': ()
[366; 368) 'x1': u64
[371; 375) 'foo1': fn foo1<S, u64, |S| -> u64>(T, F) -> U
[371; 394) 'foo1(S...hod())': u64
[376; 377) 'S': S
[379; 393) '|s| s.method()': |S| -> u64
[380; 381) 's': S
[383; 384) 's': S
[383; 393) 's.method()': u64
[404; 406) 'x2': u64
[409; 413) 'foo2': fn foo2<S, u64, |S| -> u64>(F, T) -> U
[409; 432) 'foo2(|...(), S)': u64
[414; 428) '|s| s.method()': |S| -> u64
[415; 416) 's': S
[418; 419) 's': S
[418; 428) 's.method()': u64
[430; 431) 'S': S
[442; 444) 'x3': u64
[447; 448) 'S': S
[447; 472) 'S.foo1...hod())': u64
[454; 455) 'S': S
[457; 471) '|s| s.method()': |S| -> u64
[458; 459) 's': S
[461; 462) 's': S
[461; 471) 's.method()': u64
[482; 484) 'x4': u64
[487; 488) 'S': S
[487; 512) 'S.foo2...(), S)': u64
[494; 508) '|s| s.method()': |S| -> u64
[495; 496) 's': S
[498; 499) 's': S
[498; 508) 's.method()': u64
[510; 511) 'S': S
"###
);
}
#[test]
fn unselected_projection_in_trait_env_1() {
let t = type_at(

View file

@ -406,8 +406,8 @@ where
let ty: Ty = from_chalk(self.db, parameters[0].assert_ty_ref().clone());
if let Ty::Apply(ApplicationTy { ctor: TypeCtor::Closure { def, expr }, .. }) = ty {
for fn_trait in
[super::FnTrait::FnOnce, super::FnTrait::FnMut, super::FnTrait::Fn].iter().copied()
for &fn_trait in
[super::FnTrait::FnOnce, super::FnTrait::FnMut, super::FnTrait::Fn].iter()
{
if let Some(actual_trait) = get_fn_trait(self.db, self.krate, fn_trait) {
if trait_ == actual_trait {