mirror of
				https://github.com/rust-lang/rust-analyzer.git
				synced 2025-10-31 12:04:43 +00:00 
			
		
		
		
	Port closure inference from rustc
This commit is contained in:
		
							parent
							
								
									bec545920a
								
							
						
					
					
						commit
						bc3e9d9fcb
					
				
					 2 changed files with 405 additions and 157 deletions
				
			
		|  | @ -1,24 +1,26 @@ | |||
| //! Inference of closure parameter types based on the closure's expected type.
 | ||||
| 
 | ||||
| use std::{cmp, convert::Infallible, mem}; | ||||
| use std::{cmp, convert::Infallible, mem, ops::ControlFlow}; | ||||
| 
 | ||||
| use chalk_ir::{ | ||||
|     BoundVar, DebruijnIndex, FnSubst, Mutability, TyKind, | ||||
|     cast::Cast, | ||||
|     fold::{FallibleTypeFolder, TypeFoldable}, | ||||
|     fold::{FallibleTypeFolder, Shift, TypeFoldable}, | ||||
|     visit::{TypeSuperVisitable, TypeVisitable, TypeVisitor}, | ||||
| }; | ||||
| use either::Either; | ||||
| use hir_def::{ | ||||
|     DefWithBodyId, FieldId, HasModule, TupleFieldId, TupleId, VariantId, | ||||
|     data::adt::VariantData, | ||||
|     hir::{ | ||||
|         Array, AsmOperand, BinaryOp, BindingId, CaptureBy, Expr, ExprId, ExprOrPatId, Pat, PatId, | ||||
|         Statement, UnaryOp, | ||||
|         Array, AsmOperand, BinaryOp, BindingId, CaptureBy, ClosureKind, Expr, ExprId, ExprOrPatId, | ||||
|         Pat, PatId, Statement, UnaryOp, | ||||
|     }, | ||||
|     lang_item::LangItem, | ||||
|     path::Path, | ||||
|     resolver::ValueNs, | ||||
| }; | ||||
| use hir_def::{Lookup, type_ref::TypeRefId}; | ||||
| use hir_expand::name::Name; | ||||
| use intern::sym; | ||||
| use rustc_hash::FxHashMap; | ||||
|  | @ -28,12 +30,12 @@ use syntax::utils::is_raw_identifier; | |||
| 
 | ||||
