[red-knot] Make the VERSIONS parser use ModuleName as its key type (#11968)

This commit is contained in:
Alex Waygood 2024-06-21 16:46:45 +01:00 committed by GitHub
parent 8de0cd6565
commit da79bac33c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 45 additions and 38 deletions

View file

@ -11,7 +11,7 @@ use crate::Db;
/// A module name, e.g. `foo.bar`. /// A module name, e.g. `foo.bar`.
/// ///
/// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`). /// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`).
#[derive(Clone, Debug, Eq, PartialEq, Hash)] #[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
pub struct ModuleName(smol_str::SmolStr); pub struct ModuleName(smol_str::SmolStr);
impl ModuleName { impl ModuleName {

View file

@ -5,9 +5,8 @@ use std::ops::{RangeFrom, RangeInclusive};
use std::str::FromStr; use std::str::FromStr;
use rustc_hash::FxHashMap; use rustc_hash::FxHashMap;
use smol_str::SmolStr;
use ruff_python_stdlib::identifiers::is_identifier; use crate::module::ModuleName;
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct TypeshedVersionsParseError { pub struct TypeshedVersionsParseError {
@ -82,7 +81,7 @@ impl fmt::Display for TypeshedVersionsParseErrorKind {
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug, PartialEq, Eq)]
pub struct TypeshedVersions(FxHashMap<SmolStr, PyVersionRange>); pub struct TypeshedVersions(FxHashMap<ModuleName, PyVersionRange>);
impl TypeshedVersions { impl TypeshedVersions {
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
@ -93,24 +92,22 @@ impl TypeshedVersions {
self.0.is_empty() self.0.is_empty()
} }
pub fn contains_module(&self, module_name: impl Into<SmolStr>) -> bool { pub fn contains_module(&self, module_name: &ModuleName) -> bool {
self.0.contains_key(&module_name.into()) self.0.contains_key(module_name)
} }
pub fn module_exists_on_version( pub fn module_exists_on_version(
&self, &self,
module: impl Into<SmolStr>, module: ModuleName,
version: impl Into<PyVersion>, version: impl Into<PyVersion>,
) -> bool { ) -> bool {
let version = version.into(); let version = version.into();
let mut module: Option<SmolStr> = Some(module.into()); let mut module: Option<ModuleName> = Some(module);
while let Some(module_to_try) = module { while let Some(module_to_try) = module {
if let Some(range) = self.0.get(&module_to_try) { if let Some(range) = self.0.get(&module_to_try) {
return range.contains(version); return range.contains(version);
} }
module = module_to_try module = module_to_try.parent();
.rsplit_once('.')
.map(|(parent, _)| SmolStr::new(parent));
} }
false false
} }
@ -149,15 +146,14 @@ impl FromStr for TypeshedVersions {
}); });
}; };
let module_name = SmolStr::new(module_name); let Some(module_name) = ModuleName::new(module_name) else {
if !module_name.split('.').all(is_identifier) {
return Err(TypeshedVersionsParseError { return Err(TypeshedVersionsParseError {
line_number, line_number,
reason: TypeshedVersionsParseErrorKind::InvalidModuleName( reason: TypeshedVersionsParseErrorKind::InvalidModuleName(
module_name.to_string(), module_name.to_string(),
), ),
}); });
} };
match PyVersionRange::from_str(rest) { match PyVersionRange::from_str(rest) {
Ok(version) => map.insert(module_name, version), Ok(version) => map.insert(module_name, version),
@ -176,7 +172,7 @@ impl FromStr for TypeshedVersions {
impl fmt::Display for TypeshedVersions { impl fmt::Display for TypeshedVersions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let sorted_items: BTreeMap<&SmolStr, &PyVersionRange> = self.0.iter().collect(); let sorted_items: BTreeMap<&ModuleName, &PyVersionRange> = self.0.iter().collect();
for (module_name, range) in sorted_items { for (module_name, range) in sorted_items {
writeln!(f, "{module_name}: {range}")?; writeln!(f, "{module_name}: {range}")?;
} }
@ -331,16 +327,22 @@ mod tests {
assert!(versions.len() > 100); assert!(versions.len() > 100);
assert!(versions.len() < 1000); assert!(versions.len() < 1000);
assert!(versions.contains_module("asyncio")); let asyncio = ModuleName::new_static("asyncio").unwrap();
assert!(versions.module_exists_on_version("asyncio", SupportedPyVersion::Py310)); let asyncio_staggered = ModuleName::new_static("asyncio.staggered").unwrap();
let audioop = ModuleName::new_static("audioop").unwrap();
assert!(versions.contains_module("asyncio.staggered")); assert!(versions.contains_module(&asyncio));
assert!(versions.module_exists_on_version("asyncio.staggered", SupportedPyVersion::Py38)); assert!(versions.module_exists_on_version(asyncio, SupportedPyVersion::Py310));
assert!(!versions.module_exists_on_version("asyncio.staggered", SupportedPyVersion::Py37));
assert!(versions.contains_module("audioop")); assert!(versions.contains_module(&asyncio_staggered));
assert!(versions.module_exists_on_version("audioop", SupportedPyVersion::Py312)); assert!(
assert!(!versions.module_exists_on_version("audioop", SupportedPyVersion::Py313)); versions.module_exists_on_version(asyncio_staggered.clone(), SupportedPyVersion::Py38)
);
assert!(!versions.module_exists_on_version(asyncio_staggered, SupportedPyVersion::Py37));
assert!(versions.contains_module(&audioop));
assert!(versions.module_exists_on_version(audioop.clone(), SupportedPyVersion::Py312));
assert!(!versions.module_exists_on_version(audioop, SupportedPyVersion::Py313));
} }
#[test] #[test]
@ -368,24 +370,29 @@ foo: 3.8- # trailing comment
"### "###
); );
assert!(parsed_versions.contains_module("foo")); let foo = ModuleName::new_static("foo").unwrap();
assert!(!parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py37)); let bar = ModuleName::new_static("bar").unwrap();
assert!(parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py38)); let bar_baz = ModuleName::new_static("bar.baz").unwrap();
assert!(parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py311)); let spam = ModuleName::new_static("spam").unwrap();
assert!(parsed_versions.contains_module("bar")); assert!(parsed_versions.contains_module(&foo));
assert!(parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py37)); assert!(!parsed_versions.module_exists_on_version(foo.clone(), SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py310)); assert!(parsed_versions.module_exists_on_version(foo.clone(), SupportedPyVersion::Py38));
assert!(!parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py311)); assert!(parsed_versions.module_exists_on_version(foo, SupportedPyVersion::Py311));
assert!(parsed_versions.contains_module("bar.baz")); assert!(parsed_versions.contains_module(&bar));
assert!(parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py37)); assert!(parsed_versions.module_exists_on_version(bar.clone(), SupportedPyVersion::Py37));
assert!(parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py39)); assert!(parsed_versions.module_exists_on_version(bar.clone(), SupportedPyVersion::Py310));
assert!(!parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py310)); assert!(!parsed_versions.module_exists_on_version(bar, SupportedPyVersion::Py311));
assert!(!parsed_versions.contains_module("spam")); assert!(parsed_versions.contains_module(&bar_baz));
assert!(!parsed_versions.module_exists_on_version("spam", SupportedPyVersion::Py37)); assert!(parsed_versions.module_exists_on_version(bar_baz.clone(), SupportedPyVersion::Py37));
assert!(!parsed_versions.module_exists_on_version("spam", SupportedPyVersion::Py313)); assert!(parsed_versions.module_exists_on_version(bar_baz.clone(), SupportedPyVersion::Py39));
assert!(!parsed_versions.module_exists_on_version(bar_baz, SupportedPyVersion::Py310));
assert!(!parsed_versions.contains_module(&spam));
assert!(!parsed_versions.module_exists_on_version(spam.clone(), SupportedPyVersion::Py37));
assert!(!parsed_versions.module_exists_on_version(spam, SupportedPyVersion::Py313));
} }
#[test] #[test]