diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index f229bf2f64..d3e7f6e7dc 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -66,8 +66,10 @@ pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc { + ctx.collect_fn(f); + } DefWithBodyId::ConstId(c) => ctx.collect_const(&db.const_data(c)), - DefWithBodyId::FunctionId(f) => ctx.collect_fn(f), DefWithBodyId::StaticId(s) => ctx.collect_static(&db.static_data(s)), DefWithBodyId::VariantId(v) => { ctx.return_ty = TyBuilder::builtin(match db.enum_data(v.parent).variant_body_type() { @@ -392,9 +394,12 @@ pub(crate) struct InferenceContext<'a> { /// currently within one. /// /// We might consider using a nested inference context for checking - /// closures, but currently this is the only field that will change there, - /// so it doesn't make sense. + /// closures so we can swap all shared things out at once. return_ty: Ty, + /// If `Some`, this stores coercion information for returned + /// expressions. If `None`, this is in a context where return is + /// inappropriate, such as a const expression. + return_coercion: Option, /// The resume type and the yield type, respectively, of the generator being inferred. resume_yield_tys: Option<(Ty, Ty)>, diverges: Diverges, @@ -462,6 +467,7 @@ impl<'a> InferenceContext<'a> { trait_env, return_ty: TyKind::Error.intern(Interner), // set in collect_* calls resume_yield_tys: None, + return_coercion: None, db, owner, body, @@ -595,10 +601,19 @@ impl<'a> InferenceContext<'a> { }; self.return_ty = self.normalize_associated_types_in(return_ty); + self.return_coercion = Some(CoerceMany::new(self.return_ty.clone())); } fn infer_body(&mut self) { - self.infer_expr_coerce(self.body.body_expr, &Expectation::has_type(self.return_ty.clone())); + match self.return_coercion { + Some(_) => self.infer_return(self.body.body_expr), + None => { + _ = self.infer_expr_coerce( + self.body.body_expr, + &Expectation::has_type(self.return_ty.clone()), + ) + } + } } fn write_expr_ty(&mut self, expr: ExprId, ty: Ty) { diff --git a/crates/hir-ty/src/infer/coerce.rs b/crates/hir-ty/src/infer/coerce.rs index 3293534a06..8bce47d71c 100644 --- a/crates/hir-ty/src/infer/coerce.rs +++ b/crates/hir-ty/src/infer/coerce.rs @@ -50,11 +50,44 @@ fn success( #[derive(Clone, Debug)] pub(super) struct CoerceMany { expected_ty: Ty, + final_ty: Option, } impl CoerceMany { pub(super) fn new(expected: Ty) -> Self { - CoerceMany { expected_ty: expected } + CoerceMany { expected_ty: expected, final_ty: None } + } + + /// Returns the "expected type" with which this coercion was + /// constructed. This represents the "downward propagated" type + /// that was given to us at the start of typing whatever construct + /// we are typing (e.g., the match expression). + /// + /// Typically, this is used as the expected type when + /// type-checking each of the alternative expressions whose types + /// we are trying to merge. + pub(super) fn expected_ty(&self) -> Ty { + self.expected_ty.clone() + } + + /// Returns the current "merged type", representing our best-guess + /// at the LUB of the expressions we've seen so far (if any). This + /// isn't *final* until you call `self.complete()`, which will return + /// the merged type. + pub(super) fn merged_ty(&self) -> Ty { + self.final_ty.clone().unwrap_or_else(|| self.expected_ty.clone()) + } + + pub(super) fn complete(self, ctx: &mut InferenceContext<'_>) -> Ty { + if let Some(final_ty) = self.final_ty { + final_ty + } else { + ctx.result.standard_types.never.clone() + } + } + + pub(super) fn coerce_forced_unit(&mut self, ctx: &mut InferenceContext<'_>) { + self.coerce(ctx, None, &ctx.result.standard_types.unit.clone()) } /// Merge two types from different branches, with possible coercion. @@ -76,25 +109,25 @@ impl CoerceMany { // Special case: two function types. Try to coerce both to // pointers to have a chance at getting a match. See // https://github.com/rust-lang/rust/blob/7b805396bf46dce972692a6846ce2ad8481c5f85/src/librustc_typeck/check/coercion.rs#L877-L916 - let sig = match (self.expected_ty.kind(Interner), expr_ty.kind(Interner)) { + let sig = match (self.merged_ty().kind(Interner), expr_ty.kind(Interner)) { (TyKind::FnDef(..) | TyKind::Closure(..), TyKind::FnDef(..) | TyKind::Closure(..)) => { // FIXME: we're ignoring safety here. To be more correct, if we have one FnDef and one Closure, // we should be coercing the closure to a fn pointer of the safety of the FnDef cov_mark::hit!(coerce_fn_reification); let sig = - self.expected_ty.callable_sig(ctx.db).expect("FnDef without callable sig"); + self.merged_ty().callable_sig(ctx.db).expect("FnDef without callable sig"); Some(sig) } _ => None, }; if let Some(sig) = sig { let target_ty = TyKind::Function(sig.to_fn_ptr()).intern(Interner); - let result1 = ctx.table.coerce_inner(self.expected_ty.clone(), &target_ty); + let result1 = ctx.table.coerce_inner(self.merged_ty(), &target_ty); let result2 = ctx.table.coerce_inner(expr_ty.clone(), &target_ty); if let (Ok(result1), Ok(result2)) = (result1, result2) { ctx.table.register_infer_ok(result1); ctx.table.register_infer_ok(result2); - return self.expected_ty = target_ty; + return self.final_ty = Some(target_ty); } } @@ -102,25 +135,20 @@ impl CoerceMany { // type is a type variable and the new one is `!`, trying it the other // way around first would mean we make the type variable `!`, instead of // just marking it as possibly diverging. - if ctx.coerce(expr, &expr_ty, &self.expected_ty).is_ok() { - /* self.expected_ty is already correct */ - } else if ctx.coerce(expr, &self.expected_ty, &expr_ty).is_ok() { - self.expected_ty = expr_ty; + if let Ok(res) = ctx.coerce(expr, &expr_ty, &self.merged_ty()) { + self.final_ty = Some(res); + } else if let Ok(res) = ctx.coerce(expr, &self.merged_ty(), &expr_ty) { + self.final_ty = Some(res); } else { if let Some(id) = expr { ctx.result.type_mismatches.insert( id.into(), - TypeMismatch { expected: self.expected_ty.clone(), actual: expr_ty }, + TypeMismatch { expected: self.merged_ty().clone(), actual: expr_ty.clone() }, ); } cov_mark::hit!(coerce_merge_fail_fallback); - /* self.expected_ty is already correct */ } } - - pub(super) fn complete(self) -> Ty { - self.expected_ty - } } pub fn could_coerce( diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index 6f20f0dc89..b650afe996 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -60,6 +60,10 @@ impl<'a> InferenceContext<'a> { ty } + pub(crate) fn infer_expr_no_expect(&mut self, tgt_expr: ExprId) -> Ty { + self.infer_expr_inner(tgt_expr, &Expectation::None) + } + /// Infer type of expression with possibly implicit coerce to the expected type. /// Return the type after possible coercion. pub(super) fn infer_expr_coerce(&mut self, expr: ExprId, expected: &Expectation) -> Ty { @@ -99,17 +103,20 @@ impl<'a> InferenceContext<'a> { both_arms_diverge &= mem::replace(&mut self.diverges, Diverges::Maybe); let mut coerce = CoerceMany::new(expected.coercion_target_type(&mut self.table)); coerce.coerce(self, Some(then_branch), &then_ty); - let else_ty = match else_branch { - Some(else_branch) => self.infer_expr_inner(else_branch, expected), - None => TyBuilder::unit(), - }; + match else_branch { + Some(else_branch) => { + let else_ty = self.infer_expr_inner(else_branch, expected); + coerce.coerce(self, Some(else_branch), &else_ty); + } + None => { + coerce.coerce_forced_unit(self); + } + } both_arms_diverge &= self.diverges; - // FIXME: create a synthetic `else {}` so we have something to refer to here instead of None? - coerce.coerce(self, else_branch, &else_ty); self.diverges = condition_diverges | both_arms_diverge; - coerce.complete() + coerce.complete(self) } &Expr::Let { pat, expr } => { let input_ty = self.infer_expr(expr, &Expectation::none()); @@ -172,6 +179,8 @@ impl<'a> InferenceContext<'a> { let ret_ty = self.table.new_type_var(); let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); + let prev_ret_coercion = + mem::replace(&mut self.return_coercion, Some(CoerceMany::new(ret_ty.clone()))); let (_, inner_ty) = self.with_breakable_ctx(BreakableKind::Border, self.err_ty(), None, |this| { @@ -180,6 +189,7 @@ impl<'a> InferenceContext<'a> { self.diverges = prev_diverges; self.return_ty = prev_ret_ty; + self.return_coercion = prev_ret_coercion; // Use the first type parameter as the output type of future. // existential type AsyncBlockImplTrait: Future @@ -303,17 +313,21 @@ impl<'a> InferenceContext<'a> { self.infer_top_pat(*arg_pat, &arg_ty); } + // FIXME: lift these out into a struct let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); + let prev_ret_coercion = + mem::replace(&mut self.return_coercion, Some(CoerceMany::new(ret_ty.clone()))); let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys); self.with_breakable_ctx(BreakableKind::Border, self.err_ty(), None, |this| { - this.infer_expr_coerce(*body, &Expectation::has_type(ret_ty)); + this.infer_return(*body); }); self.diverges = prev_diverges; self.return_ty = prev_ret_ty; + self.return_coercion = prev_ret_coercion; self.resume_yield_tys = prev_resume_yield_tys; ty @@ -413,7 +427,7 @@ impl<'a> InferenceContext<'a> { self.diverges = matchee_diverges | all_arms_diverge; - coerce.complete() + coerce.complete(self) } Expr::Path(p) => { // FIXME this could be more efficient... @@ -431,7 +445,12 @@ impl<'a> InferenceContext<'a> { } Expr::Break { expr, label } => { let val_ty = if let Some(expr) = *expr { - self.infer_expr(expr, &Expectation::none()) + let opt_coerce_to = find_breakable(&mut self.breakables, label.as_ref()) + .map(|ctxt| ctxt.coerce.expected_ty()); + self.infer_expr_inner( + expr, + &Expectation::HasType(opt_coerce_to.unwrap_or_else(|| self.err_ty())), + ) } else { TyBuilder::unit() }; @@ -461,15 +480,7 @@ impl<'a> InferenceContext<'a> { } self.result.standard_types.never.clone() } - Expr::Return { expr } => { - if let Some(expr) = expr { - self.infer_expr_coerce(*expr, &Expectation::has_type(self.return_ty.clone())); - } else { - let unit = TyBuilder::unit(); - let _ = self.coerce(Some(tgt_expr), &unit, &self.return_ty.clone()); - } - self.result.standard_types.never.clone() - } + &Expr::Return { expr } => self.infer_expr_return(expr), Expr::Yield { expr } => { if let Some((resume_ty, yield_ty)) = self.resume_yield_tys.clone() { if let Some(expr) = expr { @@ -486,7 +497,7 @@ impl<'a> InferenceContext<'a> { } Expr::Yeet { expr } => { if let &Some(expr) = expr { - self.infer_expr_inner(expr, &Expectation::None); + self.infer_expr_no_expect(expr); } self.result.standard_types.never.clone() } @@ -614,7 +625,7 @@ impl<'a> InferenceContext<'a> { Expr::Cast { expr, type_ref } => { let cast_ty = self.make_ty(type_ref); // FIXME: propagate the "castable to" expectation - let _inner_ty = self.infer_expr_inner(*expr, &Expectation::None); + let _inner_ty = self.infer_expr_no_expect(*expr); // FIXME check the cast... cast_ty } @@ -810,53 +821,7 @@ impl<'a> InferenceContext<'a> { TyKind::Tuple(tys.len(), Substitution::from_iter(Interner, tys)).intern(Interner) } - Expr::Array(array) => { - let elem_ty = - match expected.to_option(&mut self.table).as_ref().map(|t| t.kind(Interner)) { - Some(TyKind::Array(st, _) | TyKind::Slice(st)) => st.clone(), - _ => self.table.new_type_var(), - }; - let mut coerce = CoerceMany::new(elem_ty.clone()); - - let expected = Expectation::has_type(elem_ty.clone()); - let len = match array { - Array::ElementList { elements, .. } => { - for &expr in elements.iter() { - let cur_elem_ty = self.infer_expr_inner(expr, &expected); - coerce.coerce(self, Some(expr), &cur_elem_ty); - } - consteval::usize_const( - self.db, - Some(elements.len() as u128), - self.resolver.krate(), - ) - } - &Array::Repeat { initializer, repeat } => { - self.infer_expr_coerce(initializer, &Expectation::has_type(elem_ty)); - self.infer_expr( - repeat, - &Expectation::HasType( - TyKind::Scalar(Scalar::Uint(UintTy::Usize)).intern(Interner), - ), - ); - - if let Some(g_def) = self.owner.as_generic_def_id() { - let generics = generics(self.db.upcast(), g_def); - consteval::eval_to_const( - repeat, - ParamLoweringMode::Placeholder, - self, - || generics, - DebruijnIndex::INNERMOST, - ) - } else { - consteval::usize_const(self.db, None, self.resolver.krate()) - } - } - }; - - TyKind::Array(coerce.complete(), len).intern(Interner) - } + Expr::Array(array) => self.infer_expr_array(array, expected), Expr::Literal(lit) => match lit { Literal::Bool(..) => self.result.standard_types.bool_.clone(), Literal::String(..) => { @@ -915,6 +880,97 @@ impl<'a> InferenceContext<'a> { ty } + fn infer_expr_array( + &mut self, + array: &Array, + expected: &Expectation, + ) -> chalk_ir::Ty { + let elem_ty = match expected.to_option(&mut self.table).as_ref().map(|t| t.kind(Interner)) { + Some(TyKind::Array(st, _) | TyKind::Slice(st)) => st.clone(), + _ => self.table.new_type_var(), + }; + + let krate = self.resolver.krate(); + + let expected = Expectation::has_type(elem_ty.clone()); + let (elem_ty, len) = match array { + Array::ElementList { elements, .. } if elements.is_empty() => { + (elem_ty, consteval::usize_const(self.db, Some(0), krate)) + } + Array::ElementList { elements, .. } => { + let mut coerce = CoerceMany::new(elem_ty.clone()); + for &expr in elements.iter() { + let cur_elem_ty = self.infer_expr_inner(expr, &expected); + coerce.coerce(self, Some(expr), &cur_elem_ty); + } + ( + coerce.complete(self), + consteval::usize_const(self.db, Some(elements.len() as u128), krate), + ) + } + &Array::Repeat { initializer, repeat } => { + self.infer_expr_coerce(initializer, &Expectation::has_type(elem_ty.clone())); + self.infer_expr( + repeat, + &Expectation::HasType( + TyKind::Scalar(Scalar::Uint(UintTy::Usize)).intern(Interner), + ), + ); + + ( + elem_ty, + if let Some(g_def) = self.owner.as_generic_def_id() { + let generics = generics(self.db.upcast(), g_def); + consteval::eval_to_const( + repeat, + ParamLoweringMode::Placeholder, + self, + || generics, + DebruijnIndex::INNERMOST, + ) + } else { + consteval::usize_const(self.db, None, krate) + }, + ) + } + }; + + TyKind::Array(elem_ty, len).intern(Interner) + } + + pub(super) fn infer_return(&mut self, expr: ExprId) { + let ret_ty = self + .return_coercion + .as_mut() + .expect("infer_return called outside function body") + .expected_ty(); + let return_expr_ty = self.infer_expr_inner(expr, &Expectation::HasType(ret_ty)); + let mut coerce_many = self.return_coercion.take().unwrap(); + coerce_many.coerce(self, Some(expr), &return_expr_ty); + self.return_coercion = Some(coerce_many); + } + + fn infer_expr_return(&mut self, expr: Option) -> Ty { + match self.return_coercion { + Some(_) => { + if let Some(expr) = expr { + self.infer_return(expr); + } else { + let mut coerce = self.return_coercion.take().unwrap(); + coerce.coerce_forced_unit(self); + self.return_coercion = Some(coerce); + } + } + None => { + // FIXME: diagnose return outside of function + if let Some(expr) = expr { + self.infer_expr_no_expect(expr); + } + } + } + self.result.standard_types.never.clone() + } + fn infer_expr_box(&mut self, inner_expr: ExprId, expected: &Expectation) -> Ty { if let Some(box_id) = self.resolve_boxed_box() { let table = &mut self.table; @@ -1666,6 +1722,6 @@ impl<'a> InferenceContext<'a> { }); let res = cb(self); let ctx = self.breakables.pop().expect("breakable stack broken"); - (ctx.may_break.then(|| ctx.coerce.complete()), res) + (ctx.may_break.then(|| ctx.coerce.complete(self)), res) } }