| use crate::{ | ||||
|     Adjust, Adjustment, AliasEq, AliasTy, Binders, BindingMode, ChalkTraitId, ClosureId, DynTy, | ||||
|     DynTyExt, FnAbi, FnPointer, FnSig, Interner, OpaqueTy, ProjectionTyExt, Substitution, Ty, | ||||
|     TyExt, WhereClause, | ||||
|     db::{HirDatabase, InternedClosure}, | ||||
|     DynTyExt, FnAbi, FnPointer, FnSig, GenericArg, Interner, OpaqueTy, ProjectionTy, | ||||
|     ProjectionTyExt, Substitution, Ty, TyBuilder, TyExt, WhereClause, | ||||
|     db::{HirDatabase, InternedClosure, InternedCoroutine}, | ||||
|     error_lifetime, from_assoc_type_id, from_chalk_trait_id, from_placeholder_idx, | ||||
|     generics::Generics, | ||||
|     infer::coerce::CoerceNever, | ||||
|     infer::{BreakableKind, CoerceMany, Diverges, coerce::CoerceNever}, | ||||
|     make_binders, | ||||
|     mir::{BorrowKind, MirSpan, MutBorrowKind, ProjectionElem}, | ||||
|     to_chalk_trait_id, | ||||
|  | @ -43,7 +45,106 @@ use crate::{ | |||
| 
 | ||||
| use super::{Expectation, InferenceContext}; | ||||
| 
 | ||||
| #[derive(Debug)] | ||||
| pub(super) struct ClosureSignature { | ||||
|     pub(super) ret_ty: Ty, | ||||
|     pub(super) expected_sig: FnPointer, | ||||
| } | ||||
| 
 | ||||
| impl InferenceContext<'_> { | ||||
|     pub(super) fn infer_closure( | ||||
|         &mut self, | ||||
|         body: &ExprId, | ||||
|         args: &[PatId], | ||||
|         ret_type: &Option<TypeRefId>, | ||||
|         arg_types: &[Option<TypeRefId>], | ||||
|         closure_kind: ClosureKind, | ||||
|         tgt_expr: ExprId, | ||||
|         expected: &Expectation, | ||||
|     ) -> Ty { | ||||
|         assert_eq!(args.len(), arg_types.len()); | ||||
| 
 | ||||
|         let (expected_sig, expected_kind) = match expected.to_option(&mut self.table) { | ||||
|             Some(expected_ty) => self.deduce_closure_signature(&expected_ty, closure_kind), | ||||
|             None => (None, None), | ||||
|         }; | ||||
| 
 | ||||
|         let ClosureSignature { expected_sig: bound_sig, ret_ty: body_ret_ty } = | ||||
|             self.sig_of_closure(body, ret_type, arg_types, closure_kind, expected_sig); | ||||
|         let bound_sig = self.normalize_associated_types_in(bound_sig); | ||||
|         let sig_ty = TyKind::Function(bound_sig.clone()).intern(Interner); | ||||
| 
 | ||||
|         let (id, ty, resume_yield_tys) = match closure_kind { | ||||
|             ClosureKind::Coroutine(_) => { | ||||
|                 let sig_tys = bound_sig.substitution.0.as_slice(Interner); | ||||
|                 // FIXME: report error when there are more than 1 parameter.
 | ||||
|                 let resume_ty = match sig_tys.first() { | ||||
|                     // When `sig_tys.len() == 1` the first type is the return type, not the
 | ||||
|                     // first parameter type.
 | ||||
|                     Some(ty) if sig_tys.len() > 1 => ty.assert_ty_ref(Interner).clone(), | ||||
|                     _ => self.result.standard_types.unit.clone(), | ||||
|                 }; | ||||
|                 let yield_ty = self.table.new_type_var(); | ||||
| 
 | ||||
|                 let subst = TyBuilder::subst_for_coroutine(self.db, self.owner) | ||||
|                     .push(resume_ty.clone()) | ||||
|                     .push(yield_ty.clone()) | ||||
|                     .push(body_ret_ty.clone()) | ||||
|                     .build(); | ||||
| 
 | ||||
|                 let coroutine_id = | ||||
|                     self.db.intern_coroutine(InternedCoroutine(self.owner, tgt_expr)).into(); | ||||
|                 let coroutine_ty = TyKind::Coroutine(coroutine_id, subst).intern(Interner); | ||||
| 
 | ||||
|                 (None, coroutine_ty, Some((resume_ty, yield_ty))) | ||||
|             } | ||||
|             ClosureKind::Closure | ClosureKind::Async => { | ||||
|                 let closure_id = | ||||
|                     self.db.intern_closure(InternedClosure(self.owner, tgt_expr)).into(); | ||||
|                 let closure_ty = TyKind::Closure( | ||||
|                     closure_id, | ||||
|                     TyBuilder::subst_for_closure(self.db, self.owner, sig_ty.clone()), | ||||
|                 ) | ||||
|                 .intern(Interner); | ||||
|                 self.deferred_closures.entry(closure_id).or_default(); | ||||
|                 if let Some(c) = self.current_closure { | ||||
|                     self.closure_dependencies.entry(c).or_default().push(closure_id); | ||||
|                 } | ||||
|                 (Some(closure_id), closure_ty, None) | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         // Eagerly try to relate the closure type with the expected
 | ||||
|         // type, otherwise we often won't have enough information to
 | ||||
|         // infer the body.
 | ||||
|         self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected, expected_kind); | ||||
| 
 | ||||
|         // Now go through the argument patterns
 | ||||
|         for (arg_pat, arg_ty) in args.iter().zip(bound_sig.substitution.0.as_slice(Interner).iter()) | ||||
|         { | ||||
|             self.infer_top_pat(*arg_pat, arg_ty.assert_ty_ref(Interner), None); | ||||
|         } | ||||
| 
 | ||||
|         // FIXME: lift these out into a struct
 | ||||
|         let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); | ||||
|         let prev_closure = mem::replace(&mut self.current_closure, id); | ||||
|         let prev_ret_ty = mem::replace(&mut self.return_ty, body_ret_ty.clone()); | ||||
|         let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(body_ret_ty.clone())); | ||||
|         let prev_resume_yield_tys = mem::replace(&mut self.resume_yield_tys, resume_yield_tys); | ||||
| 
 | ||||
|         self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { | ||||
|             this.infer_return(*body); | ||||
|         }); | ||||
| 
 | ||||
|         self.diverges = prev_diverges; | ||||
|         self.return_ty = prev_ret_ty; | ||||
|         self.return_coercion = prev_ret_coercion; | ||||
|         self.current_closure = prev_closure; | ||||
|         self.resume_yield_tys = prev_resume_yield_tys; | ||||
| 
 | ||||
|         self.table.normalize_associated_types_in(ty) | ||||
|     } | ||||
| 
 | ||||
|     // This function handles both closures and coroutines.
 | ||||
