Implement parameter variance inference

This commit is contained in:
Lukas Wirth 2024-12-28 15:08:26 +01:00
parent 17b3662755
commit eee2761140
8 changed files with 1271 additions and 37 deletions

View file

@ -127,7 +127,15 @@ fn check_impl(ra_fixture: &str, allow_none: bool, only_types: bool, display_sour
None => continue,
};
let def_map = module.def_map(&db);
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
});
}
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
@ -375,7 +383,15 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
let def_map = module.def_map(&db);
let mut defs: Vec<DefWithBodyId> = Vec::new();
visit_module(&db, &def_map, module.local_id, &mut |it| defs.push(it));
visit_module(&db, &def_map, module.local_id, &mut |it| {
defs.push(match it {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
})
});
defs.sort_by_key(|def| match def {
DefWithBodyId::FunctionId(it) => {
let loc = it.lookup(&db);
@ -405,11 +421,11 @@ fn infer_with_mismatches(content: &str, include_mismatches: bool) -> String {
buf
}
fn visit_module(
pub(crate) fn visit_module(
db: &TestDB,
crate_def_map: &DefMap,
module_id: LocalModuleId,
cb: &mut dyn FnMut(DefWithBodyId),
cb: &mut dyn FnMut(ModuleDefId),
) {
visit_scope(db, crate_def_map, &crate_def_map[module_id].scope, cb);
for impl_id in crate_def_map[module_id].scope.impls() {
@ -417,18 +433,18 @@ fn visit_module(
for &item in impl_data.items.iter() {
match item {
AssocItemId::FunctionId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
}
AssocItemId::ConstId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
}
AssocItemId::TypeAliasId(_) => (),
AssocItemId::TypeAliasId(it) => {
cb(it.into());
}
}
}
}
@ -437,33 +453,27 @@ fn visit_module(
db: &TestDB,
crate_def_map: &DefMap,
scope: &ItemScope,
cb: &mut dyn FnMut(DefWithBodyId),
cb: &mut dyn FnMut(ModuleDefId),
) {
for decl in scope.declarations() {
cb(decl);
match decl {
ModuleDefId::FunctionId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::ConstId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::StaticId(it) => {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
visit_body(db, &body, cb);
}
ModuleDefId::AdtId(hir_def::AdtId::EnumId(it)) => {
db.enum_data(it).variants.iter().for_each(|&(it, _)| {
let def = it.into();
cb(def);
let body = db.body(def);
let body = db.body(it.into());
cb(it.into());
visit_body(db, &body, cb);
});
}
@ -473,7 +483,7 @@ fn visit_module(
match item {
AssocItemId::FunctionId(it) => cb(it.into()),
AssocItemId::ConstId(it) => cb(it.into()),
AssocItemId::TypeAliasId(_) => (),
AssocItemId::TypeAliasId(it) => cb(it.into()),
}
}
}
@ -483,7 +493,7 @@ fn visit_module(
}
}
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(DefWithBodyId)) {
fn visit_body(db: &TestDB, body: &Body, cb: &mut dyn FnMut(ModuleDefId)) {
for (_, def_map) in body.blocks(db) {
for (mod_id, _) in def_map.modules() {
visit_module(db, &def_map, mod_id, cb);
@ -553,7 +563,13 @@ fn salsa_bug() {
let module = db.module_for_file(pos.file_id);
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
db.infer(match def {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
});
});
let new_text = "
@ -586,6 +602,12 @@ fn salsa_bug() {
let module = db.module_for_file(pos.file_id);
let crate_def_map = module.def_map(&db);
visit_module(&db, &crate_def_map, module.local_id, &mut |def| {
db.infer(def);
db.infer(match def {
ModuleDefId::FunctionId(it) => it.into(),
ModuleDefId::EnumVariantId(it) => it.into(),
ModuleDefId::ConstId(it) => it.into(),
ModuleDefId::StaticId(it) => it.into(),
_ => return,
});
});
}