Auto merge of #16979 - Nadrieril:contiguous-enum-id, r=Veykril

pattern analysis: Use contiguous indices for enum variants

The main blocker to using the in-tree version of the `pattern_analysis` crate is that rustc requires enum indices to be contiguous because it uses `IndexVec`/`BitSet` for performance. Currently we swap these out for `FxHashMap`/`FxHashSet` when the `rustc` feature is off, but we can't do that if we use the in-tree crate.

This PR solves the problem by using contiguous indices on the r-a side too.
This commit is contained in:
bors 2024-04-01 12:15:23 +00:00
commit c82d168b79

View file

@ -3,7 +3,7 @@
use std::fmt; use std::fmt;
use tracing::debug; use tracing::debug;
use hir_def::{DefWithBodyId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId}; use hir_def::{DefWithBodyId, EnumId, EnumVariantId, HasModule, LocalFieldId, ModuleId, VariantId};
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use rustc_pattern_analysis::{ use rustc_pattern_analysis::{
constructor::{Constructor, ConstructorSet, VariantVisibility}, constructor::{Constructor, ConstructorSet, VariantVisibility},
@ -36,6 +36,24 @@ pub(crate) type WitnessPat<'p> = rustc_pattern_analysis::pat::WitnessPat<MatchCh
#[derive(Copy, Clone, Debug, PartialEq, Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum Void {} pub(crate) enum Void {}
/// An index type for enum variants. This ranges from 0 to `variants.len()`, whereas `EnumVariantId`
/// can take arbitrary large values (and hence mustn't be used with `IndexVec`/`BitSet`).
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct EnumVariantContiguousIndex(usize);
impl EnumVariantContiguousIndex {
fn from_enum_variant_id(db: &dyn HirDatabase, target_evid: EnumVariantId) -> Self {
// Find the index of this variant in the list of variants.
use hir_def::Lookup;
let i = target_evid.lookup(db.upcast()).index as usize;
EnumVariantContiguousIndex(i)
}
fn to_enum_variant_id(self, db: &dyn HirDatabase, eid: EnumId) -> EnumVariantId {
db.enum_data(eid).variants[self.0].0
}
}
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct MatchCheckCtx<'p> { pub(crate) struct MatchCheckCtx<'p> {
module: ModuleId, module: ModuleId,
@ -89,9 +107,18 @@ impl<'p> MatchCheckCtx<'p> {
} }
} }
fn variant_id_for_adt(ctor: &Constructor<Self>, adt: hir_def::AdtId) -> Option<VariantId> { fn variant_id_for_adt(
db: &'p dyn HirDatabase,
ctor: &Constructor<Self>,
adt: hir_def::AdtId,
) -> Option<VariantId> {
match ctor { match ctor {
&Variant(id) => Some(id.into()), Variant(id) => {
let hir_def::AdtId::EnumId(eid) = adt else {
panic!("bad constructor {ctor:?} for adt {adt:?}")
};
Some(id.to_enum_variant_id(db, eid).into())
}
Struct | UnionField => match adt { Struct | UnionField => match adt {
hir_def::AdtId::EnumId(_) => None, hir_def::AdtId::EnumId(_) => None,
hir_def::AdtId::StructId(id) => Some(id.into()), hir_def::AdtId::StructId(id) => Some(id.into()),
@ -175,19 +202,24 @@ impl<'p> MatchCheckCtx<'p> {
ctor = Struct; ctor = Struct;
arity = 1; arity = 1;
} }
&TyKind::Adt(adt, _) => { &TyKind::Adt(AdtId(adt), _) => {
ctor = match pat.kind.as_ref() { ctor = match pat.kind.as_ref() {
PatKind::Leaf { .. } if matches!(adt.0, hir_def::AdtId::UnionId(_)) => { PatKind::Leaf { .. } if matches!(adt, hir_def::AdtId::UnionId(_)) => {
UnionField UnionField
} }
PatKind::Leaf { .. } => Struct, PatKind::Leaf { .. } => Struct,
PatKind::Variant { enum_variant, .. } => Variant(*enum_variant), PatKind::Variant { enum_variant, .. } => {
Variant(EnumVariantContiguousIndex::from_enum_variant_id(
self.db,
*enum_variant,
))
}
_ => { _ => {
never!(); never!();
Wildcard Wildcard
} }
}; };
let variant = Self::variant_id_for_adt(&ctor, adt.0).unwrap(); let variant = Self::variant_id_for_adt(self.db, &ctor, adt).unwrap();
arity = variant.variant_data(self.db.upcast()).fields().len(); arity = variant.variant_data(self.db.upcast()).fields().len();
} }
_ => { _ => {
@ -239,7 +271,7 @@ impl<'p> MatchCheckCtx<'p> {
PatKind::Deref { subpattern: subpatterns.next().unwrap() } PatKind::Deref { subpattern: subpatterns.next().unwrap() }
} }
TyKind::Adt(adt, substs) => { TyKind::Adt(adt, substs) => {
let variant = Self::variant_id_for_adt(pat.ctor(), adt.0).unwrap(); let variant = Self::variant_id_for_adt(self.db, pat.ctor(), adt.0).unwrap();
let subpatterns = self let subpatterns = self
.list_variant_fields(pat.ty(), variant) .list_variant_fields(pat.ty(), variant)
.zip(subpatterns) .zip(subpatterns)
@ -277,7 +309,7 @@ impl<'p> MatchCheckCtx<'p> {
impl<'p> PatCx for MatchCheckCtx<'p> { impl<'p> PatCx for MatchCheckCtx<'p> {
type Error = (); type Error = ();
type Ty = Ty; type Ty = Ty;
type VariantIdx = EnumVariantId; type VariantIdx = EnumVariantContiguousIndex;
type StrLit = Void; type StrLit = Void;
type ArmData = (); type ArmData = ();
type PatData = PatData<'p>; type PatData = PatData<'p>;
@ -303,7 +335,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
// patterns. If we're here we can assume this is a box pattern. // patterns. If we're here we can assume this is a box pattern.
1 1
} else { } else {
let variant = Self::variant_id_for_adt(ctor, adt).unwrap(); let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
variant.variant_data(self.db.upcast()).fields().len() variant.variant_data(self.db.upcast()).fields().len()
} }
} }
@ -343,7 +375,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone(); let subst_ty = substs.at(Interner, 0).assert_ty_ref(Interner).clone();
single(subst_ty) single(subst_ty)
} else { } else {
let variant = Self::variant_id_for_adt(ctor, adt).unwrap(); let variant = Self::variant_id_for_adt(self.db, ctor, adt).unwrap();
let (adt, _) = ty.as_adt().unwrap(); let (adt, _) = ty.as_adt().unwrap();
let adt_is_local = let adt_is_local =
@ -421,7 +453,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
ConstructorSet::NoConstructors ConstructorSet::NoConstructors
} else { } else {
let mut variants = FxHashMap::default(); let mut variants = FxHashMap::default();
for &(variant, _) in enum_data.variants.iter() { for (i, &(variant, _)) in enum_data.variants.iter().enumerate() {
let is_uninhabited = let is_uninhabited =
is_enum_variant_uninhabited_from(variant, subst, cx.module, cx.db); is_enum_variant_uninhabited_from(variant, subst, cx.module, cx.db);
let visibility = if is_uninhabited { let visibility = if is_uninhabited {
@ -429,7 +461,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
} else { } else {
VariantVisibility::Visible VariantVisibility::Visible
}; };
variants.insert(variant, visibility); variants.insert(EnumVariantContiguousIndex(i), visibility);
} }
ConstructorSet::Variants { ConstructorSet::Variants {
@ -453,10 +485,10 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
f: &mut fmt::Formatter<'_>, f: &mut fmt::Formatter<'_>,
pat: &rustc_pattern_analysis::pat::DeconstructedPat<Self>, pat: &rustc_pattern_analysis::pat::DeconstructedPat<Self>,
) -> fmt::Result { ) -> fmt::Result {
let variant =
pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(pat.ctor(), adt));
let db = pat.data().db; let db = pat.data().db;
let variant =
pat.ty().as_adt().and_then(|(adt, _)| Self::variant_id_for_adt(db, pat.ctor(), adt));
if let Some(variant) = variant { if let Some(variant) = variant {
match variant { match variant {
VariantId::EnumVariantId(v) => { VariantId::EnumVariantId(v) => {