From ec273c3d12b7393f6b81e793ce1c7abd59e3dc67 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Fri, 3 Mar 2023 10:23:20 +0100 Subject: [PATCH] Split pattern inference into more functions --- crates/hir-def/src/resolver.rs | 2 +- crates/hir-ty/src/infer.rs | 40 +-- crates/hir-ty/src/infer/expr.rs | 47 ++-- crates/hir-ty/src/infer/pat.rs | 300 +++++++++++++++-------- crates/hir-ty/src/tests/patterns.rs | 12 +- crates/ide/src/inlay_hints/adjustment.rs | 5 +- 6 files changed, 234 insertions(+), 172 deletions(-) diff --git a/crates/hir-def/src/resolver.rs b/crates/hir-def/src/resolver.rs index 36d8b24e9c..fdb236c111 100644 --- a/crates/hir-def/src/resolver.rs +++ b/crates/hir-def/src/resolver.rs @@ -85,7 +85,7 @@ pub enum ResolveValueResult { Partial(TypeNs, usize), } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum ValueNs { ImplSelf(ImplId), LocalBinding(PatId), diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 6790be64c5..f229bf2f64 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -144,44 +144,6 @@ impl Default for BindingMode { } } -/// Used to generalize patterns and assignee expressions. -trait PatLike: Into + Copy { - type BindingMode: Copy; - - fn infer( - this: &mut InferenceContext<'_>, - id: Self, - expected_ty: &Ty, - default_bm: Self::BindingMode, - ) -> Ty; -} - -impl PatLike for ExprId { - type BindingMode = (); - - fn infer( - this: &mut InferenceContext<'_>, - id: Self, - expected_ty: &Ty, - _: Self::BindingMode, - ) -> Ty { - this.infer_assignee_expr(id, expected_ty) - } -} - -impl PatLike for PatId { - type BindingMode = BindingMode; - - fn infer( - this: &mut InferenceContext<'_>, - id: Self, - expected_ty: &Ty, - default_bm: Self::BindingMode, - ) -> Ty { - this.infer_pat(id, expected_ty, default_bm) - } -} - #[derive(Debug)] pub(crate) struct InferOk { value: T, @@ -581,7 +543,7 @@ impl<'a> InferenceContext<'a> { let ty = self.insert_type_vars(ty); let ty = self.normalize_associated_types_in(ty); - self.infer_pat(*pat, &ty, BindingMode::default()); + self.infer_top_pat(*pat, &ty); } let error_ty = &TypeRef::Error; let return_ty = if data.has_async_kw() { diff --git a/crates/hir-ty/src/infer/expr.rs b/crates/hir-ty/src/infer/expr.rs index e169cbef49..a186ae836d 100644 --- a/crates/hir-ty/src/infer/expr.rs +++ b/crates/hir-ty/src/infer/expr.rs @@ -25,7 +25,9 @@ use syntax::ast::RangeOp; use crate::{ autoderef::{self, Autoderef}, consteval, - infer::{coerce::CoerceMany, find_continuable, BreakableKind}, + infer::{ + coerce::CoerceMany, find_continuable, pat::contains_explicit_ref_binding, BreakableKind, + }, lower::{ const_or_path_to_chalk, generic_arg_to_chalk, lower_to_chalk_mutability, ParamLoweringMode, }, @@ -39,8 +41,8 @@ use crate::{ }; use super::{ - coerce::auto_deref_adjust_steps, find_breakable, BindingMode, BreakableContext, Diverges, - Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch, + coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges, Expectation, + InferenceContext, InferenceDiagnostic, TypeMismatch, }; impl<'a> InferenceContext<'a> { @@ -111,7 +113,7 @@ impl<'a> InferenceContext<'a> { } &Expr::Let { pat, expr } => { let input_ty = self.infer_expr(expr, &Expectation::none()); - self.infer_pat(pat, &input_ty, BindingMode::default()); + self.infer_top_pat(pat, &input_ty); self.result.standard_types.bool_.clone() } Expr::Block { statements, tail, label, id: _ } => { @@ -223,7 +225,7 @@ impl<'a> InferenceContext<'a> { let pat_ty = self.resolve_associated_type(into_iter_ty, self.resolve_iterator_item()); - self.infer_pat(pat, &pat_ty, BindingMode::default()); + self.infer_top_pat(pat, &pat_ty); self.with_breakable_ctx(BreakableKind::Loop, self.err_ty(), label, |this| { this.infer_expr(body, &Expectation::HasType(TyBuilder::unit())); }); @@ -298,7 +300,7 @@ impl<'a> InferenceContext<'a> { // Now go through the argument patterns for (arg_pat, arg_ty) in args.iter().zip(sig_tys) { - self.infer_pat(*arg_pat, &arg_ty, BindingMode::default()); + self.infer_top_pat(*arg_pat, &arg_ty); } let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); @@ -395,7 +397,8 @@ impl<'a> InferenceContext<'a> { for arm in arms.iter() { self.diverges = Diverges::Maybe; - let _pat_ty = self.infer_pat(arm.pat, &input_ty, BindingMode::default()); + let input_ty = self.resolve_ty_shallow(&input_ty); + let _pat_ty = self.infer_top_pat(arm.pat, &input_ty); if let Some(guard_expr) = arm.guard { self.infer_expr( guard_expr, @@ -1142,27 +1145,33 @@ impl<'a> InferenceContext<'a> { let decl_ty = type_ref .as_ref() .map(|tr| self.make_ty(tr)) - .unwrap_or_else(|| self.err_ty()); + .unwrap_or_else(|| self.table.new_type_var()); - // Always use the declared type when specified - let mut ty = decl_ty.clone(); - - if let Some(expr) = initializer { - let actual_ty = - self.infer_expr_coerce(*expr, &Expectation::has_type(decl_ty.clone())); - if decl_ty.is_unknown() { - ty = actual_ty; + let ty = if let Some(expr) = initializer { + let ty = if contains_explicit_ref_binding(&self.body, *pat) { + self.infer_expr(*expr, &Expectation::has_type(decl_ty.clone())) + } else { + self.infer_expr_coerce(*expr, &Expectation::has_type(decl_ty.clone())) + }; + if type_ref.is_some() { + decl_ty + } else { + ty } - } + } else { + decl_ty + }; + + self.infer_top_pat(*pat, &ty); if let Some(expr) = else_branch { + let previous_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); self.infer_expr_coerce( *expr, &Expectation::HasType(self.result.standard_types.never.clone()), ); + self.diverges = previous_diverges; } - - self.infer_pat(*pat, &ty, BindingMode::default()); } Statement::Expr { expr, .. } => { self.infer_expr(*expr, &Expectation::none()); diff --git a/crates/hir-ty/src/infer/pat.rs b/crates/hir-ty/src/infer/pat.rs index 8b381f0d1f..6481e0b7a7 100644 --- a/crates/hir-ty/src/infer/pat.rs +++ b/crates/hir-ty/src/infer/pat.rs @@ -4,7 +4,8 @@ use std::iter::repeat_with; use chalk_ir::Mutability; use hir_def::{ - expr::{BindingAnnotation, Expr, Literal, Pat, PatId}, + body::Body, + expr::{BindingAnnotation, Expr, ExprId, ExprOrPatId, Literal, Pat, PatId, RecordFieldPat}, path::Path, }; use hir_expand::name::Name; @@ -17,7 +18,43 @@ use crate::{ static_lifetime, Interner, Scalar, Substitution, Ty, TyBuilder, TyExt, TyKind, }; -use super::PatLike; +/// Used to generalize patterns and assignee expressions. +pub(super) trait PatLike: Into + Copy { + type BindingMode: Copy; + + fn infer( + this: &mut InferenceContext<'_>, + id: Self, + expected_ty: &Ty, + default_bm: Self::BindingMode, + ) -> Ty; +} + +impl PatLike for ExprId { + type BindingMode = (); + + fn infer( + this: &mut InferenceContext<'_>, + id: Self, + expected_ty: &Ty, + (): Self::BindingMode, + ) -> Ty { + this.infer_assignee_expr(id, expected_ty) + } +} + +impl PatLike for PatId { + type BindingMode = BindingMode; + + fn infer( + this: &mut InferenceContext<'_>, + id: Self, + expected_ty: &Ty, + default_bm: Self::BindingMode, + ) -> Ty { + this.infer_pat(id, expected_ty, default_bm) + } +} impl<'a> InferenceContext<'a> { /// Infers type for tuple struct pattern or its corresponding assignee expression. @@ -110,6 +147,7 @@ impl<'a> InferenceContext<'a> { ellipsis: Option, subs: &[T], ) -> Ty { + let expected = self.resolve_ty_shallow(expected); let expectations = match expected.as_tuple() { Some(parameters) => &*parameters.as_slice(Interner), _ => &[], @@ -143,12 +181,11 @@ impl<'a> InferenceContext<'a> { .intern(Interner) } - pub(super) fn infer_pat( - &mut self, - pat: PatId, - expected: &Ty, - mut default_bm: BindingMode, - ) -> Ty { + pub(super) fn infer_top_pat(&mut self, pat: PatId, expected: &Ty) -> Ty { + self.infer_pat(pat, expected, BindingMode::default()) + } + + fn infer_pat(&mut self, pat: PatId, expected: &Ty, mut default_bm: BindingMode) -> Ty { let mut expected = self.resolve_ty_shallow(expected); if is_non_ref_pat(self.body, pat) { @@ -183,25 +220,17 @@ impl<'a> InferenceContext<'a> { self.infer_tuple_pat_like(&expected, default_bm, *ellipsis, args) } Pat::Or(pats) => { - if let Some((first_pat, rest)) = pats.split_first() { - let ty = self.infer_pat(*first_pat, &expected, default_bm); - for pat in rest { - self.infer_pat(*pat, &expected, default_bm); - } - ty - } else { - self.err_ty() + for pat in pats.iter() { + self.infer_pat(*pat, &expected, default_bm); } + expected.clone() } - Pat::Ref { pat, mutability } => { - let mutability = lower_to_chalk_mutability(*mutability); - let expectation = match expected.as_reference() { - Some((inner_ty, _lifetime, exp_mut)) => inner_ty.clone(), - _ => self.result.standard_types.unknown.clone(), - }; - let subty = self.infer_pat(*pat, &expectation, default_bm); - TyKind::Ref(mutability, static_lifetime(), subty).intern(Interner) - } + &Pat::Ref { pat, mutability } => self.infer_ref_pat( + pat, + lower_to_chalk_mutability(mutability), + &expected, + default_bm, + ), Pat::TupleStruct { path: p, args: subpats, ellipsis } => self .infer_tuple_struct_pat_like( p.as_deref(), @@ -221,91 +250,17 @@ impl<'a> InferenceContext<'a> { self.infer_path(&resolver, path, pat.into()).unwrap_or_else(|| self.err_ty()) } Pat::Bind { mode, name: _, subpat } => { - let mode = if mode == &BindingAnnotation::Unannotated { - default_bm - } else { - BindingMode::convert(*mode) - }; - self.result.pat_binding_modes.insert(pat, mode); - - let inner_ty = match subpat { - Some(subpat) => self.infer_pat(*subpat, &expected, default_bm), - None => expected, - }; - let inner_ty = self.insert_type_vars_shallow(inner_ty); - - let bound_ty = match mode { - BindingMode::Ref(mutability) => { - TyKind::Ref(mutability, static_lifetime(), inner_ty.clone()) - .intern(Interner) - } - BindingMode::Move => inner_ty.clone(), - }; - self.write_pat_ty(pat, bound_ty); - return inner_ty; + return self.infer_bind_pat(pat, *mode, default_bm, *subpat, &expected); } Pat::Slice { prefix, slice, suffix } => { - let elem_ty = match expected.kind(Interner) { - TyKind::Array(st, _) | TyKind::Slice(st) => st.clone(), - _ => self.err_ty(), - }; - - for &pat_id in prefix.iter().chain(suffix.iter()) { - self.infer_pat(pat_id, &elem_ty, default_bm); - } - - if let &Some(slice_pat_id) = slice { - let rest_pat_ty = match expected.kind(Interner) { - TyKind::Array(_, length) => { - let len = try_const_usize(length); - let len = len.and_then(|len| { - len.checked_sub((prefix.len() + suffix.len()) as u128) - }); - TyKind::Array( - elem_ty.clone(), - usize_const(self.db, len, self.resolver.krate()), - ) - } - _ => TyKind::Slice(elem_ty.clone()), - } - .intern(Interner); - self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm); - } - - match expected.kind(Interner) { - TyKind::Array(_, const_) => TyKind::Array(elem_ty, const_.clone()), - _ => TyKind::Slice(elem_ty), - } - .intern(Interner) + self.infer_slice_pat(&expected, prefix, slice, suffix, default_bm) } Pat::Wild => expected.clone(), Pat::Range { start, end } => { let start_ty = self.infer_expr(*start, &Expectation::has_type(expected.clone())); self.infer_expr(*end, &Expectation::has_type(start_ty)) } - &Pat::Lit(expr) => { - // FIXME: using `Option` here is a workaround until we can use if-let chains in stable. - let mut pat_ty = None; - - // Like slice patterns, byte string patterns can denote both `&[u8; N]` and `&[u8]`. - if let Expr::Literal(Literal::ByteString(_)) = self.body[expr] { - if let Some((inner, ..)) = expected.as_reference() { - let inner = self.resolve_ty_shallow(inner); - if matches!(inner.kind(Interner), TyKind::Slice(_)) { - let elem_ty = TyKind::Scalar(Scalar::Uint(UintTy::U8)).intern(Interner); - let slice_ty = TyKind::Slice(elem_ty).intern(Interner); - let ty = TyKind::Ref(Mutability::Not, static_lifetime(), slice_ty) - .intern(Interner); - self.write_expr_ty(expr, ty.clone()); - pat_ty = Some(ty); - } - } - } - - pat_ty.unwrap_or_else(|| { - self.infer_expr(expr, &Expectation::has_type(expected.clone())) - }) - } + &Pat::Lit(expr) => self.infer_lit_pat(expr, &expected), Pat::Box { inner } => match self.resolve_boxed_box() { Some(box_adt) => { let (inner_ty, alloc_ty) = match expected.as_adt() { @@ -341,6 +296,109 @@ impl<'a> InferenceContext<'a> { self.write_pat_ty(pat, ty.clone()); ty } + + fn infer_ref_pat( + &mut self, + pat: PatId, + mutability: Mutability, + expected: &Ty, + default_bm: BindingMode, + ) -> Ty { + let expectation = match expected.as_reference() { + Some((inner_ty, _lifetime, _exp_mut)) => inner_ty.clone(), + _ => self.result.standard_types.unknown.clone(), + }; + let subty = self.infer_pat(pat, &expectation, default_bm); + TyKind::Ref(mutability, static_lifetime(), subty).intern(Interner) + } + + fn infer_bind_pat( + &mut self, + pat: PatId, + mode: BindingAnnotation, + default_bm: BindingMode, + subpat: Option, + expected: &Ty, + ) -> Ty { + let mode = if mode == BindingAnnotation::Unannotated { + default_bm + } else { + BindingMode::convert(mode) + }; + self.result.pat_binding_modes.insert(pat, mode); + + let inner_ty = match subpat { + Some(subpat) => self.infer_pat(subpat, &expected, default_bm), + None => expected.clone(), + }; + let inner_ty = self.insert_type_vars_shallow(inner_ty); + + let bound_ty = match mode { + BindingMode::Ref(mutability) => { + TyKind::Ref(mutability, static_lifetime(), inner_ty.clone()).intern(Interner) + } + BindingMode::Move => inner_ty.clone(), + }; + self.write_pat_ty(pat, bound_ty); + return inner_ty; + } + + fn infer_slice_pat( + &mut self, + expected: &Ty, + prefix: &[PatId], + slice: &Option, + suffix: &[PatId], + default_bm: BindingMode, + ) -> Ty { + let elem_ty = match expected.kind(Interner) { + TyKind::Array(st, _) | TyKind::Slice(st) => st.clone(), + _ => self.err_ty(), + }; + + for &pat_id in prefix.iter().chain(suffix.iter()) { + self.infer_pat(pat_id, &elem_ty, default_bm); + } + + if let &Some(slice_pat_id) = slice { + let rest_pat_ty = match expected.kind(Interner) { + TyKind::Array(_, length) => { + let len = try_const_usize(length); + let len = + len.and_then(|len| len.checked_sub((prefix.len() + suffix.len()) as u128)); + TyKind::Array(elem_ty.clone(), usize_const(self.db, len, self.resolver.krate())) + } + _ => TyKind::Slice(elem_ty.clone()), + } + .intern(Interner); + self.infer_pat(slice_pat_id, &rest_pat_ty, default_bm); + } + + match expected.kind(Interner) { + TyKind::Array(_, const_) => TyKind::Array(elem_ty, const_.clone()), + _ => TyKind::Slice(elem_ty), + } + .intern(Interner) + } + + fn infer_lit_pat(&mut self, expr: ExprId, expected: &Ty) -> Ty { + // Like slice patterns, byte string patterns can denote both `&[u8; N]` and `&[u8]`. + if let Expr::Literal(Literal::ByteString(_)) = self.body[expr] { + if let Some((inner, ..)) = expected.as_reference() { + let inner = self.resolve_ty_shallow(inner); + if matches!(inner.kind(Interner), TyKind::Slice(_)) { + let elem_ty = TyKind::Scalar(Scalar::Uint(UintTy::U8)).intern(Interner); + let slice_ty = TyKind::Slice(elem_ty).intern(Interner); + let ty = + TyKind::Ref(Mutability::Not, static_lifetime(), slice_ty).intern(Interner); + self.write_expr_ty(expr, ty.clone()); + return ty; + } + } + } + + self.infer_expr(expr, &Expectation::has_type(expected.clone())) + } } fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool { @@ -365,3 +423,41 @@ fn is_non_ref_pat(body: &hir_def::body::Body, pat: PatId) -> bool { Pat::Wild | Pat::Bind { .. } | Pat::Ref { .. } | Pat::Box { .. } | Pat::Missing => false, } } + +pub(super) fn contains_explicit_ref_binding(body: &Body, pat_id: PatId) -> bool { + let mut res = false; + walk_pats(body, pat_id, &mut |pat| { + res |= matches!(pat, Pat::Bind { mode: BindingAnnotation::Ref, .. }) + }); + res +} + +fn walk_pats(body: &Body, pat_id: PatId, f: &mut impl FnMut(&Pat)) { + let pat = &body[pat_id]; + f(pat); + match pat { + Pat::Range { .. } + | Pat::Lit(..) + | Pat::Path(..) + | Pat::ConstBlock(..) + | Pat::Wild + | Pat::Missing => {} + &Pat::Bind { subpat, .. } => { + if let Some(subpat) = subpat { + walk_pats(body, subpat, f); + } + } + Pat::Or(args) | Pat::Tuple { args, .. } | Pat::TupleStruct { args, .. } => { + args.iter().copied().for_each(|p| walk_pats(body, p, f)); + } + Pat::Ref { pat, .. } => walk_pats(body, *pat, f), + Pat::Slice { prefix, slice, suffix } => { + let total_iter = prefix.iter().chain(slice.iter()).chain(suffix.iter()); + total_iter.copied().for_each(|p| walk_pats(body, p, f)); + } + Pat::Record { args, .. } => { + args.iter().for_each(|RecordFieldPat { pat, .. }| walk_pats(body, *pat, f)); + } + Pat::Box { inner } => walk_pats(body, *inner, f), + } +} diff --git a/crates/hir-ty/src/tests/patterns.rs b/crates/hir-ty/src/tests/patterns.rs index aa1b2a1d9b..be67329fee 100644 --- a/crates/hir-ty/src/tests/patterns.rs +++ b/crates/hir-ty/src/tests/patterns.rs @@ -953,9 +953,9 @@ fn main() { 42..51 'true | ()': bool 49..51 '()': () 57..59 '{}': () - 68..80 '(() | true,)': ((),) + 68..80 '(() | true,)': (bool,) 69..71 '()': () - 69..78 '() | true': () + 69..78 '() | true': bool 74..78 'true': bool 74..78 'true': bool 84..86 '{}': () @@ -964,19 +964,15 @@ fn main() { 96..102 '_ | ()': bool 100..102 '()': () 108..110 '{}': () - 119..128 '(() | _,)': ((),) + 119..128 '(() | _,)': (bool,) 120..122 '()': () - 120..126 '() | _': () + 120..126 '() | _': bool 125..126 '_': bool 132..134 '{}': () 49..51: expected bool, got () - 68..80: expected (bool,), got ((),) 69..71: expected bool, got () - 69..78: expected bool, got () 100..102: expected bool, got () - 119..128: expected (bool,), got ((),) 120..122: expected bool, got () - 120..126: expected bool, got () "#]], ); } diff --git a/crates/ide/src/inlay_hints/adjustment.rs b/crates/ide/src/inlay_hints/adjustment.rs index 188eb7f977..729780fa0c 100644 --- a/crates/ide/src/inlay_hints/adjustment.rs +++ b/crates/ide/src/inlay_hints/adjustment.rs @@ -606,14 +606,13 @@ fn a() { } #[test] - fn bug() { + fn let_stmt_explicit_ty() { check_with_config( InlayHintsConfig { adjustment_hints: AdjustmentHints::Always, ..DISABLED_CONFIG }, r#" fn main() { - // These should be identical, but they are not... - let () = return; + //^^^^^^ let (): () = return; //^^^^^^ }