mirror of
				https://github.com/rust-lang/rust-analyzer.git
				synced 2025-10-31 03:54:42 +00:00 
			
		
		
		
	Implement parameter variance inference
This commit is contained in:
		
							parent
							
								
									17b3662755
								
							
						
					
					
						commit
						eee2761140
					
				
					 8 changed files with 1271 additions and 37 deletions
				
			
		|  | @ -950,11 +950,18 @@ pub(crate) fn fn_def_datum_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Ar | ||||||
| 
 | 
 | ||||||
| pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances { | pub(crate) fn fn_def_variance_query(db: &dyn HirDatabase, fn_def_id: FnDefId) -> Variances { | ||||||
|     let callable_def: CallableDefId = from_chalk(db, fn_def_id); |     let callable_def: CallableDefId = from_chalk(db, fn_def_id); | ||||||
|     let generic_params = |  | ||||||
|         generics(db.upcast(), GenericDefId::from_callable(db.upcast(), callable_def)); |  | ||||||
|     Variances::from_iter( |     Variances::from_iter( | ||||||
|         Interner, |         Interner, | ||||||
|         std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()), |         db.variances_of(GenericDefId::from_callable(db.upcast(), callable_def)) | ||||||
|  |             .as_deref() | ||||||
|  |             .unwrap_or_default() | ||||||
|  |             .iter() | ||||||
|  |             .map(|v| match v { | ||||||
|  |                 crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant, | ||||||
|  |                 crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant, | ||||||
|  |                 crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant, | ||||||
|  |                 crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant, | ||||||
|  |             }), | ||||||
|     ) |     ) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -962,10 +969,14 @@ pub(crate) fn adt_variance_query( | ||||||
|     db: &dyn HirDatabase, |     db: &dyn HirDatabase, | ||||||
|     chalk_ir::AdtId(adt_id): AdtId, |     chalk_ir::AdtId(adt_id): AdtId, | ||||||
| ) -> Variances { | ) -> Variances { | ||||||
|     let generic_params = generics(db.upcast(), adt_id.into()); |  | ||||||
|     Variances::from_iter( |     Variances::from_iter( | ||||||
|         Interner, |         Interner, | ||||||
|         std::iter::repeat(chalk_ir::Variance::Invariant).take(generic_params.len()), |         db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v { | ||||||
|  |             crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant, | ||||||
|  |             crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant, | ||||||
|  |             crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant, | ||||||
|  |             crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant, | ||||||
|  |         }), | ||||||
|     ) |     ) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -271,6 +271,10 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> { | ||||||
|     #[ra_salsa::invoke(chalk_db::adt_variance_query)] |     #[ra_salsa::invoke(chalk_db::adt_variance_query)] | ||||||
|     fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances; |     fn adt_variance(&self, adt_id: chalk_db::AdtId) -> chalk_db::Variances; | ||||||
| 
 | 
 | ||||||
|  |     #[ra_salsa::invoke(crate::variance::variances_of)] | ||||||
|  |     #[ra_salsa::cycle(crate::variance::variances_of_cycle)] | ||||||
|  |     fn variances_of(&self, def: GenericDefId) -> Option<Arc<[crate::variance::Variance]>>; | ||||||
|  | 
 | ||||||
|     #[ra_salsa::invoke(chalk_db::associated_ty_value_query)] |     #[ra_salsa::invoke(chalk_db::associated_ty_value_query)] | ||||||
|     fn associated_ty_value( |     fn associated_ty_value( | ||||||
|         &self, |         &self, | ||||||
|  |  | ||||||
|  | @ -132,6 +132,14 @@ impl Generics { | ||||||
|         self.params.len() |         self.params.len() | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |     pub(crate) fn len_self_lifetimes(&self) -> usize { | ||||||
|  |         self.params.len_lifetimes() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     pub(crate) fn has_trait_self(&self) -> bool { | ||||||
|  |         self.params.trait_self_param().is_some() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     /// (parent total, self param, type params, const params, impl trait list, lifetimes)
 |     /// (parent total, self param, type params, const params, impl trait list, lifetimes)
 | ||||||
|     pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) { |     pub(crate) fn provenance_split(&self) -> (usize, bool, usize, usize, usize, usize) { | ||||||
|         let mut self_param = false; |         let mut self_param = false; | ||||||
|  |  | ||||||
|  | @ -50,6 +50,7 @@ pub mod traits; | ||||||
| mod test_db; | mod test_db; | ||||||
| #[cfg(test)] | #[cfg(test)] | ||||||
| mod tests; | mod tests; | ||||||
|  | mod variance; | ||||||
| 
 | 
 | ||||||
| use std::hash::Hash; | use std::hash::Hash; | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour | ||||||
|             None => continue, |             None => continue, | ||||||
|         }; |         }; | ||||||
|         let def_map = module.def_map(&db); |         let def_map = module.def_map(&db); | ||||||
|         visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it)); |         visit_module(&db, &def_map, module.local_id, &mut |it| { | ||||||
|  |             defs.push(match it { | ||||||
|  |                 ModuleDefId::FunctionId(it) => it.into(), | ||||||
|  |                 ModuleDefId::EnumVariantId(it) => it.into(), | ||||||
|  |                 ModuleDefId::ConstId(it) => it.into(), | ||||||
|  |                 ModuleDefId::StaticId(it) => it.into(), | ||||||
|  |                 _ => return, | ||||||
|  |             }) | ||||||
|  |         }); | ||||||
|     } |     } | ||||||
|     defs.sort_by_key(|def| match def { |     defs.sort_by_key(|def| match def { | ||||||
|         DefWithBodyId::FunctionId(it) => { |         DefWithBodyId::FunctionId(it) => { | ||||||
|  | @ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { | ||||||
|     let def_map = module.def_map(&db); |     let def_map = module.def_map(&db); | ||||||
| 
 | 
 | ||||||
|     let mut defs: Vec<DefWithBodyId> = Vec::new(); |     let mut defs: Vec<DefWithBodyId> = Vec::new(); | ||||||
|     visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it)); |     visit_module(&db, &def_map, module.local_id, &mut |it| { | ||||||
|  |         defs.push(match it { | ||||||
|  |             ModuleDefId::FunctionId(it) => it.into(), | ||||||
|  |             ModuleDefId::EnumVariantId(it) => it.into(), | ||||||
|  |             ModuleDefId::ConstId(it) => it.into(), | ||||||
|  |             ModuleDefId::StaticId(it) => it.into(), | ||||||
|  |             _ => return, | ||||||
|  |         }) | ||||||
|  |     }); | ||||||
|     defs.sort_by_key(|def| match def { |     defs.sort_by_key(|def| match def { | ||||||
|         DefWithBodyId::FunctionId(it) => { |         DefWithBodyId::FunctionId(it) => { | ||||||
|             let loc = it.lookup(&db); |             let loc = it.lookup(&db); | ||||||
|  | @ -405,11 +421,11 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String { | ||||||
|     buf |     buf | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| fn visit_module( | pub(crate) fn visit_module( | ||||||
|     db: &TestDB, |     db: &TestDB, | ||||||
|     crate_def_map: &DefMap, |     crate_def_map: &DefMap, | ||||||
|     module_id: LocalModuleId, |     module_id: LocalModuleId, | ||||||
|     cb: &mut dyn FnMut(DefWithBodyId), |     cb: &mut dyn FnMut(ModuleDefId), | ||||||
| ) { | ) { | ||||||
|     visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb); |     visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb); | ||||||
|     for impl_id in crate_def_map[module_id].scope.impls() { |     for impl_id in crate_def_map[module_id].scope.impls() { | ||||||
|  | @ -417,18 +433,18 @@ fn visit_module( | ||||||
|         for &item in impl_data.items.iter() { |         for &item in impl_data.items.iter() { | ||||||
|             match item { |             match item { | ||||||
|                 AssocItemId::FunctionId(it) => { |                 AssocItemId::FunctionId(it) => { | ||||||
|                     let def = it.into(); |                     let body = db.body(it.into()); | ||||||
|                     cb(def); |                     cb(it.into()); | ||||||
|                     let body = db.body(def); |  | ||||||
|                     visit_body(db, &body, cb); |                     visit_body(db, &body, cb); | ||||||
|                 } |                 } | ||||||
|                 AssocItemId::ConstId(it) => { |                 AssocItemId::ConstId(it) => { | ||||||
|                     let def = it.into(); |                     let body = db.body(it.into()); | ||||||
|                     cb(def); |                     cb(it.into()); | ||||||
|                     let body = db.body(def); |  | ||||||
|                     visit_body(db, &body, cb); |                     visit_body(db, &body, cb); | ||||||
|                 } |                 } | ||||||
|                 AssocItemId::TypeAliasId(_) => (), |                 AssocItemId::TypeAliasId(it) => { | ||||||
|  |                     cb(it.into()); | ||||||
|  |                 } | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | @ -437,33 +453,27 @@ fn visit_module( | ||||||
|         db: &TestDB, |         db: &TestDB, | ||||||
|         crate_def_map: &DefMap, |         crate_def_map: &DefMap, | ||||||
|         scope: &ItemScope, |         scope: &ItemScope, | ||||||
|         cb: &mut dyn FnMut(DefWithBodyId), |         cb: &mut dyn FnMut(ModuleDefId), | ||||||
|     ) { |     ) { | ||||||
|         for decl in scope.declarations() { |         for decl in scope.declarations() { | ||||||
|  |             cb(decl); | ||||||
|             match decl { |             match decl { | ||||||
|                 ModuleDefId::FunctionId(it) => { |                 ModuleDefId::FunctionId(it) => { | ||||||
|                     let def = it.into(); |                     let body = db.body(it.into()); | ||||||
|                     cb(def); |  | ||||||
|                     let body = db.body(def); |  | ||||||
|                     visit_body(db, &body, cb); |                     visit_body(db, &body, cb); | ||||||
|                 } |                 } | ||||||
|                 ModuleDefId::ConstId(it) => { |                 ModuleDefId::ConstId(it) => { | ||||||
|                     let def = it.into(); |                     let body = db.body(it.into()); | ||||||
|                     cb(def); |  | ||||||
|                     let body = db.body(def); |  | ||||||
|                     visit_body(db, &body, cb); |                     visit_body(db, &body, cb); | ||||||
|                 } |                 } | ||||||
|                 ModuleDefId::StaticId(it) => { |                 ModuleDefId::StaticId(it) => { | ||||||
|                     let def = it.into(); |                     let body = db.body(it.into()); | ||||||
|                     cb(def); |  | ||||||
|                     let body = db.body(def); |  | ||||||
|                     visit_body(db, &body, cb); |                     visit_body(db, &body, cb); | ||||||
|                 } |                 } | ||||||
|                 ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => { |                 ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => { | ||||||
|                     db.enum_data(it).variants.iter().for_each(|&(it, _)| { |                     db.enum_data(it).variants.iter().for_each(|&(it, _)| { | ||||||
|                         let def = it.into(); |                         let body = db.body(it.into()); | ||||||
|                         cb(def); |                         cb(it.into()); | ||||||
|                         let body = db.body(def); |  | ||||||
|                         visit_body(db, &body, cb); |                         visit_body(db, &body, cb); | ||||||
|                     }); |                     }); | ||||||
|                 } |                 } | ||||||
|  | @ -473,7 +483,7 @@ fn visit_module( | ||||||
|                         match item { |                         match item { | ||||||
|                             AssocItemId::FunctionId(it) => cb(it.into()), |                             AssocItemId::FunctionId(it) => cb(it.into()), | ||||||
|                             AssocItemId::ConstId(it) => cb(it.into()), |                             AssocItemId::ConstId(it) => cb(it.into()), | ||||||
|                             AssocItemId::TypeAliasId(_) => (), |                             AssocItemId::TypeAliasId(it) => cb(it.into()), | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|  | @ -483,7 +493,7 @@ fn visit_module( | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) { |     fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) { | ||||||
|         for (_, def_map) in body.blocks(db) { |         for (_, def_map) in body.blocks(db) { | ||||||
|             for (mod_id, _) in def_map.modules() { |             for (mod_id, _) in def_map.modules() { | ||||||
|                 visit_module(db, &def_map, mod_id, cb); |                 visit_module(db, &def_map, mod_id, cb); | ||||||
|  | @ -553,7 +563,13 @@ fn salsa_bug() { | ||||||
|     let module = db.module_for_file(pos.file_id); |     let module = db.module_for_file(pos.file_id); | ||||||
|     let crate_def_map = module.def_map(&db); |     let crate_def_map = module.def_map(&db); | ||||||
|     visit_module(&db, &crate_def_map, module.local_id, &mut |def| { |     visit_module(&db, &crate_def_map, module.local_id, &mut |def| { | ||||||
|         db.infer(def); |         db.infer(match def { | ||||||
|  |             ModuleDefId::FunctionId(it) => it.into(), | ||||||
|  |             ModuleDefId::EnumVariantId(it) => it.into(), | ||||||
|  |             ModuleDefId::ConstId(it) => it.into(), | ||||||
|  |             ModuleDefId::StaticId(it) => it.into(), | ||||||
|  |             _ => return, | ||||||
|  |         }); | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     let new_text = " |     let new_text = " | ||||||
|  | @ -586,6 +602,12 @@ fn salsa_bug() { | ||||||
|     let module = db.module_for_file(pos.file_id); |     let module = db.module_for_file(pos.file_id); | ||||||
|     let crate_def_map = module.def_map(&db); |     let crate_def_map = module.def_map(&db); | ||||||
|     visit_module(&db, &crate_def_map, module.local_id, &mut |def| { |     visit_module(&db, &crate_def_map, module.local_id, &mut |def| { | ||||||
|         db.infer(def); |         db.infer(match def { | ||||||
|  |             ModuleDefId::FunctionId(it) => it.into(), | ||||||
|  |             ModuleDefId::EnumVariantId(it) => it.into(), | ||||||
|  |             ModuleDefId::ConstId(it) => it.into(), | ||||||
|  |             ModuleDefId::StaticId(it) => it.into(), | ||||||
|  |             _ => return, | ||||||
|  |         }); | ||||||
|     }); |     }); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -24,6 +24,13 @@ fn check_closure_captures(ra_fixture: &str, expect: Expect) { | ||||||
| 
 | 
 | ||||||
|     let mut captures_info = Vec::new(); |     let mut captures_info = Vec::new(); | ||||||
|     for def in defs { |     for def in defs { | ||||||
|  |         let def = match def { | ||||||
|  |             hir_def::ModuleDefId::FunctionId(it) => it.into(), | ||||||
|  |             hir_def::ModuleDefId::EnumVariantId(it) => it.into(), | ||||||
|  |             hir_def::ModuleDefId::ConstId(it) => it.into(), | ||||||
|  |             hir_def::ModuleDefId::StaticId(it) => it.into(), | ||||||
|  |             _ => continue, | ||||||
|  |         }; | ||||||
|         let infer = db.infer(def); |         let infer = db.infer(def); | ||||||
|         let db = &db; |         let db = &db; | ||||||
|         captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| { |         captures_info.extend(infer.closure_info.iter().flat_map(|(closure_id, (captures, _))| { | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| use base_db::SourceDatabaseFileInputExt as _; | use base_db::SourceDatabaseFileInputExt as _; | ||||||
|  | use hir_def::ModuleDefId; | ||||||
| use test_fixture::WithFixture; | use test_fixture::WithFixture; | ||||||
| 
 | 
 | ||||||
| use crate::{db::HirDatabase, test_db::TestDB}; | use crate::{db::HirDatabase, test_db::TestDB}; | ||||||
|  | @ -19,7 +20,9 @@ fn foo() -> i32 { | ||||||
|             let module = db.module_for_file(pos.file_id.file_id()); |             let module = db.module_for_file(pos.file_id.file_id()); | ||||||
|             let crate_def_map = module.def_map(&db); |             let crate_def_map = module.def_map(&db); | ||||||
|             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { |             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { | ||||||
|                 db.infer(def); |                 if let ModuleDefId::FunctionId(it) = def { | ||||||
|  |                     db.infer(it.into()); | ||||||
|  |                 } | ||||||
|             }); |             }); | ||||||
|         }); |         }); | ||||||
|         assert!(format!("{events:?}").contains("infer")) |         assert!(format!("{events:?}").contains("infer")) | ||||||
|  | @ -39,7 +42,9 @@ fn foo() -> i32 { | ||||||
|             let module = db.module_for_file(pos.file_id.file_id()); |             let module = db.module_for_file(pos.file_id.file_id()); | ||||||
|             let crate_def_map = module.def_map(&db); |             let crate_def_map = module.def_map(&db); | ||||||
|             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { |             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { | ||||||
|                 db.infer(def); |                 if let ModuleDefId::FunctionId(it) = def { | ||||||
|  |                     db.infer(it.into()); | ||||||
|  |                 } | ||||||
|             }); |             }); | ||||||
|         }); |         }); | ||||||
|         assert!(!format!("{events:?}").contains("infer"), "{events:#?}") |         assert!(!format!("{events:?}").contains("infer"), "{events:#?}") | ||||||
|  | @ -66,7 +71,9 @@ fn baz() -> i32 { | ||||||
|             let module = db.module_for_file(pos.file_id.file_id()); |             let module = db.module_for_file(pos.file_id.file_id()); | ||||||
|             let crate_def_map = module.def_map(&db); |             let crate_def_map = module.def_map(&db); | ||||||
|             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { |             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { | ||||||
|                 db.infer(def); |                 if let ModuleDefId::FunctionId(it) = def { | ||||||
|  |                     db.infer(it.into()); | ||||||
|  |                 } | ||||||
|             }); |             }); | ||||||
|         }); |         }); | ||||||
|         assert!(format!("{events:?}").contains("infer")) |         assert!(format!("{events:?}").contains("infer")) | ||||||
|  | @ -91,7 +98,9 @@ fn baz() -> i32 { | ||||||
|             let module = db.module_for_file(pos.file_id.file_id()); |             let module = db.module_for_file(pos.file_id.file_id()); | ||||||
|             let crate_def_map = module.def_map(&db); |             let crate_def_map = module.def_map(&db); | ||||||
|             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { |             visit_module(&db, &crate_def_map, module.local_id, &mut |def| { | ||||||
|                 db.infer(def); |                 if let ModuleDefId::FunctionId(it) = def { | ||||||
|  |                     db.infer(it.into()); | ||||||
|  |                 } | ||||||
|             }); |             }); | ||||||
|         }); |         }); | ||||||
|         assert!(format!("{events:?}").matches("infer").count() == 1, "{events:#?}") |         assert!(format!("{events:?}").matches("infer").count() == 1, "{events:#?}") | ||||||
|  |  | ||||||
							
								
								
									
										1172
									
								
								crates/hir-ty/src/variance.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										1172
									
								
								crates/hir-ty/src/variance.rs
									
										
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load diff
											
										
									
								
							
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Lukas Wirth
						Lukas Wirth