Co-locate lenient requirement parsing (#418)

No behavior changes.
This commit is contained in:
Charlie Marsh 2023-11-13 12:46:21 -08:00 committed by GitHub
parent 437d4fb87e
commit 28ec4e79f0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 265 additions and 194 deletions

View file

@ -1,11 +1,105 @@
use pep440_rs::{Pep440Error, VersionSpecifiers};
use serde::{de, Deserialize, Deserializer, Serialize};
use std::str::FromStr;
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{de, Deserialize, Deserializer, Serialize};
use tracing::warn;
use pep440_rs::{Pep440Error, VersionSpecifiers};
use pep508_rs::{Pep508Error, Requirement};
/// Ex) `>=7.2.0<8.0.0`
static MISSING_COMMA: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\d)([<>=~^!])").unwrap());
/// Ex) `!=~5.0`
static NOT_EQUAL_TILDE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!=~((?:\d\.)*\d)").unwrap());
/// Ex) `>=1.9.*`
static GREATER_THAN_STAR: Lazy<Regex> = Lazy::new(|| Regex::new(r">=(\d+\.\d+)\.\*").unwrap());
/// Ex) `!=3.0*`
static MISSING_DOT: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\d\.\d)+\*").unwrap());
/// Ex) `>=3.6,`
static TRAILING_COMMA: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\d\.\d)+,$").unwrap());
/// Like [`Requirement`], but attempts to correct some common errors in user-provided requirements.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct LenientRequirement(Requirement);
impl FromStr for LenientRequirement {
type Err = Pep508Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match Requirement::from_str(s) {
Ok(requirement) => Ok(Self(requirement)),
Err(err) => {
// Given `elasticsearch-dsl (>=7.2.0<8.0.0)`, rewrite to `elasticsearch-dsl (>=7.2.0,<8.0.0)`.
let patched = MISSING_COMMA.replace_all(s, r"$1,$2");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Inserting missing comma into invalid requirement (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `jupyter-core (!=~5.0,>=4.12)`, rewrite to `jupyter-core (!=5.0.*,>=4.12)`.
let patched = NOT_EQUAL_TILDE.replace_all(s, r"!=${1}.*");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Adding wildcard after invalid tilde operator (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `torch (>=1.9.*)`, rewrite to `torch (>=1.9)`.
let patched = GREATER_THAN_STAR.replace_all(s, r">=${1}");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Removing star after greater equal operator (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `pyzmq (!=3.0*)`, rewrite to `pyzmq (!=3.0.*)`.
let patched = MISSING_DOT.replace_all(s, r"${1}.*");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Inserting missing dot into invalid requirement (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `pyzmq (>=3.6,)`, rewrite to `pyzmq (>=3.6)`
let patched = TRAILING_COMMA.replace_all(s, r"${1}");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Removing trailing comma from invalid requirement (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
Err(err)
}
}
}
}
impl From<LenientRequirement> for Requirement {
fn from(requirement: LenientRequirement) -> Self {
requirement.0
}
}
/// Like [`VersionSpecifiers`], but attempts to correct some common errors in user-provided requirements.
///
/// We turn `>=3.x.*` into `>=3.x`
/// For example, we turn `>=3.x.*` into `>=3.x`.
#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
pub struct LenientVersionSpecifiers(VersionSpecifiers);
@ -16,45 +110,61 @@ impl FromStr for LenientVersionSpecifiers {
match VersionSpecifiers::from_str(s) {
Ok(specifiers) => Ok(Self(specifiers)),
Err(err) => {
// Given `>=3.5.*`, rewrite to `>=3.5`.
let patched = match s {
">=3.12.*" => Some(">=3.12"),
">=3.11.*" => Some(">=3.11"),
">=3.10.*" => Some(">=3.10"),
">=3.9.*" => Some(">=3.9"),
">=3.8.*" => Some(">=3.8"),
">=3.7.*" => Some(">=3.7"),
">=3.6.*" => Some(">=3.6"),
">=3.5.*" => Some(">=3.5"),
">=3.4.*" => Some(">=3.4"),
">=3.3.*" => Some(">=3.3"),
">=3.2.*" => Some(">=3.2"),
">=3.1.*" => Some(">=3.1"),
">=3.0.*" => Some(">=3.0"),
">=3.12," => Some(">=3.12"),
">=3.11," => Some(">=3.11"),
">=3.10," => Some(">=3.10"),
">=3.9," => Some(">=3.9"),
">=3.8," => Some(">=3.8"),
">=3.7," => Some(">=3.7"),
">=3.6," => Some(">=3.6"),
">=3.5," => Some(">=3.5"),
">=3.4," => Some(">=3.4"),
">=3.3," => Some(">=3.3"),
">=3.2," => Some(">=3.2"),
">=3.1," => Some(">=3.1"),
">=3.0," => Some(">=3.0"),
">=2.7,!=3.0*,!=3.1*,!=3.2*" => Some(">=2.7,!=3.0.*,!=3.1.*,!=3.2.*"),
_ => None,
};
if let Some(patched) = patched {
if let Ok(specifier) = VersionSpecifiers::from_str(patched) {
// Given `>=7.2.0<8.0.0`, rewrite to `>=7.2.0,<8.0.0`.
let patched = MISSING_COMMA.replace_all(s, r"$1,$2");
if patched != s {
if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) {
warn!(
"Correcting invalid wildcard bound on version specifier (before: `{s}`; after: `{patched}`)",
"Inserting missing comma into invalid specifier (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(specifier));
return Ok(Self(specifiers));
}
}
// Given `!=~5.0,>=4.12`, rewrite to `!=5.0.*,>=4.12`.
let patched = NOT_EQUAL_TILDE.replace_all(s, r"!=${1}.*");
if patched != s {
if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) {
warn!(
"Adding wildcard after invalid tilde operator (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(specifiers));
}
}
// Given `>=1.9.*`, rewrite to `>=1.9`.
let patched = GREATER_THAN_STAR.replace_all(s, r">=${1}");
if patched != s {
if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) {
warn!(
"Removing star after greater equal operator (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(specifiers));
}
}
// Given `!=3.0*`, rewrite to `!=3.0.*`.
let patched = MISSING_DOT.replace_all(s, r"${1}.*");
if patched != s {
if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) {
warn!(
"Inserting missing dot into invalid specifier (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(specifiers));
}
}
// Given `>=3.6,`, rewrite to `>=3.6`
let patched = TRAILING_COMMA.replace_all(s, r"${1}");
if patched != s {
if let Ok(specifiers) = VersionSpecifiers::from_str(&patched) {
warn!(
"Removing trailing comma from invalid specifier (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(specifiers));
}
}
Err(err)
}
}
@ -76,3 +186,118 @@ impl<'de> Deserialize<'de> for LenientVersionSpecifiers {
Self::from_str(&s).map_err(de::Error::custom)
}
}
#[cfg(test)]
mod tests {
use pep440_rs::VersionSpecifiers;
use std::str::FromStr;
use crate::LenientVersionSpecifiers;
use pep508_rs::Requirement;
use super::LenientRequirement;
#[test]
fn requirement_missing_comma() {
let actual: Requirement = LenientRequirement::from_str("elasticsearch-dsl (>=7.2.0<8.0.0)")
.unwrap()
.into();
let expected: Requirement =
Requirement::from_str("elasticsearch-dsl (>=7.2.0,<8.0.0)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn requirement_not_equal_tile() {
let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5.0,>=4.12)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("jupyter-core (!=5.0.*,>=4.12)").unwrap();
assert_eq!(actual, expected);
let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5,>=4.12)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("jupyter-core (!=5.*,>=4.12)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn requirement_greater_than_star() {
let actual: Requirement = LenientRequirement::from_str("torch (>=1.9.*)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("torch (>=1.9)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn requirement_missing_dot() {
let actual: Requirement =
LenientRequirement::from_str("pyzmq (>=2.7,!=3.0*,!=3.1*,!=3.2*)")
.unwrap()
.into();
let expected: Requirement =
Requirement::from_str("pyzmq (>=2.7,!=3.0.*,!=3.1.*,!=3.2.*)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn requirement_trailing_comma() {
let actual: Requirement = LenientRequirement::from_str("pyzmq >=3.6,").unwrap().into();
let expected: Requirement = Requirement::from_str("pyzmq >=3.6").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn specifier_missing_comma() {
let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=7.2.0<8.0.0")
.unwrap()
.into();
let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=7.2.0,<8.0.0").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn specifier_not_equal_tile() {
let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5.0,>=4.12")
.unwrap()
.into();
let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.0.*,>=4.12").unwrap();
assert_eq!(actual, expected);
let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str("!=~5,>=4.12")
.unwrap()
.into();
let expected: VersionSpecifiers = VersionSpecifiers::from_str("!=5.*,>=4.12").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn specifier_greater_than_star() {
let actual: VersionSpecifiers = LenientVersionSpecifiers::from_str(">=1.9.*")
.unwrap()
.into();
let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=1.9").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn specifier_missing_dot() {
let actual: VersionSpecifiers =
LenientVersionSpecifiers::from_str(">=2.7,!=3.0*,!=3.1*,!=3.2*")
.unwrap()
.into();
let expected: VersionSpecifiers =
VersionSpecifiers::from_str(">=2.7,!=3.0.*,!=3.1.*,!=3.2.*").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn specifier_trailing_comma() {
let actual: VersionSpecifiers =
LenientVersionSpecifiers::from_str(">=3.6,").unwrap().into();
let expected: VersionSpecifiers = VersionSpecifiers::from_str(">=3.6").unwrap();
assert_eq!(actual, expected);
}
}

View file

@ -5,17 +5,15 @@ use std::io;
use std::str::FromStr;
use mailparse::{MailHeaderMap, MailParseError};
use once_cell::sync::Lazy;
use regex::Regex;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tracing::warn;
use crate::lenient_requirement::LenientVersionSpecifiers;
use pep440_rs::{Pep440Error, Version, VersionSpecifiers};
use pep508_rs::{Pep508Error, Requirement};
use puffin_normalize::{ExtraName, InvalidNameError, PackageName};
use crate::lenient_requirement::{LenientRequirement, LenientVersionSpecifiers};
/// Python Package Metadata 2.1 as specified in
/// <https://packaging.python.org/specifications/core-metadata/>
///
@ -210,155 +208,3 @@ impl Metadata21 {
})
}
}
/// Ex) `>=7.2.0<8.0.0`
static MISSING_COMMA: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\d)([<>=~^!])").unwrap());
/// Ex) `!=~5.0`
static NOT_EQUAL_TILDE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!=~((?:\d\.)*\d)").unwrap());
/// Ex) `>=1.9.*`
static GREATER_THAN_STAR: Lazy<Regex> = Lazy::new(|| Regex::new(r">=(\d+\.\d+)\.\*").unwrap());
/// Ex) `!=3.0*`
static MISSING_DOT: Lazy<Regex> = Lazy::new(|| Regex::new(r"(\d\.\d)+\*").unwrap());
/// Ex) `>=3.6,`
static TRAILING_COMMA: Lazy<Regex> = Lazy::new(|| Regex::new(r",\)").unwrap());
/// Like [`Requirement`], but attempts to correct some common errors in user-provided requirements.
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
struct LenientRequirement(Requirement);
impl FromStr for LenientRequirement {
type Err = Pep508Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match Requirement::from_str(s) {
Ok(requirement) => Ok(Self(requirement)),
Err(err) => {
// Given `elasticsearch-dsl (>=7.2.0<8.0.0)`, rewrite to `elasticsearch-dsl (>=7.2.0,<8.0.0)`.
let patched = MISSING_COMMA.replace_all(s, r"$1,$2");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Inserting missing comma into invalid requirement (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `jupyter-core (!=~5.0,>=4.12)`, rewrite to `jupyter-core (!=5.0.*,>=4.12)`.
let patched = NOT_EQUAL_TILDE.replace_all(s, r"!=${1}.*");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Adding wildcard after invalid tilde operator (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `torch (>=1.9.*)`, rewrite to `torch (>=1.9)`.
let patched = GREATER_THAN_STAR.replace_all(s, r">=${1}");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Removing star after greater equal operator (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `pyzmq (!=3.0*)`, rewrite to `pyzmq (!=3.0.*)`.
let patched = MISSING_DOT.replace_all(s, r"${1}.*");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Inserting missing dot into invalid requirement (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
// Given `pyzmq (>=3.6,)`, rewrite to `pyzmq (>=3.6)`
let patched = TRAILING_COMMA.replace_all(s, r")");
if patched != s {
if let Ok(requirement) = Requirement::from_str(&patched) {
warn!(
"Removing trailing comma from invalid requirement (before: `{s}`; after: `{patched}`)",
);
return Ok(Self(requirement));
}
}
Err(err)
}
}
}
}
impl From<LenientRequirement> for Requirement {
fn from(requirement: LenientRequirement) -> Self {
requirement.0
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use pep508_rs::Requirement;
use super::LenientRequirement;
#[test]
fn missing_comma() {
let actual: Requirement = LenientRequirement::from_str("elasticsearch-dsl (>=7.2.0<8.0.0)")
.unwrap()
.into();
let expected: Requirement =
Requirement::from_str("elasticsearch-dsl (>=7.2.0,<8.0.0)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn not_equal_tile() {
let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5.0,>=4.12)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("jupyter-core (!=5.0.*,>=4.12)").unwrap();
assert_eq!(actual, expected);
let actual: Requirement = LenientRequirement::from_str("jupyter-core (!=~5,>=4.12)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("jupyter-core (!=5.*,>=4.12)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn greater_than_star() {
let actual: Requirement = LenientRequirement::from_str("torch (>=1.9.*)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("torch (>=1.9)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn missing_dot() {
let actual: Requirement =
LenientRequirement::from_str("pyzmq (>=2.7,!=3.0*,!=3.1*,!=3.2*)")
.unwrap()
.into();
let expected: Requirement =
Requirement::from_str("pyzmq (>=2.7,!=3.0.*,!=3.1.*,!=3.2.*)").unwrap();
assert_eq!(actual, expected);
}
#[test]
fn trailing_comma() {
let actual: Requirement = LenientRequirement::from_str("pyzmq (>=3.6,)")
.unwrap()
.into();
let expected: Requirement = Requirement::from_str("pyzmq (>=3.6)").unwrap();
assert_eq!(actual, expected);
}
}