Use serde-untagged to improve some untagged enum error messages (#7822)

## Summary

This is related to https://github.com/astral-sh/uv/issues/7817, but
doesn't close it.
This commit is contained in:
Charlie Marsh 2024-09-30 19:40:21 -04:00 committed by GitHub
parent 67769a4985
commit b6de417c94
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 123 additions and 44 deletions

34
Cargo.lock generated
View file

@ -1077,6 +1077,16 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "erased-serde"
version = "0.4.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24e2389d65ab4fab27dc2a5de7b191e1f6617d1f1c8855c0dc569c94a4cbb18d"
dependencies = [
"serde",
"typeid",
]
[[package]]
name = "errno"
version = "0.3.9"
@ -2717,7 +2727,7 @@ dependencies = [
"indoc",
"libc",
"memoffset 0.9.1",
"parking_lot 0.11.2",
"parking_lot 0.12.3",
"portable-atomic",
"pyo3-build-config",
"pyo3-ffi",
@ -2796,6 +2806,7 @@ dependencies = [
"regex",
"rkyv",
"serde",
"serde-untagged",
"thiserror",
"toml",
"toml_edit",
@ -3518,6 +3529,17 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-untagged"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2676ba99bd82f75cae5cbd2c8eda6fa0b8760f18978ea840e980dd5567b5c5b6"
dependencies = [
"erased-serde",
"serde",
"typeid",
]
[[package]]
name = "serde_derive"
version = "1.0.210"
@ -4256,6 +4278,12 @@ version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0609f771ad9c6155384897e1df4d948e692667cc0588548b68eb44d052b27633"
[[package]]
name = "typeid"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e13db2e0ccd5e14a544e8a246ba2312cd25223f616442d7f2cb0e3db614236e"
[[package]]
name = "typenum"
version = "1.17.0"
@ -4698,6 +4726,7 @@ dependencies = [
"rustc-hash",
"schemars",
"serde",
"serde-untagged",
"serde_json",
"thiserror",
"tracing",
@ -5329,6 +5358,7 @@ dependencies = [
"same-file",
"schemars",
"serde",
"serde-untagged",
"tempfile",
"thiserror",
"tokio",
@ -5551,7 +5581,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.59.0",
]
[[package]]

View file

@ -141,6 +141,7 @@ same-file = { version = "1.0.6" }
schemars = { version = "0.8.21", features = ["url"] }
seahash = { version = "4.1.0" }
serde = { version = "1.0.210", features = ["derive"] }
serde-untagged = { version = "0.1.6" }
serde_json = { version = "1.0.128" }
sha2 = { version = "0.10.8" }
smallvec = { version = "1.13.2" }

View file

@ -27,6 +27,7 @@ mailparse = { workspace = true }
regex = { workspace = true }
rkyv = { workspace = true }
serde = { workspace = true }
serde-untagged = { workspace = true }
thiserror = { workspace = true }
toml = { workspace = true }
toml_edit = { workspace = true }

View file

@ -1,9 +1,8 @@
use std::str::FromStr;
use jiff::Timestamp;
use serde::{Deserialize, Deserializer, Serialize};
use pep440_rs::{VersionSpecifiers, VersionSpecifiersParseError};
use serde::{Deserialize, Deserializer, Serialize};
use crate::lenient_requirement::LenientVersionSpecifiers;
@ -71,13 +70,24 @@ where
))
}
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
#[derive(Debug, Clone)]
pub enum CoreMetadata {
Bool(bool),
Hashes(Hashes),
}
impl<'de> Deserialize<'de> for CoreMetadata {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.bool(|bool| Ok(CoreMetadata::Bool(bool)))
.map(|map| map.deserialize().map(CoreMetadata::Hashes))
.deserialize(deserializer)
}
}
impl CoreMetadata {
pub fn is_available(&self) -> bool {
match self {
@ -87,24 +97,25 @@ impl CoreMetadata {
}
}
#[derive(
Debug,
Clone,
PartialEq,
Eq,
Hash,
Deserialize,
rkyv::Archive,
rkyv::Deserialize,
rkyv::Serialize,
)]
#[derive(Debug, Clone, PartialEq, Eq, Hash, rkyv::Archive, rkyv::Deserialize, rkyv::Serialize)]
#[rkyv(derive(Debug))]
#[serde(untagged)]
pub enum Yanked {
Bool(bool),
Reason(String),
}
impl<'de> Deserialize<'de> for Yanked {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.bool(|bool| Ok(Yanked::Bool(bool)))
.string(|string| Ok(Yanked::Reason(string.to_owned())))
.deserialize(deserializer)
}
}
impl Yanked {
pub fn is_yanked(&self) -> bool {
match self {

View file

@ -28,6 +28,7 @@ fs-err = { workspace = true }
rustc-hash = { workspace = true }
schemars = { workspace = true, optional = true }
serde = { workspace = true }
serde-untagged = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }

View file

@ -1,5 +1,5 @@
use serde::{Deserialize, Deserializer};
use std::str::FromStr;
use url::Url;
/// A trusted host, which could be a host or a host-port pair.
@ -33,28 +33,28 @@ impl TrustedHost {
}
}
#[derive(serde::Deserialize)]
#[serde(untagged)]
enum TrustHostWire {
String(String),
Struct {
scheme: Option<String>,
host: String,
port: Option<u16>,
},
}
impl<'de> serde::de::Deserialize<'de> for TrustedHost {
fn deserialize<D>(deserializer: D) -> Result<TrustedHost, D::Error>
impl<'de> Deserialize<'de> for TrustedHost {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
D: Deserializer<'de>,
{
let helper = TrustHostWire::deserialize(deserializer)?;
match helper {
TrustHostWire::String(s) => TrustedHost::from_str(&s).map_err(serde::de::Error::custom),
TrustHostWire::Struct { scheme, host, port } => Ok(TrustedHost { scheme, host, port }),
#[derive(Deserialize)]
struct Inner {
scheme: Option<String>,
host: String,
port: Option<u16>,
}
serde_untagged::UntaggedEnumVisitor::new()
.string(|string| TrustedHost::from_str(string).map_err(serde::de::Error::custom))
.map(|map| {
map.deserialize::<Inner>().map(|inner| TrustedHost {
scheme: inner.scheme,
host: inner.host,
port: inner.port,
})
})
.deserialize(deserializer)
}
}

View file

@ -226,6 +226,29 @@ mod test {
"###);
}
#[tokio::test]
async fn wrong_type() {
let input = indoc! {r#"
[project]
name = "foo"
version = "0.0.0"
dependencies = [
"tqdm",
]
[tool.uv.sources]
tqdm = true
"#};
assert_snapshot!(format_err(input).await, @r###"
error: TOML parse error at line 8, column 8
|
8 | tqdm = true
| ^^^^
invalid type: boolean `true`, expected an array or map
"###);
}
#[tokio::test]
async fn too_many_git_specs() {
let input = indoc! {r#"
@ -264,7 +287,7 @@ mod test {
|
8 | tqdm = { git = "https://github.com/tqdm/tqdm", ref = "baaaaaab" }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
data did not match any variant of untagged enum SourcesWire
data did not match any variant of untagged enum Source
"###);
}
@ -288,7 +311,7 @@ mod test {
|
8 | tqdm = { path = "tqdm", index = "torch" }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
data did not match any variant of untagged enum SourcesWire
data did not match any variant of untagged enum Source
"###);
}
@ -348,7 +371,7 @@ mod test {
|
8 | tqdm = { url = "§invalid#+#*Ä" }
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^
data did not match any variant of untagged enum SourcesWire
data did not match any variant of untagged enum Source
"###);
}

View file

@ -32,6 +32,7 @@ rustc-hash = { workspace = true }
same-file = { workspace = true }
schemars = { workspace = true, optional = true }
serde = { workspace = true, features = ["derive"] }
serde-untagged = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
toml = { workspace = true }

View file

@ -444,15 +444,26 @@ impl IntoIterator for Sources {
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(rename_all = "kebab-case", untagged)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(untagged))]
#[allow(clippy::large_enum_variant)]
enum SourcesWire {
One(Source),
Many(Vec<Source>),
}
impl<'de> serde::de::Deserialize<'de> for SourcesWire {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
serde_untagged::UntaggedEnumVisitor::new()
.map(|map| map.deserialize().map(SourcesWire::One))
.seq(|seq| seq.deserialize().map(SourcesWire::Many))
.deserialize(deserializer)
}
}
impl TryFrom<SourcesWire> for Sources {
type Error = SourceError;