diff --git a/crates/red_knot_module_resolver/src/module.rs b/crates/red_knot_module_resolver/src/module.rs index 507ee12b88..45ad78145c 100644 --- a/crates/red_knot_module_resolver/src/module.rs +++ b/crates/red_knot_module_resolver/src/module.rs @@ -11,7 +11,7 @@ use crate::Db; /// A module name, e.g. `foo.bar`. /// /// Always normalized to the absolute form (never a relative module name, i.e., never `.foo`). -#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[derive(Clone, Debug, Eq, PartialEq, Hash, PartialOrd, Ord)] pub struct ModuleName(smol_str::SmolStr); impl ModuleName { diff --git a/crates/red_knot_module_resolver/src/typeshed/versions.rs b/crates/red_knot_module_resolver/src/typeshed/versions.rs index 4653ed7377..e247fbc9dc 100644 --- a/crates/red_knot_module_resolver/src/typeshed/versions.rs +++ b/crates/red_knot_module_resolver/src/typeshed/versions.rs @@ -5,9 +5,8 @@ use std::ops::{RangeFrom, RangeInclusive}; use std::str::FromStr; use rustc_hash::FxHashMap; -use smol_str::SmolStr; -use ruff_python_stdlib::identifiers::is_identifier; +use crate::module::ModuleName; #[derive(Debug, PartialEq, Eq)] pub struct TypeshedVersionsParseError { @@ -82,7 +81,7 @@ impl fmt::Display for TypeshedVersionsParseErrorKind { } #[derive(Debug, PartialEq, Eq)] -pub struct TypeshedVersions(FxHashMap); +pub struct TypeshedVersions(FxHashMap); impl TypeshedVersions { pub fn len(&self) -> usize { @@ -93,24 +92,22 @@ impl TypeshedVersions { self.0.is_empty() } - pub fn contains_module(&self, module_name: impl Into) -> bool { - self.0.contains_key(&module_name.into()) + pub fn contains_module(&self, module_name: &ModuleName) -> bool { + self.0.contains_key(module_name) } pub fn module_exists_on_version( &self, - module: impl Into, + module: ModuleName, version: impl Into, ) -> bool { let version = version.into(); - let mut module: Option = Some(module.into()); + let mut module: Option = Some(module); while let Some(module_to_try) = module { if let Some(range) = self.0.get(&module_to_try) { return range.contains(version); } - module = module_to_try - .rsplit_once('.') - .map(|(parent, _)| SmolStr::new(parent)); + module = module_to_try.parent(); } false } @@ -149,15 +146,14 @@ impl FromStr for TypeshedVersions { }); }; - let module_name = SmolStr::new(module_name); - if !module_name.split('.').all(is_identifier) { + let Some(module_name) = ModuleName::new(module_name) else { return Err(TypeshedVersionsParseError { line_number, reason: TypeshedVersionsParseErrorKind::InvalidModuleName( module_name.to_string(), ), }); - } + }; match PyVersionRange::from_str(rest) { Ok(version) => map.insert(module_name, version), @@ -176,7 +172,7 @@ impl FromStr for TypeshedVersions { impl fmt::Display for TypeshedVersions { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let sorted_items: BTreeMap<&SmolStr, &PyVersionRange> = self.0.iter().collect(); + let sorted_items: BTreeMap<&ModuleName, &PyVersionRange> = self.0.iter().collect(); for (module_name, range) in sorted_items { writeln!(f, "{module_name}: {range}")?; } @@ -331,16 +327,22 @@ mod tests { assert!(versions.len() > 100); assert!(versions.len() < 1000); - assert!(versions.contains_module("asyncio")); - assert!(versions.module_exists_on_version("asyncio", SupportedPyVersion::Py310)); + let asyncio = ModuleName::new_static("asyncio").unwrap(); + let asyncio_staggered = ModuleName::new_static("asyncio.staggered").unwrap(); + let audioop = ModuleName::new_static("audioop").unwrap(); - assert!(versions.contains_module("asyncio.staggered")); - assert!(versions.module_exists_on_version("asyncio.staggered", SupportedPyVersion::Py38)); - assert!(!versions.module_exists_on_version("asyncio.staggered", SupportedPyVersion::Py37)); + assert!(versions.contains_module(&asyncio)); + assert!(versions.module_exists_on_version(asyncio, SupportedPyVersion::Py310)); - assert!(versions.contains_module("audioop")); - assert!(versions.module_exists_on_version("audioop", SupportedPyVersion::Py312)); - assert!(!versions.module_exists_on_version("audioop", SupportedPyVersion::Py313)); + assert!(versions.contains_module(&asyncio_staggered)); + assert!( + versions.module_exists_on_version(asyncio_staggered.clone(), SupportedPyVersion::Py38) + ); + assert!(!versions.module_exists_on_version(asyncio_staggered, SupportedPyVersion::Py37)); + + assert!(versions.contains_module(&audioop)); + assert!(versions.module_exists_on_version(audioop.clone(), SupportedPyVersion::Py312)); + assert!(!versions.module_exists_on_version(audioop, SupportedPyVersion::Py313)); } #[test] @@ -368,24 +370,29 @@ foo: 3.8- # trailing comment "### ); - assert!(parsed_versions.contains_module("foo")); - assert!(!parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py37)); - assert!(parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py38)); - assert!(parsed_versions.module_exists_on_version("foo", SupportedPyVersion::Py311)); + let foo = ModuleName::new_static("foo").unwrap(); + let bar = ModuleName::new_static("bar").unwrap(); + let bar_baz = ModuleName::new_static("bar.baz").unwrap(); + let spam = ModuleName::new_static("spam").unwrap(); - assert!(parsed_versions.contains_module("bar")); - assert!(parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py37)); - assert!(parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py310)); - assert!(!parsed_versions.module_exists_on_version("bar", SupportedPyVersion::Py311)); + assert!(parsed_versions.contains_module(&foo)); + assert!(!parsed_versions.module_exists_on_version(foo.clone(), SupportedPyVersion::Py37)); + assert!(parsed_versions.module_exists_on_version(foo.clone(), SupportedPyVersion::Py38)); + assert!(parsed_versions.module_exists_on_version(foo, SupportedPyVersion::Py311)); - assert!(parsed_versions.contains_module("bar.baz")); - assert!(parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py37)); - assert!(parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py39)); - assert!(!parsed_versions.module_exists_on_version("bar.baz", SupportedPyVersion::Py310)); + assert!(parsed_versions.contains_module(&bar)); + assert!(parsed_versions.module_exists_on_version(bar.clone(), SupportedPyVersion::Py37)); + assert!(parsed_versions.module_exists_on_version(bar.clone(), SupportedPyVersion::Py310)); + assert!(!parsed_versions.module_exists_on_version(bar, SupportedPyVersion::Py311)); - assert!(!parsed_versions.contains_module("spam")); - assert!(!parsed_versions.module_exists_on_version("spam", SupportedPyVersion::Py37)); - assert!(!parsed_versions.module_exists_on_version("spam", SupportedPyVersion::Py313)); + assert!(parsed_versions.contains_module(&bar_baz)); + assert!(parsed_versions.module_exists_on_version(bar_baz.clone(), SupportedPyVersion::Py37)); + assert!(parsed_versions.module_exists_on_version(bar_baz.clone(), SupportedPyVersion::Py39)); + assert!(!parsed_versions.module_exists_on_version(bar_baz, SupportedPyVersion::Py310)); + + assert!(!parsed_versions.contains_module(&spam)); + assert!(!parsed_versions.module_exists_on_version(spam.clone(), SupportedPyVersion::Py37)); + assert!(!parsed_versions.module_exists_on_version(spam, SupportedPyVersion::Py313)); } #[test]