|     pub(super) fn deduce_closure_type_from_expectations( | ||||
|         &mut self, | ||||
|  | @ -51,19 +152,21 @@ impl InferenceContext<'_> { | |||
|         closure_ty: &Ty, | ||||
|         sig_ty: &Ty, | ||||
|         expectation: &Expectation, | ||||
|         expected_kind: Option<FnTrait>, | ||||
|     ) { | ||||
|         let expected_ty = match expectation.to_option(&mut self.table) { | ||||
|             Some(ty) => ty, | ||||
|             None => return, | ||||
|         }; | ||||
| 
 | ||||
|         if let TyKind::Closure(closure_id, _) = closure_ty.kind(Interner) { | ||||
|             if let Some(closure_kind) = self.deduce_closure_kind_from_expectations(&expected_ty) { | ||||
|         match (closure_ty.kind(Interner), expected_kind) { | ||||
|             (TyKind::Closure(closure_id, _), Some(closure_kind)) => { | ||||
|                 self.result | ||||
|                     .closure_info | ||||
|                     .entry(*closure_id) | ||||
|                     .or_insert_with(|| (Vec::new(), closure_kind)); | ||||
|             } | ||||
|             _ => {} | ||||
|         } | ||||
| 
 | ||||
|         // Deduction from where-clauses in scope, as well as fn-pointer coercion are handled here.
 | ||||
|  | @ -86,56 +189,146 @@ impl InferenceContext<'_> { | |||
| 
 | ||||
|     // Closure kind deductions are mostly from `rustc_hir_typeck/src/closure.rs`.
 | ||||
|     // Might need to port closure sig deductions too.
 | ||||
|     fn deduce_closure_kind_from_expectations(&mut self, expected_ty: &Ty) -> Option<FnTrait> { | ||||
|     pub(super) fn deduce_closure_signature( | ||||
|         &mut self, | ||||
|         expected_ty: &Ty, | ||||
|         closure_kind: ClosureKind, | ||||
|     ) -> (Option<FnSubst<Interner>>, Option<FnTrait>) { | ||||
|         match expected_ty.kind(Interner) { | ||||
|             TyKind::Alias(AliasTy::Opaque(OpaqueTy { .. })) | TyKind::OpaqueType(..) => { | ||||
|                 let clauses = expected_ty | ||||
|                     .impl_trait_bounds(self.db) | ||||
|                     .into_iter() | ||||
|                     .flatten() | ||||
|                     .map(|b| b.into_value_and_skipped_binders().0); | ||||
|                 self.deduce_closure_kind_from_predicate_clauses(clauses) | ||||
|                 let clauses = expected_ty.impl_trait_bounds(self.db).into_iter().flatten().map( | ||||
|                     |b: chalk_ir::Binders<chalk_ir::WhereClause<Interner>>| { | ||||
|                         b.into_value_and_skipped_binders().0 | ||||
|                     }, | ||||
|                 ); | ||||
|                 self.deduce_closure_kind_from_predicate_clauses(expected_ty, clauses, closure_kind) | ||||
|             } | ||||
|             TyKind::Dyn(dyn_ty) => { | ||||
|                 let sig = | ||||
|                     dyn_ty.bounds.skip_binders().as_slice(Interner).iter().find_map(|bound| { | ||||
|                         if let WhereClause::AliasEq(AliasEq { | ||||
|                             alias: AliasTy::Projection(projection_ty), | ||||
|                             ty: projected_ty, | ||||
|                         }) = bound.skip_binders() | ||||
|                         { | ||||
|                             if let Some(sig) = self.deduce_sig_from_projection( | ||||
|                                 closure_kind, | ||||
|                                 projection_ty, | ||||
|                                 projected_ty, | ||||
|                             ) { | ||||
|                                 return Some(sig); | ||||
|                             } | ||||
|                         } | ||||
|                         None | ||||
|                     }); | ||||
| 
 | ||||
|                 let kind = dyn_ty.principal().and_then(|principal_trait_ref| { | ||||
|                     self.fn_trait_kind_from_trait_id(from_chalk_trait_id( | ||||
|                         principal_trait_ref.skip_binders().skip_binders().trait_id, | ||||
|                     )) | ||||
|                 }); | ||||
| 
 | ||||
|                 (sig, kind) | ||||
|             } | ||||
|             TyKind::Dyn(dyn_ty) => dyn_ty.principal_id().and_then(|trait_id| { | ||||
|                 self.fn_trait_kind_from_trait_id(from_chalk_trait_id(trait_id)) | ||||
|             }), | ||||
|             TyKind::InferenceVar(ty, chalk_ir::TyVariableKind::General) => { | ||||
|                 let clauses = self.clauses_for_self_ty(*ty); | ||||
|                 self.deduce_closure_kind_from_predicate_clauses(clauses.into_iter()) | ||||
|                 self.deduce_closure_kind_from_predicate_clauses( | ||||
|                     expected_ty, | ||||
|                     clauses.into_iter(), | ||||
|                     closure_kind, | ||||
|                 ) | ||||
|             } | ||||
|             TyKind::Function(_) => Some(FnTrait::Fn), | ||||
|             _ => None, | ||||
|             TyKind::Function(fn_ptr) => match closure_kind { | ||||
|                 ClosureKind::Closure => (Some(fn_ptr.substitution.clone()), Some(FnTrait::Fn)), | ||||
|                 ClosureKind::Async | ClosureKind::Coroutine(_) => (None, None), | ||||
|             }, | ||||
|             _ => (None, None), | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn deduce_closure_kind_from_predicate_clauses( | ||||
|         &self, | ||||
|         expected_ty: &Ty, | ||||
|         clauses: impl DoubleEndedIterator<Item = WhereClause>, | ||||
|     ) -> Option<FnTrait> { | ||||
|         closure_kind: ClosureKind, | ||||
|     ) -> (Option<FnSubst<Interner>>, Option<FnTrait>) { | ||||
|         let mut expected_sig = None; | ||||
|         let mut expected_kind = None; | ||||
| 
 | ||||
|         for clause in elaborate_clause_supertraits(self.db, clauses.rev()) { | ||||
|             let trait_id = match clause { | ||||
|                 WhereClause::AliasEq(AliasEq { | ||||
|                     alias: AliasTy::Projection(projection), .. | ||||
|                 }) => Some(projection.trait_(self.db)), | ||||
|                 WhereClause::Implemented(trait_ref) => { | ||||
|                     Some(from_chalk_trait_id(trait_ref.trait_id)) | ||||
|                 } | ||||
|                 _ => None, | ||||
|             }; | ||||
|             if let Some(closure_kind) = | ||||
|                 trait_id.and_then(|trait_id| self.fn_trait_kind_from_trait_id(trait_id)) | ||||
|             if expected_sig.is_none() { | ||||
|                 if let WhereClause::AliasEq(AliasEq { | ||||
|                     alias: AliasTy::Projection(projection), | ||||
|                     ty, | ||||
|                 }) = &clause | ||||
|                 { | ||||
|                 // `FnX`'s variants order is opposite from rustc, so use `cmp::max` instead of `cmp::min`
 | ||||
|                 expected_kind = Some( | ||||
|                     expected_kind | ||||
|                         .map_or_else(|| closure_kind, |current| cmp::max(current, closure_kind)), | ||||
|                 ); | ||||
|                     let inferred_sig = | ||||
|                         self.deduce_sig_from_projection(closure_kind, projection, ty); | ||||
|                     // Make sure that we didn't infer a signature that mentions itself.
 | ||||
|                     // This can happen when we elaborate certain supertrait bounds that
 | ||||
|                     // mention projections containing the `Self` type. See rust-lang/rust#105401.
 | ||||
|                     struct MentionsTy<'a> { | ||||
|                         expected_ty: &'a Ty, | ||||
|                     } | ||||
|                     impl TypeVisitor<Interner> for MentionsTy<'_> { | ||||
|                         type BreakTy = (); | ||||
| 
 | ||||
|                         fn interner(&self) -> Interner { | ||||
|                             Interner | ||||
|                         } | ||||
| 
 | ||||
|                         fn as_dyn( | ||||
|                             &mut self, | ||||
|                         ) -> &mut dyn TypeVisitor<Interner, BreakTy = Self::BreakTy> | ||||
|                         { | ||||
|                             self | ||||
|                         } | ||||
| 
 | ||||
|                         fn visit_ty( | ||||
|                             &mut self, | ||||
|                             t: &Ty, | ||||
|                             db: chalk_ir::DebruijnIndex, | ||||
|                         ) -> ControlFlow<()> { | ||||
|                             if t == self.expected_ty { | ||||
|                                 ControlFlow::Break(()) | ||||
|                             } else { | ||||
|                                 t.super_visit_with(self, db) | ||||
|                             } | ||||
|                         } | ||||
|                     } | ||||
|                     if inferred_sig | ||||
|                         .visit_with( | ||||
|                             &mut MentionsTy { expected_ty }, | ||||
|                             chalk_ir::DebruijnIndex::INNERMOST, | ||||
|                         ) | ||||
|                         .is_continue() | ||||
|                     { | ||||
|                         expected_sig = inferred_sig; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|         expected_kind | ||||
|             let trait_id = match clause { | ||||
|                 WhereClause::AliasEq(AliasEq { | ||||
|                     alias: AliasTy::Projection(projection), .. | ||||
|                 }) => projection.trait_(self.db), | ||||
|                 WhereClause::Implemented(trait_ref) => from_chalk_trait_id(trait_ref.trait_id), | ||||
|                 _ => continue, | ||||
|             }; | ||||
|             if let Some(closure_kind) = self.fn_trait_kind_from_trait_id(trait_id) { | ||||
|                 // always use the closure kind that is more permissive.
 | ||||
|                 match (expected_kind, closure_kind) { | ||||
|                     (None, _) => expected_kind = Some(closure_kind), | ||||
|                     (Some(FnTrait::FnMut), FnTrait::Fn) => expected_kind = Some(FnTrait::Fn), | ||||
|                     (Some(FnTrait::FnOnce), FnTrait::Fn | FnTrait::FnMut) => { | ||||
|                         expected_kind = Some(closure_kind) | ||||
|                     } | ||||
|                     _ => {} | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         (expected_sig, expected_kind) | ||||
|     } | ||||
| 
 | ||||
|     fn deduce_sig_from_dyn_ty(&self, dyn_ty: &DynTy) -> Option<FnPointer> { | ||||
|  | @ -186,9 +379,174 @@ impl InferenceContext<'_> { | |||
|         None | ||||
|     } | ||||
| 
 | ||||
|     fn deduce_sig_from_projection( | ||||
|         &self, | ||||
|         closure_kind: ClosureKind, | ||||
|         projection_ty: &ProjectionTy, | ||||
|         projected_ty: &Ty, | ||||
|     ) -> Option<FnSubst<Interner>> { | ||||
|         let container = | ||||
|             from_assoc_type_id(projection_ty.associated_ty_id).lookup(self.db.upcast()).container; | ||||
|         let trait_ = match container { | ||||
|             hir_def::ItemContainerId::TraitId(trait_) => trait_, | ||||
|             _ => return None, | ||||
|         }; | ||||
| 
 | ||||
|         // For now, we only do signature deduction based off of the `Fn` and `AsyncFn` traits,
 | ||||
|         // for closures and async closures, respectively.
 | ||||
|         match closure_kind { | ||||
|             ClosureKind::Closure | ClosureKind::Async | ||||
|                 if self.fn_trait_kind_from_trait_id(trait_).is_some() => | ||||
|             { | ||||
|                 self.extract_sig_from_projection(projection_ty, projected_ty) | ||||
|             } | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn extract_sig_from_projection( | ||||
|         &self, | ||||
|         projection_ty: &ProjectionTy, | ||||
|         projected_ty: &Ty, | ||||
|     ) -> Option<FnSubst<Interner>> { | ||||
|         let arg_param_ty = projection_ty.substitution.as_slice(Interner)[1].assert_ty_ref(Interner); | ||||
| 
 | ||||
|         let TyKind::Tuple(_, input_tys) = arg_param_ty.kind(Interner) else { | ||||
|             return None; | ||||
|         }; | ||||
| 
 | ||||
|         let ret_param_ty = projected_ty; | ||||
| 
 | ||||
|         Some(FnSubst(Substitution::from_iter( | ||||
|             Interner, | ||||
|             input_tys.iter(Interner).map(|t| t.cast(Interner)).chain(Some(GenericArg::new( | ||||
|                 Interner, | ||||
|                 chalk_ir::GenericArgData::Ty(ret_param_ty.clone()), | ||||
|             ))), | ||||
|         ))) | ||||
|     } | ||||
| 
 | ||||
|     fn fn_trait_kind_from_trait_id(&self, trait_id: hir_def::TraitId) -> Option<FnTrait> { | ||||
|         FnTrait::from_lang_item(self.db.lang_attr(trait_id.into())?) | ||||
|     } | ||||
| 
 | ||||
|     fn supplied_sig_of_closure( | ||||
|         &mut self, | ||||
|         body: &ExprId, | ||||
|         ret_type: &Option<TypeRefId>, | ||||
|         arg_types: &[Option<TypeRefId>], | ||||
|         closure_kind: ClosureKind, | ||||
|     ) -> ClosureSignature { | ||||
|         let mut sig_tys = Vec::with_capacity(arg_types.len() + 1); | ||||
| 
 | ||||
|         // collect explicitly written argument types
 | ||||
|         for arg_type in arg_types.iter() { | ||||
|             let arg_ty = match arg_type { | ||||
|                 Some(type_ref) => self.make_body_ty(*type_ref), | ||||
|                 None => self.table.new_type_var(), | ||||
|             }; | ||||
|             sig_tys.push(arg_ty); | ||||
|         } | ||||
| 
 | ||||
|         // add return type
 | ||||
|         let ret_ty = match ret_type { | ||||
|             Some(type_ref) => self.make_body_ty(*type_ref), | ||||
|             None => self.table.new_type_var(), | ||||
|         }; | ||||
|         if let ClosureKind::Async = closure_kind { | ||||
|             sig_tys.push(self.lower_async_block_type_impl_trait(ret_ty.clone(), *body)); | ||||
|         } else { | ||||
|             sig_tys.push(ret_ty.clone()); | ||||
|         } | ||||
| 
 | ||||
|         let expected_sig = FnPointer { | ||||
|             num_binders: 0, | ||||
|             sig: FnSig { abi: FnAbi::RustCall, safety: chalk_ir::Safety::Safe, variadic: false }, | ||||
|             substitution: FnSubst( | ||||
|                 Substitution::from_iter(Interner, sig_tys.iter().cloned()).shifted_in(Interner), | ||||
|             ), | ||||
|         }; | ||||
| 
 | ||||
|         ClosureSignature { ret_ty, expected_sig } | ||||
|     } | ||||
| 
 | ||||
|     /// The return type is the signature of the closure, and the return type
 | ||||
|     /// *as represented inside the body* (so, for async closures, the `Output` ty)
 | ||||
|     pub(super) fn sig_of_closure( | ||||
|         &mut self, | ||||
|         body: &ExprId, | ||||
|         ret_type: &Option<TypeRefId>, | ||||
|         arg_types: &[Option<TypeRefId>], | ||||
|         closure_kind: ClosureKind, | ||||
|         expected_sig: Option<FnSubst<Interner>>, | ||||
|     ) -> ClosureSignature { | ||||
|         if let Some(e) = expected_sig { | ||||
|             self.sig_of_closure_with_expectation(body, ret_type, arg_types, closure_kind, e) | ||||
|         } else { | ||||
|             self.sig_of_closure_no_expectation(body, ret_type, arg_types, closure_kind) | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn sig_of_closure_no_expectation( | ||||
|         &mut self, | ||||
|         body: &ExprId, | ||||
|         ret_type: &Option<TypeRefId>, | ||||
|         arg_types: &[Option<TypeRefId>], | ||||
|         closure_kind: ClosureKind, | ||||
|     ) -> ClosureSignature { | ||||
|         self.supplied_sig_of_closure(body, ret_type, arg_types, closure_kind) | ||||
|     } | ||||
| 
 | ||||
|     fn sig_of_closure_with_expectation( | ||||
|         &mut self, | ||||
|         body: &ExprId, | ||||
|         ret_type: &Option<TypeRefId>, | ||||
|         arg_types: &[Option<TypeRefId>], | ||||
|         closure_kind: ClosureKind, | ||||
|         expected_sig: FnSubst<Interner>, | ||||
|     ) -> ClosureSignature { | ||||
|         let expected_sig = FnPointer { | ||||
|             num_binders: 0, | ||||
|             sig: FnSig { abi: FnAbi::RustCall, safety: chalk_ir::Safety::Safe, variadic: false }, | ||||
|             substitution: expected_sig, | ||||
|         }; | ||||
| 
 | ||||
|         // If the expected signature does not match the actual arg types,
 | ||||
|         // then just return the expected signature
 | ||||
|         if expected_sig.substitution.0.len(Interner) != arg_types.len() + 1 { | ||||
|             let ret_ty = match ret_type { | ||||
|                 Some(type_ref) => self.make_body_ty(*type_ref), | ||||
|                 None => self.table.new_type_var(), | ||||
|             }; | ||||
|             return ClosureSignature { expected_sig, ret_ty }; | ||||
|         } | ||||
| 
 | ||||
|         self.merge_supplied_sig_with_expectation( | ||||
|             body, | ||||
|             ret_type, | ||||
|             arg_types, | ||||
|             closure_kind, | ||||
|             expected_sig, | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     fn merge_supplied_sig_with_expectation( | ||||
|         &mut self, | ||||
|         body: &ExprId, | ||||
|         ret_type: &Option<TypeRefId>, | ||||
|         arg_types: &[Option<TypeRefId>], | ||||
|         closure_kind: ClosureKind, | ||||
|         expected_sig: FnPointer, | ||||
|     ) -> ClosureSignature { | ||||
|         let supplied_sig = self.supplied_sig_of_closure(body, ret_type, arg_types, closure_kind); | ||||
| 
 | ||||
|         let snapshot = self.table.snapshot(); | ||||
|         if !self.table.unify(&expected_sig.substitution, &supplied_sig.expected_sig.substitution) { | ||||
|             self.table.rollback_to(snapshot); | ||||
|         } | ||||
| 
 | ||||
|         supplied_sig | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| // The below functions handle capture and closure kind (Fn, FnMut, ..)
 | ||||
|  |  | |||
|  | @ -5,13 +5,13 @@ use std::{ | |||
|     mem, | ||||
| }; | ||||
| 
 | ||||
| use chalk_ir::{DebruijnIndex, Mutability, TyVariableKind, cast::Cast, fold::Shift}; | ||||
| use chalk_ir::{DebruijnIndex, Mutability, TyVariableKind, cast::Cast}; | ||||
| use either::Either; | ||||
| use hir_def::{ | ||||
|     BlockId, FieldId, GenericDefId, GenericParamId, ItemContainerId, Lookup, TupleFieldId, TupleId, | ||||
|     hir::{ | ||||
|         ArithOp, Array, AsmOperand, AsmOptions, BinaryOp, ClosureKind, Expr, ExprId, ExprOrPatId, | ||||
|         LabelId, Literal, Pat, PatId, Statement, UnaryOp, | ||||
|         ArithOp, Array, AsmOperand, AsmOptions, BinaryOp, Expr, ExprId, ExprOrPatId, LabelId, | ||||
|         Literal, Pat, PatId, Statement, UnaryOp, | ||||
|     }, | ||||
|     lang_item::{LangItem, LangItemTarget}, | ||||
|     path::{GenericArg, GenericArgs, Path}, | ||||
|  | @ -24,12 +24,10 @@ use syntax::ast::RangeOp; | |||
| 
 | ||||
| use crate::{ | ||||
|     Adjust, Adjustment, AdtId, AutoBorrow, Binders, CallableDefId, CallableSig, DeclContext, | ||||
|     DeclOrigin, FnAbi, FnPointer, FnSig, FnSubst, Interner, Rawness, Scalar, Substitution, | ||||
|     TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind, | ||||
|     DeclOrigin, Interner, Rawness, Scalar, Substitution, TraitEnvironment, TraitRef, Ty, TyBuilder, | ||||
|     TyExt, TyKind, | ||||
|     autoderef::{Autoderef, builtin_deref, deref_by_trait}, | ||||
|     consteval, | ||||
|     db::{InternedClosure, InternedCoroutine}, | ||||
|     error_lifetime, | ||||
|     consteval, error_lifetime, | ||||
|     generics::{Generics, generics}, | ||||
|     infer::{ | ||||
|         BreakableKind, | ||||
|  | @ -378,116 +376,8 @@ impl InferenceContext<'_> { | |||
|                     None => self.result.standard_types.never.clone(), | ||||
|                 } | ||||
|             } | ||||
|             Expr::Closure { body, args, ret_type, arg_types, closure_kind, capture_by: _ } => { | ||||
|                 assert_eq!(args.len(), arg_types.len()); | ||||
| 
 | ||||
|                 let mut sig_tys = Vec::with_capacity(arg_types.len() + 1); | ||||
| 
 | ||||
|                 // collect explicitly written argument types
 | ||||
|                 for arg_type in arg_types.iter() { | ||||
|                     let arg_ty = match arg_type { | ||||
|                         Some(type_ref) => self.make_body_ty(*type_ref), | ||||
|                         None => self.table.new_type_var(), | ||||
|                     }; | ||||
|                     sig_tys.push(arg_ty); | ||||
|                 } | ||||
| 
 | ||||
|                 // add return type
 | ||||
|                 let ret_ty = match ret_type { | ||||
|                     Some(type_ref) => self.make_body_ty(*type_ref), | ||||
|                     None => self.table.new_type_var(), | ||||
|                 }; | ||||
|                 if let ClosureKind::Async = closure_kind { | ||||
|                     sig_tys.push(self.lower_async_block_type_impl_trait(ret_ty.clone(), *body)); | ||||
|                 } else { | ||||
|                     sig_tys.push(ret_ty.clone()); | ||||
|                 } | ||||
| 
 | ||||
|                 let sig_ty = TyKind::Function(FnPointer { | ||||
|                     num_binders: 0, | ||||
|                     sig: FnSig { | ||||
|                         abi: FnAbi::RustCall, | ||||
|                         safety: chalk_ir::Safety::Safe, | ||||
|                         variadic: false, | ||||
|                     }, | ||||
|                     substitution: FnSubst( | ||||
|                         Substitution::from_iter(Interner, sig_tys.iter().cloned()) | ||||
|                             .shifted_in(Interner), | ||||
|                     ), | ||||
|                 }) | ||||
|                 .intern(Interner); | ||||
| 
 | ||||
|                 let (id, ty, resume_yield_tys) = match closure_kind { | ||||
|                     ClosureKind::Coroutine(_) => { | ||||
|                         // FIXME: report error when there are more than 1 parameter.
 | ||||
|                         let resume_ty = match sig_tys.first() { | ||||
|                             // When `sig_tys.len() == 1` the first type is the return type, not the
 | ||||
|                             // first parameter type.
 | ||||
|                             Some(ty) if sig_tys.len() > 1 => ty.clone(), | ||||
|                             _ => self.result.standard_types.unit.clone(), | ||||
|                         }; | ||||
|                         let yield_ty = self.table.new_type_var(); | ||||
| 
 | ||||
|                         let subst = TyBuilder::subst_for_coroutine(self.db, self.owner) | ||||
|                             .push(resume_ty.clone()) | ||||
|                             .push(yield_ty.clone()) | ||||
|                             .push(ret_ty.clone()) | ||||
|                             .build(); | ||||
| 
 | ||||
|                         let coroutine_id = self | ||||
|                             .db | ||||
|                             .intern_coroutine(InternedCoroutine(self.owner, tgt_expr)) | ||||
|                             .into(); | ||||
|                         let coroutine_ty = TyKind::Coroutine(coroutine_id, subst).intern(Interner); | ||||
| 
 | ||||
|                         (None, coroutine_ty, Some((resume_ty, yield_ty))) | ||||
|                     } | ||||
|                     ClosureKind::Closure | ClosureKind::Async => { | ||||
|                         let closure_id = | ||||
|                             self.db.intern_closure(InternedClosure(self.owner, tgt_expr)).into(); | ||||
|                         let closure_ty = TyKind::Closure( | ||||
|                             closure_id, | ||||
|                             TyBuilder::subst_for_closure(self.db, self.owner, sig_ty.clone()), | ||||
|                         ) | ||||
|                         .intern(Interner); | ||||
|                         self.deferred_closures.entry(closure_id).or_default(); | ||||
|                         if let Some(c) = self.current_closure { | ||||
|                             self.closure_dependencies.entry(c).or_default().push(closure_id); | ||||
|                         } | ||||
|                         (Some(closure_id), closure_ty, None) | ||||
|                     } | ||||
|                 }; | ||||
| 
 | ||||
|                 // Eagerly try to relate the closure type with the expected
 | ||||
|                 // type, otherwise we often won't have enough information to
 | ||||
|                 // infer the body.
 | ||||
|                 self.deduce_closure_type_from_expectations(tgt_expr, &ty, &sig_ty, expected); | ||||
| 
 | ||||
|                 // Now go through the argument patterns
 | ||||
|                 for (arg_pat, arg_ty) in args.iter().zip(&sig_tys) { | ||||
|                     self.infer_top_pat(*arg_pat, arg_ty, None); | ||||
|                 } | ||||
| 
 | ||||
|                 // FIXME: lift these out into a struct
 | ||||
|                 let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe); | ||||
|                 let prev_closure = mem::replace(&mut self.current_closure, id); | ||||
|                 let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone()); | ||||
|                 let prev_ret_coercion = self.return_coercion.replace(CoerceMany::new(ret_ty)); | ||||
|                 let prev_resume_yield_tys = | ||||
|                     mem::replace(&mut self.resume_yield_tys, resume_yield_tys); | ||||
| 
 | ||||
|                 self.with_breakable_ctx(BreakableKind::Border, None, None, |this| { | ||||
|                     this.infer_return(*body); | ||||
|                 }); | ||||
| 
 | ||||
|                 self.diverges = prev_diverges; | ||||
|                 self.return_ty = prev_ret_ty; | ||||
|                 self.return_coercion = prev_ret_coercion; | ||||
|                 self.current_closure = prev_closure; | ||||
|                 self.resume_yield_tys = prev_resume_yield_tys; | ||||
| 
 | ||||
|                 ty | ||||
|             } | ||||
|             Expr::Closure { body, args, ret_type, arg_types, closure_kind, capture_by: _ } => self | ||||
|                 .infer_closure(body, args, ret_type, arg_types, *closure_kind, tgt_expr, expected), | ||||
|             Expr::Call { callee, args, .. } => self.infer_call(tgt_expr, *callee, args, expected), | ||||
|             Expr::MethodCall { receiver, args, method_name, generic_args } => self | ||||
|                 .infer_method_call( | ||||
|  | @ -2458,7 +2348,7 @@ impl InferenceContext<'_> { | |||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn with_breakable_ctx<T>( | ||||
|     pub(super) fn with_breakable_ctx<T>( | ||||
|         &mut self, | ||||
|         kind: BreakableKind, | ||||
|         ty: Option<Ty>, | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 jackh726
						jackh726