diff --git a/crates/pypi-types/src/lenient_requirement.rs b/crates/pypi-types/src/lenient_requirement.rs index 867cccb9d..6dcacfed3 100644 --- a/crates/pypi-types/src/lenient_requirement.rs +++ b/crates/pypi-types/src/lenient_requirement.rs @@ -1,11 +1,105 @@ -use pep440_rs::{Pep440Error, VersionSpecifiers}; -use serde::{de, Deserialize, Deserializer, Serialize}; use std::str::FromStr; + +use once_cell::sync::Lazy; +use regex::Regex; +use serde::{de, Deserialize, Deserializer, Serialize}; use tracing::warn; +use pep440_rs::{Pep440Error, VersionSpecifiers}; +use pep508_rs::{Pep508Error, Requirement}; + +/// Ex) `>=7.2.0<8.0.0` +static MISSING_COMMA: Lazy = Lazy::new(|| Regex::new(r"(\d)([<>=~^!])").unwrap()); +/// Ex) `!=~5.0` +static NOT_EQUAL_TILDE: Lazy = Lazy::new(|| Regex::new(r"!=~((?:\d\.)*\d)").unwrap()); +/// Ex) `>=1.9.*` +static GREATER_THAN_STAR: Lazy = Lazy::new(|| Regex::new(r">=(\d+\.\d+)\.\*").unwrap()); +/// Ex) `!=3.0*` +static MISSING_DOT: Lazy = Lazy::new(|| Regex::new(r"(\d\.\d)+\*").unwrap()); +/// Ex) `>=3.6,` +static TRAILING_COMMA: Lazy = Lazy::new(|| Regex::new(r"(\d\.\d)+,$").unwrap()); + +/// Like [`Requirement`], but attempts to correct some common errors in user-provided requirements. +#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] +pub struct LenientRequirement(Requirement); + +impl FromStr for LenientRequirement { + type Err = Pep508Error; + + fn from_str(s: &str) -> Result { + match Requirement::from_str(s) { + Ok(requirement) => Ok(Self(requirement)), + Err(err) => { + // Given `elasticsearch-dsl (>=7.2.0<8.0.0)`, rewrite to `elasticsearch-dsl (>=7.2.0,<8.0.0)`. + let patched = MISSING_COMMA.replace_all(s, r"$1,$2"); + if patched != s { + if let Ok(requirement) = Requirement::from_str(&patched) { + warn!( + "Inserting missing comma into invalid requirement (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(requirement)); + } + } + + // Given `jupyter-core (!=~5.0,>=4.12)`, rewrite to `jupyter-core (!=5.0.*,>=4.12)`. + let patched = NOT_EQUAL_TILDE.replace_all(s, r"!=${1}.*"); + if patched != s { + if let Ok(requirement) = Requirement::from_str(&patched) { + warn!( + "Adding wildcard after invalid tilde operator (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(requirement)); + } + } + + // Given `torch (>=1.9.*)`, rewrite to `torch (>=1.9)`. + let patched = GREATER_THAN_STAR.replace_all(s, r">=${1}"); + if patched != s { + if let Ok(requirement) = Requirement::from_str(&patched) { + warn!( + "Removing star after greater equal operator (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(requirement)); + } + } + + // Given `pyzmq (!=3.0*)`, rewrite to `pyzmq (!=3.0.*)`. + let patched = MISSING_DOT.replace_all(s, r"${1}.*"); + if patched != s { + if let Ok(requirement) = Requirement::from_str(&patched) { + warn!( + "Inserting missing dot into invalid requirement (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(requirement)); + } + } + + // Given `pyzmq (>=3.6,)`, rewrite to `pyzmq (>=3.6)` + let patched = TRAILING_COMMA.replace_all(s, r"${1}"); + if patched != s { + if let Ok(requirement) = Requirement::from_str(&patched) { + warn!( + "Removing trailing comma from invalid requirement (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(requirement)); + } + } + + Err(err) + } + } + } +} + +impl From for Requirement { + fn from(requirement: LenientRequirement) -> Self { + requirement.0 + } +} + /// Like [`VersionSpecifiers`], but attempts to correct some common errors in user-provided requirements. /// -/// We turn `>=3.x.*` into `>=3.x` +/// For example, we turn `>=3.x.*` into `>=3.x`. #[derive(Debug, Clone, Serialize, Eq, PartialEq)] pub struct LenientVersionSpecifiers(VersionSpecifiers); @@ -16,45 +110,61 @@ impl FromStr for LenientVersionSpecifiers { match VersionSpecifiers::from_str(s) { Ok(specifiers) => Ok(Self(specifiers)), Err(err) => { - // Given `>=3.5.*`, rewrite to `>=3.5`. - let patched = match s { - ">=3.12.*" => Some(">=3.12"), - ">=3.11.*" => Some(">=3.11"), - ">=3.10.*" => Some(">=3.10"), - ">=3.9.*" => Some(">=3.9"), - ">=3.8.*" => Some(">=3.8"), - ">=3.7.*" => Some(">=3.7"), - ">=3.6.*" => Some(">=3.6"), - ">=3.5.*" => Some(">=3.5"), - ">=3.4.*" => Some(">=3.4"), - ">=3.3.*" => Some(">=3.3"), - ">=3.2.*" => Some(">=3.2"), - ">=3.1.*" => Some(">=3.1"), - ">=3.0.*" => Some(">=3.0"), - ">=3.12," => Some(">=3.12"), - ">=3.11," => Some(">=3.11"), - ">=3.10," => Some(">=3.10"), - ">=3.9," => Some(">=3.9"), - ">=3.8," => Some(">=3.8"), - ">=3.7," => Some(">=3.7"), - ">=3.6," => Some(">=3.6"), - ">=3.5," => Some(">=3.5"), - ">=3.4," => Some(">=3.4"), - ">=3.3," => Some(">=3.3"), - ">=3.2," => Some(">=3.2"), - ">=3.1," => Some(">=3.1"), - ">=3.0," => Some(">=3.0"), - ">=2.7,!=3.0*,!=3.1*,!=3.2*" => Some(">=2.7,!=3.0.*,!=3.1.*,!=3.2.*"), - _ => None, - }; - if let Some(patched) = patched { - if let Ok(specifier) = VersionSpecifiers::from_str(patched) { + // Given `>=7.2.0<8.0.0`, rewrite to `>=7.2.0,<8.0.0`. + let patched = MISSING_COMMA.replace_all(s, r"$1,$2"); + if patched != s { + if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) { warn!( - "Correcting invalid wildcard bound on version specifier (before: `{s}`; after: `{patched}`)", + "Inserting missing comma into invalid specifier (before: `{s}`; after: `{patched}`)", ); - return Ok(Self(specifier)); + return Ok(Self(specifiers)); } } + + // Given `!=~5.0,>=4.12`, rewrite to `!=5.0.*,>=4.12`. + let patched = NOT_EQUAL_TILDE.replace_all(s, r"!=${1}.*"); + if patched != s { + if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) { + warn!( + "Adding wildcard after invalid tilde operator (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(specifiers)); + } + } + + // Given `>=1.9.*`, rewrite to `>=1.9`. + let patched = GREATER_THAN_STAR.replace_all(s, r">=${1}"); + if patched != s { + if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) { + warn!( + "Removing star after greater equal operator (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(specifiers)); + } + } + + // Given `!=3.0*`, rewrite to `!=3.0.*`. + let patched = MISSING_DOT.replace_all(s, r"${1}.*"); + if patched != s { + if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) { + warn!( + "Inserting missing dot into invalid specifier (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(specifiers)); + } + } + + // Given `>=3.6,`, rewrite to `>=3.6` + let patched = TRAILING_COMMA.replace_all(s, r"${1}"); + if patched != s { + if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) { + warn!( + "Removing trailing comma from invalid specifier (before: `{s}`; after: `{patched}`)", + ); + return Ok(Self(specifiers)); + } + } + Err(err) } } @@ -76,3 +186,118 @@ impl<'de> Deserialize<'de> for LenientVersionSpecifiers { Self::from_str(&s).map_err(de::Error::custom) } } + +#[cfg(test)] +mod tests { + use pep440_rs::VersionSpecifiers; + use std::str::FromStr; + + use crate::LenientVersionSpecifiers; + use pep508_rs::Requirement; + + use super::LenientRequirement; + + #[test] + fn requirement_missing_comma() { + let actual: Requirement = LenientRequirement::from_str("elasticsearch-dsl (>=7.2.0<8.0.0)") + .unwrap() + .into(); + let expected: Requirement = + Requirement::from_str("elasticsearch-dsl (>=7.2.0,<8.0.0)").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn requirement_not_equal_tile() { + let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5.0,>=4.12)") + .unwrap() + .into(); + let expected: Requirement = Requirement::from_str("jupyter-core (!=5.0.*,>=4.12)").unwrap(); + assert_eq!(actual, expected); + + let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5,>=4.12)") + .unwrap() + .into(); + let expected: Requirement = Requirement::from_str("jupyter-core (!=5.*,>=4.12)").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn requirement_greater_than_star() { + let actual: Requirement = LenientRequirement::from_str("torch (>=1.9.*)") + .unwrap() + .into(); + let expected: Requirement = Requirement::from_str("torch (>=1.9)").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn requirement_missing_dot() { + let actual: Requirement = + LenientRequirement::from_str("pyzmq (>=2.7,!=3.0*,!=3.1*,!=3.2*)") + .unwrap() + .into(); + let expected: Requirement = + Requirement::from_str("pyzmq (>=2.7,!=3.0.*,!=3.1.*,!=3.2.*)").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn requirement_trailing_comma() { + let actual: Requirement = LenientRequirement::from_str("pyzmq >=3.6,").unwrap().into(); + let expected: Requirement = Requirement::from_str("pyzmq >=3.6").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn specifier_missing_comma() { + let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=7.2.0<8.0.0") + .unwrap() + .into(); + let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=7.2.0,<8.0.0").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn specifier_not_equal_tile() { + let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5.0,>=4.12") + .unwrap() + .into(); + let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.0.*,>=4.12").unwrap(); + assert_eq!(actual, expected); + + let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5,>=4.12") + .unwrap() + .into(); + let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.*,>=4.12").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn specifier_greater_than_star() { + let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=1.9.*") + .unwrap() + .into(); + let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=1.9").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn specifier_missing_dot() { + let actual: VersionSpecifiers = + LenientVersionSpecifiers::from_str(">=2.7,!=3.0*,!=3.1*,!=3.2*") + .unwrap() + .into(); + let expected: VersionSpecifiers = + VersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,!=3.2.*").unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn specifier_trailing_comma() { + let actual: VersionSpecifiers = + LenientVersionSpecifiers::from_str(">=3.6,").unwrap().into(); + let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap(); + assert_eq!(actual, expected); + } +} diff --git a/crates/pypi-types/src/metadata.rs b/crates/pypi-types/src/metadata.rs index 72b406a2e..0423623b9 100644 --- a/crates/pypi-types/src/metadata.rs +++ b/crates/pypi-types/src/metadata.rs @@ -5,17 +5,15 @@ use std::io; use std::str::FromStr; use mailparse::{MailHeaderMap, MailParseError}; -use once_cell::sync::Lazy; -use regex::Regex; use serde::{Deserialize, Serialize}; use thiserror::Error; -use tracing::warn; -use crate::lenient_requirement::LenientVersionSpecifiers; use pep440_rs::{Pep440Error, Version, VersionSpecifiers}; use pep508_rs::{Pep508Error, Requirement}; use puffin_normalize::{ExtraName, InvalidNameError, PackageName}; +use crate::lenient_requirement::{LenientRequirement, LenientVersionSpecifiers}; + /// Python Package Metadata 2.1 as specified in /// /// @@ -210,155 +208,3 @@ impl Metadata21 { }) } } - -/// Ex) `>=7.2.0<8.0.0` -static MISSING_COMMA: Lazy = Lazy::new(|| Regex::new(r"(\d)([<>=~^!])").unwrap()); -/// Ex) `!=~5.0` -static NOT_EQUAL_TILDE: Lazy = Lazy::new(|| Regex::new(r"!=~((?:\d\.)*\d)").unwrap()); -/// Ex) `>=1.9.*` -static GREATER_THAN_STAR: Lazy = Lazy::new(|| Regex::new(r">=(\d+\.\d+)\.\*").unwrap()); -/// Ex) `!=3.0*` -static MISSING_DOT: Lazy = Lazy::new(|| Regex::new(r"(\d\.\d)+\*").unwrap()); -/// Ex) `>=3.6,` -static TRAILING_COMMA: Lazy = Lazy::new(|| Regex::new(r",\)").unwrap()); - -/// Like [`Requirement`], but attempts to correct some common errors in user-provided requirements. -#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] -struct LenientRequirement(Requirement); - -impl FromStr for LenientRequirement { - type Err = Pep508Error; - - fn from_str(s: &str) -> Result { - match Requirement::from_str(s) { - Ok(requirement) => Ok(Self(requirement)), - Err(err) => { - // Given `elasticsearch-dsl (>=7.2.0<8.0.0)`, rewrite to `elasticsearch-dsl (>=7.2.0,<8.0.0)`. - let patched = MISSING_COMMA.replace_all(s, r"$1,$2"); - if patched != s { - if let Ok(requirement) = Requirement::from_str(&patched) { - warn!( - "Inserting missing comma into invalid requirement (before: `{s}`; after: `{patched}`)", - ); - return Ok(Self(requirement)); - } - } - - // Given `jupyter-core (!=~5.0,>=4.12)`, rewrite to `jupyter-core (!=5.0.*,>=4.12)`. - let patched = NOT_EQUAL_TILDE.replace_all(s, r"!=${1}.*"); - if patched != s { - if let Ok(requirement) = Requirement::from_str(&patched) { - warn!( - "Adding wildcard after invalid tilde operator (before: `{s}`; after: `{patched}`)", - ); - return Ok(Self(requirement)); - } - } - - // Given `torch (>=1.9.*)`, rewrite to `torch (>=1.9)`. - let patched = GREATER_THAN_STAR.replace_all(s, r">=${1}"); - if patched != s { - if let Ok(requirement) = Requirement::from_str(&patched) { - warn!( - "Removing star after greater equal operator (before: `{s}`; after: `{patched}`)", - ); - return Ok(Self(requirement)); - } - } - - // Given `pyzmq (!=3.0*)`, rewrite to `pyzmq (!=3.0.*)`. - let patched = MISSING_DOT.replace_all(s, r"${1}.*"); - if patched != s { - if let Ok(requirement) = Requirement::from_str(&patched) { - warn!( - "Inserting missing dot into invalid requirement (before: `{s}`; after: `{patched}`)", - ); - return Ok(Self(requirement)); - } - } - - // Given `pyzmq (>=3.6,)`, rewrite to `pyzmq (>=3.6)` - let patched = TRAILING_COMMA.replace_all(s, r")"); - if patched != s { - if let Ok(requirement) = Requirement::from_str(&patched) { - warn!( - "Removing trailing comma from invalid requirement (before: `{s}`; after: `{patched}`)", - ); - return Ok(Self(requirement)); - } - } - - Err(err) - } - } - } -} - -impl From for Requirement { - fn from(requirement: LenientRequirement) -> Self { - requirement.0 - } -} - -#[cfg(test)] -mod tests { - use std::str::FromStr; - - use pep508_rs::Requirement; - - use super::LenientRequirement; - - #[test] - fn missing_comma() { - let actual: Requirement = LenientRequirement::from_str("elasticsearch-dsl (>=7.2.0<8.0.0)") - .unwrap() - .into(); - let expected: Requirement = - Requirement::from_str("elasticsearch-dsl (>=7.2.0,<8.0.0)").unwrap(); - assert_eq!(actual, expected); - } - - #[test] - fn not_equal_tile() { - let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5.0,>=4.12)") - .unwrap() - .into(); - let expected: Requirement = Requirement::from_str("jupyter-core (!=5.0.*,>=4.12)").unwrap(); - assert_eq!(actual, expected); - - let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5,>=4.12)") - .unwrap() - .into(); - let expected: Requirement = Requirement::from_str("jupyter-core (!=5.*,>=4.12)").unwrap(); - assert_eq!(actual, expected); - } - - #[test] - fn greater_than_star() { - let actual: Requirement = LenientRequirement::from_str("torch (>=1.9.*)") - .unwrap() - .into(); - let expected: Requirement = Requirement::from_str("torch (>=1.9)").unwrap(); - assert_eq!(actual, expected); - } - - #[test] - fn missing_dot() { - let actual: Requirement = - LenientRequirement::from_str("pyzmq (>=2.7,!=3.0*,!=3.1*,!=3.2*)") - .unwrap() - .into(); - let expected: Requirement = - Requirement::from_str("pyzmq (>=2.7,!=3.0.*,!=3.1.*,!=3.2.*)").unwrap(); - assert_eq!(actual, expected); - } - - #[test] - fn trailing_comma() { - let actual: Requirement = LenientRequirement::from_str("pyzmq (>=3.6,)") - .unwrap() - .into(); - let expected: Requirement = Requirement::from_str("pyzmq (>=3.6)").unwrap(); - assert_eq!(actual, expected); - } -}