mirror of
https://github.com/astral-sh/ruff.git
synced 2025-10-01 06:11:21 +00:00
Vendor benchmark test files (#15878)
Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
parent
d9a1034db0
commit
770b7f3439
20 changed files with 4235 additions and 314 deletions
|
@ -5,6 +5,7 @@ exclude: |
|
||||||
.github/workflows/release.yml|
|
.github/workflows/release.yml|
|
||||||
crates/red_knot_vendored/vendor/.*|
|
crates/red_knot_vendored/vendor/.*|
|
||||||
crates/red_knot_project/resources/.*|
|
crates/red_knot_project/resources/.*|
|
||||||
|
crates/ruff_benchmark/resources/.*|
|
||||||
crates/ruff_linter/resources/.*|
|
crates/ruff_linter/resources/.*|
|
||||||
crates/ruff_linter/src/rules/.*/snapshots/.*|
|
crates/ruff_linter/src/rules/.*/snapshots/.*|
|
||||||
crates/ruff_notebook/resources/.*|
|
crates/ruff_notebook/resources/.*|
|
||||||
|
|
108
Cargo.lock
generated
108
Cargo.lock
generated
|
@ -190,12 +190,6 @@ version = "0.13.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "base64"
|
|
||||||
version = "0.22.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bincode"
|
name = "bincode"
|
||||||
version = "1.3.3"
|
version = "1.3.3"
|
||||||
|
@ -2577,28 +2571,13 @@ version = "0.8.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ring"
|
|
||||||
version = "0.17.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
|
|
||||||
dependencies = [
|
|
||||||
"cc",
|
|
||||||
"cfg-if",
|
|
||||||
"getrandom",
|
|
||||||
"libc",
|
|
||||||
"spin",
|
|
||||||
"untrusted",
|
|
||||||
"windows-sys 0.52.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "ron"
|
name = "ron"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a"
|
checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"base64 0.13.1",
|
"base64",
|
||||||
"bitflags 1.3.2",
|
"bitflags 1.3.2",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
@ -2692,11 +2671,7 @@ dependencies = [
|
||||||
"ruff_python_parser",
|
"ruff_python_parser",
|
||||||
"ruff_python_trivia",
|
"ruff_python_trivia",
|
||||||
"rustc-hash 2.1.0",
|
"rustc-hash 2.1.0",
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
"tikv-jemallocator",
|
"tikv-jemallocator",
|
||||||
"ureq",
|
|
||||||
"url",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -3253,38 +3228,6 @@ dependencies = [
|
||||||
"windows-sys 0.52.0",
|
"windows-sys 0.52.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustls"
|
|
||||||
version = "0.23.21"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8"
|
|
||||||
dependencies = [
|
|
||||||
"log",
|
|
||||||
"once_cell",
|
|
||||||
"ring",
|
|
||||||
"rustls-pki-types",
|
|
||||||
"rustls-webpki",
|
|
||||||
"subtle",
|
|
||||||
"zeroize",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustls-pki-types"
|
|
||||||
version = "1.10.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "rustls-webpki"
|
|
||||||
version = "0.102.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
|
|
||||||
dependencies = [
|
|
||||||
"ring",
|
|
||||||
"rustls-pki-types",
|
|
||||||
"untrusted",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustversion"
|
name = "rustversion"
|
||||||
version = "1.0.19"
|
version = "1.0.19"
|
||||||
|
@ -3567,12 +3510,6 @@ dependencies = [
|
||||||
"anstream",
|
"anstream",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "spin"
|
|
||||||
version = "0.9.8"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stable_deref_trait"
|
name = "stable_deref_trait"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
|
@ -3622,12 +3559,6 @@ dependencies = [
|
||||||
"syn 2.0.96",
|
"syn 2.0.96",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "subtle"
|
|
||||||
version = "2.6.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "1.0.109"
|
version = "1.0.109"
|
||||||
|
@ -4116,28 +4047,6 @@ version = "0.1.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e9df2af067a7953e9c3831320f35c1cc0600c30d44d9f7a12b01db1cd88d6b47"
|
checksum = "e9df2af067a7953e9c3831320f35c1cc0600c30d44d9f7a12b01db1cd88d6b47"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "untrusted"
|
|
||||||
version = "0.9.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "ureq"
|
|
||||||
version = "2.12.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
|
|
||||||
dependencies = [
|
|
||||||
"base64 0.22.1",
|
|
||||||
"flate2",
|
|
||||||
"log",
|
|
||||||
"once_cell",
|
|
||||||
"rustls",
|
|
||||||
"rustls-pki-types",
|
|
||||||
"url",
|
|
||||||
"webpki-roots",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "url"
|
name = "url"
|
||||||
version = "2.5.4"
|
version = "2.5.4"
|
||||||
|
@ -4406,15 +4315,6 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "webpki-roots"
|
|
||||||
version = "0.26.7"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
|
|
||||||
dependencies = [
|
|
||||||
"rustls-pki-types",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "which"
|
name = "which"
|
||||||
version = "7.0.1"
|
version = "7.0.1"
|
||||||
|
@ -4723,12 +4623,6 @@ dependencies = [
|
||||||
"synstructure",
|
"synstructure",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "zeroize"
|
|
||||||
version = "1.8.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zerovec"
|
name = "zerovec"
|
||||||
version = "0.10.4"
|
version = "0.10.4"
|
||||||
|
|
14
Cargo.toml
14
Cargo.toml
|
@ -134,7 +134,12 @@ serde_with = { version = "3.6.0", default-features = false, features = [
|
||||||
shellexpand = { version = "3.0.0" }
|
shellexpand = { version = "3.0.0" }
|
||||||
similar = { version = "2.4.0", features = ["inline"] }
|
similar = { version = "2.4.0", features = ["inline"] }
|
||||||
smallvec = { version = "1.13.2" }
|
smallvec = { version = "1.13.2" }
|
||||||
snapbox = { version = "0.6.0", features = ["diff", "term-svg", "cmd", "examples"] }
|
snapbox = { version = "0.6.0", features = [
|
||||||
|
"diff",
|
||||||
|
"term-svg",
|
||||||
|
"cmd",
|
||||||
|
"examples",
|
||||||
|
] }
|
||||||
static_assertions = "1.1.0"
|
static_assertions = "1.1.0"
|
||||||
strum = { version = "0.26.0", features = ["strum_macros"] }
|
strum = { version = "0.26.0", features = ["strum_macros"] }
|
||||||
strum_macros = { version = "0.26.0" }
|
strum_macros = { version = "0.26.0" }
|
||||||
|
@ -159,7 +164,6 @@ unicode-ident = { version = "1.0.12" }
|
||||||
unicode-width = { version = "0.2.0" }
|
unicode-width = { version = "0.2.0" }
|
||||||
unicode_names2 = { version = "1.2.2" }
|
unicode_names2 = { version = "1.2.2" }
|
||||||
unicode-normalization = { version = "0.1.23" }
|
unicode-normalization = { version = "0.1.23" }
|
||||||
ureq = { version = "2.9.6" }
|
|
||||||
url = { version = "2.5.0" }
|
url = { version = "2.5.0" }
|
||||||
uuid = { version = "1.6.1", features = [
|
uuid = { version = "1.6.1", features = [
|
||||||
"v4",
|
"v4",
|
||||||
|
@ -305,7 +309,11 @@ local-artifacts-jobs = ["./build-binaries", "./build-docker"]
|
||||||
# Publish jobs to run in CI
|
# Publish jobs to run in CI
|
||||||
publish-jobs = ["./publish-pypi", "./publish-wasm"]
|
publish-jobs = ["./publish-pypi", "./publish-wasm"]
|
||||||
# Post-announce jobs to run in CI
|
# Post-announce jobs to run in CI
|
||||||
post-announce-jobs = ["./notify-dependents", "./publish-docs", "./publish-playground"]
|
post-announce-jobs = [
|
||||||
|
"./notify-dependents",
|
||||||
|
"./publish-docs",
|
||||||
|
"./publish-playground",
|
||||||
|
]
|
||||||
# Custom permissions for GitHub Jobs
|
# Custom permissions for GitHub Jobs
|
||||||
github-custom-job-permissions = { "build-docker" = { packages = "write", contents = "read" }, "publish-wasm" = { contents = "read", id-token = "write", packages = "write" } }
|
github-custom-job-permissions = { "build-docker" = { packages = "write", contents = "read" }, "publish-wasm" = { contents = "read", id-token = "write", packages = "write" } }
|
||||||
# Whether to install an updater program
|
# Whether to install an updater program
|
||||||
|
|
|
@ -41,10 +41,6 @@ codspeed-criterion-compat = { workspace = true, default-features = false, option
|
||||||
criterion = { workspace = true, default-features = false }
|
criterion = { workspace = true, default-features = false }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
rustc-hash = { workspace = true }
|
rustc-hash = { workspace = true }
|
||||||
serde = { workspace = true }
|
|
||||||
serde_json = { workspace = true }
|
|
||||||
url = { workspace = true }
|
|
||||||
ureq = { workspace = true }
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
ruff_db = { workspace = true }
|
ruff_db = { workspace = true }
|
||||||
|
|
|
@ -3,7 +3,10 @@ use std::path::Path;
|
||||||
use ruff_benchmark::criterion::{
|
use ruff_benchmark::criterion::{
|
||||||
criterion_group, criterion_main, BenchmarkId, Criterion, Throughput,
|
criterion_group, criterion_main, BenchmarkId, Criterion, Throughput,
|
||||||
};
|
};
|
||||||
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
|
|
||||||
|
use ruff_benchmark::{
|
||||||
|
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
|
||||||
|
};
|
||||||
use ruff_python_formatter::{format_module_ast, PreviewMode, PyFormatOptions};
|
use ruff_python_formatter::{format_module_ast, PreviewMode, PyFormatOptions};
|
||||||
use ruff_python_parser::{parse, Mode};
|
use ruff_python_parser::{parse, Mode};
|
||||||
use ruff_python_trivia::CommentRanges;
|
use ruff_python_trivia::CommentRanges;
|
||||||
|
@ -24,27 +27,20 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||||
|
|
||||||
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
|
fn create_test_cases() -> Vec<TestCase> {
|
||||||
Ok(vec![
|
vec![
|
||||||
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
|
TestCase::fast(NUMPY_GLOBALS.clone()),
|
||||||
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
|
TestCase::fast(UNICODE_PYPINYIN.clone()),
|
||||||
TestCase::normal(TestFile::try_download(
|
TestCase::normal(PYDANTIC_TYPES.clone()),
|
||||||
"pydantic/types.py",
|
TestCase::normal(NUMPY_CTYPESLIB.clone()),
|
||||||
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
|
TestCase::slow(LARGE_DATASET.clone()),
|
||||||
)?),
|
]
|
||||||
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
|
|
||||||
TestCase::slow(TestFile::try_download(
|
|
||||||
"large/dataset.py",
|
|
||||||
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
|
|
||||||
)?),
|
|
||||||
])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn benchmark_formatter(criterion: &mut Criterion) {
|
fn benchmark_formatter(criterion: &mut Criterion) {
|
||||||
let mut group = criterion.benchmark_group("formatter");
|
let mut group = criterion.benchmark_group("formatter");
|
||||||
let test_cases = create_test_cases().unwrap();
|
|
||||||
|
|
||||||
for case in test_cases {
|
for case in create_test_cases() {
|
||||||
group.throughput(Throughput::Bytes(case.code().len() as u64));
|
group.throughput(Throughput::Bytes(case.code().len() as u64));
|
||||||
|
|
||||||
group.bench_with_input(
|
group.bench_with_input(
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
use ruff_benchmark::criterion::{
|
use ruff_benchmark::criterion::{
|
||||||
criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput,
|
criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput,
|
||||||
};
|
};
|
||||||
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
|
use ruff_benchmark::{
|
||||||
|
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
|
||||||
|
};
|
||||||
use ruff_python_parser::{lexer, Mode, TokenKind};
|
use ruff_python_parser::{lexer, Mode, TokenKind};
|
||||||
|
|
||||||
#[cfg(target_os = "windows")]
|
#[cfg(target_os = "windows")]
|
||||||
|
@ -20,24 +22,18 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||||
|
|
||||||
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
|
fn create_test_cases() -> Vec<TestCase> {
|
||||||
Ok(vec![
|
vec![
|
||||||
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
|
TestCase::fast(NUMPY_GLOBALS.clone()),
|
||||||
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
|
TestCase::fast(UNICODE_PYPINYIN.clone()),
|
||||||
TestCase::normal(TestFile::try_download(
|
TestCase::normal(PYDANTIC_TYPES.clone()),
|
||||||
"pydantic/types.py",
|
TestCase::normal(NUMPY_CTYPESLIB.clone()),
|
||||||
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
|
TestCase::slow(LARGE_DATASET.clone()),
|
||||||
)?),
|
]
|
||||||
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
|
|
||||||
TestCase::slow(TestFile::try_download(
|
|
||||||
"large/dataset.py",
|
|
||||||
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
|
|
||||||
)?),
|
|
||||||
])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn benchmark_lexer(criterion: &mut Criterion<WallTime>) {
|
fn benchmark_lexer(criterion: &mut Criterion<WallTime>) {
|
||||||
let test_cases = create_test_cases().unwrap();
|
let test_cases = create_test_cases();
|
||||||
let mut group = criterion.benchmark_group("lexer");
|
let mut group = criterion.benchmark_group("lexer");
|
||||||
|
|
||||||
for case in test_cases {
|
for case in test_cases {
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
use ruff_benchmark::criterion::{
|
use ruff_benchmark::criterion::{
|
||||||
criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion, Throughput,
|
criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion, Throughput,
|
||||||
};
|
};
|
||||||
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
|
use ruff_benchmark::{
|
||||||
|
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
|
||||||
|
};
|
||||||
use ruff_linter::linter::{lint_only, ParseSource};
|
use ruff_linter::linter::{lint_only, ParseSource};
|
||||||
use ruff_linter::rule_selector::PreviewOptions;
|
use ruff_linter::rule_selector::PreviewOptions;
|
||||||
use ruff_linter::settings::rule_table::RuleTable;
|
use ruff_linter::settings::rule_table::RuleTable;
|
||||||
|
@ -46,24 +48,18 @@ static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||||
#[allow(unsafe_code)]
|
#[allow(unsafe_code)]
|
||||||
pub static _rjem_malloc_conf: &[u8] = b"dirty_decay_ms:-1,muzzy_decay_ms:-1\0";
|
pub static _rjem_malloc_conf: &[u8] = b"dirty_decay_ms:-1,muzzy_decay_ms:-1\0";
|
||||||
|
|
||||||
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
|
fn create_test_cases() -> Vec<TestCase> {
|
||||||
Ok(vec![
|
vec![
|
||||||
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
|
TestCase::fast(NUMPY_GLOBALS.clone()),
|
||||||
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
|
TestCase::fast(UNICODE_PYPINYIN.clone()),
|
||||||
TestCase::normal(TestFile::try_download(
|
TestCase::normal(PYDANTIC_TYPES.clone()),
|
||||||
"pydantic/types.py",
|
TestCase::normal(NUMPY_CTYPESLIB.clone()),
|
||||||
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
|
TestCase::slow(LARGE_DATASET.clone()),
|
||||||
)?),
|
]
|
||||||
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
|
|
||||||
TestCase::slow(TestFile::try_download(
|
|
||||||
"large/dataset.py",
|
|
||||||
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
|
|
||||||
)?),
|
|
||||||
])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn benchmark_linter(mut group: BenchmarkGroup, settings: &LinterSettings) {
|
fn benchmark_linter(mut group: BenchmarkGroup, settings: &LinterSettings) {
|
||||||
let test_cases = create_test_cases().unwrap();
|
let test_cases = create_test_cases();
|
||||||
|
|
||||||
for case in test_cases {
|
for case in test_cases {
|
||||||
group.throughput(Throughput::Bytes(case.code().len() as u64));
|
group.throughput(Throughput::Bytes(case.code().len() as u64));
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
use ruff_benchmark::criterion::{
|
use ruff_benchmark::criterion::{
|
||||||
criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput,
|
criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput,
|
||||||
};
|
};
|
||||||
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
|
use ruff_benchmark::{
|
||||||
|
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
|
||||||
|
};
|
||||||
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
|
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
|
||||||
use ruff_python_ast::Stmt;
|
use ruff_python_ast::Stmt;
|
||||||
use ruff_python_parser::parse_module;
|
use ruff_python_parser::parse_module;
|
||||||
|
@ -22,20 +24,14 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
|
||||||
|
|
||||||
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
|
fn create_test_cases() -> Vec<TestCase> {
|
||||||
Ok(vec![
|
vec![
|
||||||
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
|
TestCase::fast(NUMPY_GLOBALS.clone()),
|
||||||
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
|
TestCase::fast(UNICODE_PYPINYIN.clone()),
|
||||||
TestCase::normal(TestFile::try_download(
|
TestCase::normal(PYDANTIC_TYPES.clone()),
|
||||||
"pydantic/types.py",
|
TestCase::normal(NUMPY_CTYPESLIB.clone()),
|
||||||
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
|
TestCase::slow(LARGE_DATASET.clone()),
|
||||||
)?),
|
]
|
||||||
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
|
|
||||||
TestCase::slow(TestFile::try_download(
|
|
||||||
"large/dataset.py",
|
|
||||||
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
|
|
||||||
)?),
|
|
||||||
])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CountVisitor {
|
struct CountVisitor {
|
||||||
|
@ -50,7 +46,7 @@ impl<'a> StatementVisitor<'a> for CountVisitor {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn benchmark_parser(criterion: &mut Criterion<WallTime>) {
|
fn benchmark_parser(criterion: &mut Criterion<WallTime>) {
|
||||||
let test_cases = create_test_cases().unwrap();
|
let test_cases = create_test_cases();
|
||||||
let mut group = criterion.benchmark_group("parser");
|
let mut group = criterion.benchmark_group("parser");
|
||||||
|
|
||||||
for case in test_cases {
|
for case in test_cases {
|
||||||
|
|
|
@ -24,7 +24,25 @@ struct Case {
|
||||||
re_path: SystemPathBuf,
|
re_path: SystemPathBuf,
|
||||||
}
|
}
|
||||||
|
|
||||||
const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib";
|
// "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib";
|
||||||
|
static TOMLLIB_FILES: [TestFile; 4] = [
|
||||||
|
TestFile::new(
|
||||||
|
"tomllib/__init__.py",
|
||||||
|
include_str!("../resources/tomllib/__init__.py"),
|
||||||
|
),
|
||||||
|
TestFile::new(
|
||||||
|
"tomllib/_parser.py",
|
||||||
|
include_str!("../resources/tomllib/_parser.py"),
|
||||||
|
),
|
||||||
|
TestFile::new(
|
||||||
|
"tomllib/_re.py",
|
||||||
|
include_str!("../resources/tomllib/_re.py"),
|
||||||
|
),
|
||||||
|
TestFile::new(
|
||||||
|
"tomllib/_types.py",
|
||||||
|
include_str!("../resources/tomllib/_types.py"),
|
||||||
|
),
|
||||||
|
];
|
||||||
|
|
||||||
/// A structured set of fields we use to do diagnostic comparisons.
|
/// A structured set of fields we use to do diagnostic comparisons.
|
||||||
///
|
///
|
||||||
|
@ -80,27 +98,19 @@ static EXPECTED_DIAGNOSTICS: &[KeyDiagnosticFields] = &[
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
fn get_test_file(name: &str) -> TestFile {
|
fn tomllib_path(file: &TestFile) -> SystemPathBuf {
|
||||||
let path = format!("tomllib/{name}");
|
SystemPathBuf::from("src").join(file.name())
|
||||||
let url = format!("{TOMLLIB_312_URL}/{name}");
|
|
||||||
TestFile::try_download(&path, &url).unwrap()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn tomllib_path(filename: &str) -> SystemPathBuf {
|
|
||||||
SystemPathBuf::from(format!("/src/tomllib/{filename}").as_str())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn setup_case() -> Case {
|
fn setup_case() -> Case {
|
||||||
let system = TestSystem::default();
|
let system = TestSystem::default();
|
||||||
let fs = system.memory_file_system().clone();
|
let fs = system.memory_file_system().clone();
|
||||||
|
|
||||||
let tomllib_filenames = ["__init__.py", "_parser.py", "_re.py", "_types.py"];
|
fs.write_files(
|
||||||
fs.write_files(tomllib_filenames.iter().map(|filename| {
|
TOMLLIB_FILES
|
||||||
(
|
.iter()
|
||||||
tomllib_path(filename),
|
.map(|file| (tomllib_path(file), file.code().to_string())),
|
||||||
get_test_file(filename).code().to_string(),
|
|
||||||
)
|
)
|
||||||
}))
|
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let src_root = SystemPath::new("/src");
|
let src_root = SystemPath::new("/src");
|
||||||
|
@ -114,15 +124,22 @@ fn setup_case() -> Case {
|
||||||
});
|
});
|
||||||
|
|
||||||
let mut db = ProjectDatabase::new(metadata, system).unwrap();
|
let mut db = ProjectDatabase::new(metadata, system).unwrap();
|
||||||
|
let mut tomllib_files = FxHashSet::default();
|
||||||
|
let mut re: Option<File> = None;
|
||||||
|
|
||||||
|
for test_file in &TOMLLIB_FILES {
|
||||||
|
let file = system_path_to_file(&db, tomllib_path(test_file)).unwrap();
|
||||||
|
if test_file.name().ends_with("_re.py") {
|
||||||
|
re = Some(file);
|
||||||
|
}
|
||||||
|
tomllib_files.insert(file);
|
||||||
|
}
|
||||||
|
|
||||||
|
let re = re.unwrap();
|
||||||
|
|
||||||
let tomllib_files: FxHashSet<File> = tomllib_filenames
|
|
||||||
.iter()
|
|
||||||
.map(|filename| system_path_to_file(&db, tomllib_path(filename)).unwrap())
|
|
||||||
.collect();
|
|
||||||
db.project().set_open_files(&mut db, tomllib_files);
|
db.project().set_open_files(&mut db, tomllib_files);
|
||||||
|
|
||||||
let re_path = tomllib_path("_re.py");
|
let re_path = re.path(&db).as_system_path().unwrap().to_owned();
|
||||||
let re = system_path_to_file(&db, &re_path).unwrap();
|
|
||||||
Case {
|
Case {
|
||||||
db,
|
db,
|
||||||
fs,
|
fs,
|
||||||
|
|
16
crates/ruff_benchmark/resources/README.md
Normal file
16
crates/ruff_benchmark/resources/README.md
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
This directory vendors some files from actual projects.
|
||||||
|
This is to benchmark Ruff's performance against real-world
|
||||||
|
code instead of synthetic benchmarks.
|
||||||
|
|
||||||
|
The following files are included:
|
||||||
|
|
||||||
|
* [`numpy/globals`](https://github.com/numpy/numpy/blob/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py)
|
||||||
|
* [`numpy/ctypeslib.py`](https://github.com/numpy/numpy/blob/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py)
|
||||||
|
* [`pypinyin.py`](https://github.com/mozillazg/python-pinyin/blob/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py)
|
||||||
|
* [`pydantic/types.py`](https://github.com/pydantic/pydantic/blob/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py)
|
||||||
|
* [`large/dataset.py`](https://github.com/DHI/mikeio/blob/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py)
|
||||||
|
* [`tomllib`](https://github.com/python/cpython/tree/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib) (3.12)
|
||||||
|
|
||||||
|
The files are included in the `resources` directory to allow
|
||||||
|
running benchmarks offline and for simplicity. They're licensed
|
||||||
|
according to their original licenses (see link).
|
1617
crates/ruff_benchmark/resources/large/dataset.py
Normal file
1617
crates/ruff_benchmark/resources/large/dataset.py
Normal file
File diff suppressed because it is too large
Load diff
547
crates/ruff_benchmark/resources/numpy/ctypeslib.py
Normal file
547
crates/ruff_benchmark/resources/numpy/ctypeslib.py
Normal file
|
@ -0,0 +1,547 @@
|
||||||
|
"""
|
||||||
|
============================
|
||||||
|
``ctypes`` Utility Functions
|
||||||
|
============================
|
||||||
|
|
||||||
|
See Also
|
||||||
|
--------
|
||||||
|
load_library : Load a C library.
|
||||||
|
ndpointer : Array restype/argtype with verification.
|
||||||
|
as_ctypes : Create a ctypes array from an ndarray.
|
||||||
|
as_array : Create an ndarray from a ctypes array.
|
||||||
|
|
||||||
|
References
|
||||||
|
----------
|
||||||
|
.. [1] "SciPy Cookbook: ctypes", https://scipy-cookbook.readthedocs.io/items/Ctypes.html
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
Load the C library:
|
||||||
|
|
||||||
|
>>> _lib = np.ctypeslib.load_library('libmystuff', '.') #doctest: +SKIP
|
||||||
|
|
||||||
|
Our result type, an ndarray that must be of type double, be 1-dimensional
|
||||||
|
and is C-contiguous in memory:
|
||||||
|
|
||||||
|
>>> array_1d_double = np.ctypeslib.ndpointer(
|
||||||
|
... dtype=np.double,
|
||||||
|
... ndim=1, flags='CONTIGUOUS') #doctest: +SKIP
|
||||||
|
|
||||||
|
Our C-function typically takes an array and updates its values
|
||||||
|
in-place. For example::
|
||||||
|
|
||||||
|
void foo_func(double* x, int length)
|
||||||
|
{
|
||||||
|
int i;
|
||||||
|
for (i = 0; i < length; i++) {
|
||||||
|
x[i] = i*i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
We wrap it using:
|
||||||
|
|
||||||
|
>>> _lib.foo_func.restype = None #doctest: +SKIP
|
||||||
|
>>> _lib.foo_func.argtypes = [array_1d_double, c_int] #doctest: +SKIP
|
||||||
|
|
||||||
|
Then, we're ready to call ``foo_func``:
|
||||||
|
|
||||||
|
>>> out = np.empty(15, dtype=np.double)
|
||||||
|
>>> _lib.foo_func(out, len(out)) #doctest: +SKIP
|
||||||
|
|
||||||
|
"""
|
||||||
|
__all__ = ['load_library', 'ndpointer', 'c_intp', 'as_ctypes', 'as_array',
|
||||||
|
'as_ctypes_type']
|
||||||
|
|
||||||
|
import os
|
||||||
|
from numpy import (
|
||||||
|
integer, ndarray, dtype as _dtype, asarray, frombuffer
|
||||||
|
)
|
||||||
|
from numpy.core.multiarray import _flagdict, flagsobj
|
||||||
|
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
except ImportError:
|
||||||
|
ctypes = None
|
||||||
|
|
||||||
|
if ctypes is None:
|
||||||
|
def _dummy(*args, **kwds):
|
||||||
|
"""
|
||||||
|
Dummy object that raises an ImportError if ctypes is not available.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ImportError
|
||||||
|
If ctypes is not available.
|
||||||
|
|
||||||
|
"""
|
||||||
|
raise ImportError("ctypes is not available.")
|
||||||
|
load_library = _dummy
|
||||||
|
as_ctypes = _dummy
|
||||||
|
as_array = _dummy
|
||||||
|
from numpy import intp as c_intp
|
||||||
|
_ndptr_base = object
|
||||||
|
else:
|
||||||
|
import numpy.core._internal as nic
|
||||||
|
c_intp = nic._getintp_ctype()
|
||||||
|
del nic
|
||||||
|
_ndptr_base = ctypes.c_void_p
|
||||||
|
|
||||||
|
# Adapted from Albert Strasheim
|
||||||
|
def load_library(libname, loader_path):
|
||||||
|
"""
|
||||||
|
It is possible to load a library using
|
||||||
|
|
||||||
|
>>> lib = ctypes.cdll[<full_path_name>] # doctest: +SKIP
|
||||||
|
|
||||||
|
But there are cross-platform considerations, such as library file extensions,
|
||||||
|
plus the fact Windows will just load the first library it finds with that name.
|
||||||
|
NumPy supplies the load_library function as a convenience.
|
||||||
|
|
||||||
|
.. versionchanged:: 1.20.0
|
||||||
|
Allow libname and loader_path to take any
|
||||||
|
:term:`python:path-like object`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
libname : path-like
|
||||||
|
Name of the library, which can have 'lib' as a prefix,
|
||||||
|
but without an extension.
|
||||||
|
loader_path : path-like
|
||||||
|
Where the library can be found.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ctypes.cdll[libpath] : library object
|
||||||
|
A ctypes library object
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
OSError
|
||||||
|
If there is no library with the expected extension, or the
|
||||||
|
library is defective and cannot be loaded.
|
||||||
|
"""
|
||||||
|
if ctypes.__version__ < '1.0.1':
|
||||||
|
import warnings
|
||||||
|
warnings.warn("All features of ctypes interface may not work "
|
||||||
|
"with ctypes < 1.0.1", stacklevel=2)
|
||||||
|
|
||||||
|
# Convert path-like objects into strings
|
||||||
|
libname = os.fsdecode(libname)
|
||||||
|
loader_path = os.fsdecode(loader_path)
|
||||||
|
|
||||||
|
ext = os.path.splitext(libname)[1]
|
||||||
|
if not ext:
|
||||||
|
# Try to load library with platform-specific name, otherwise
|
||||||
|
# default to libname.[so|pyd]. Sometimes, these files are built
|
||||||
|
# erroneously on non-linux platforms.
|
||||||
|
from numpy.distutils.misc_util import get_shared_lib_extension
|
||||||
|
so_ext = get_shared_lib_extension()
|
||||||
|
libname_ext = [libname + so_ext]
|
||||||
|
# mac, windows and linux >= py3.2 shared library and loadable
|
||||||
|
# module have different extensions so try both
|
||||||
|
so_ext2 = get_shared_lib_extension(is_python_ext=True)
|
||||||
|
if not so_ext2 == so_ext:
|
||||||
|
libname_ext.insert(0, libname + so_ext2)
|
||||||
|
else:
|
||||||
|
libname_ext = [libname]
|
||||||
|
|
||||||
|
loader_path = os.path.abspath(loader_path)
|
||||||
|
if not os.path.isdir(loader_path):
|
||||||
|
libdir = os.path.dirname(loader_path)
|
||||||
|
else:
|
||||||
|
libdir = loader_path
|
||||||
|
|
||||||
|
for ln in libname_ext:
|
||||||
|
libpath = os.path.join(libdir, ln)
|
||||||
|
if os.path.exists(libpath):
|
||||||
|
try:
|
||||||
|
return ctypes.cdll[libpath]
|
||||||
|
except OSError:
|
||||||
|
## defective lib file
|
||||||
|
raise
|
||||||
|
## if no successful return in the libname_ext loop:
|
||||||
|
raise OSError("no file with expected extension")
|
||||||
|
|
||||||
|
|
||||||
|
def _num_fromflags(flaglist):
|
||||||
|
num = 0
|
||||||
|
for val in flaglist:
|
||||||
|
num += _flagdict[val]
|
||||||
|
return num
|
||||||
|
|
||||||
|
_flagnames = ['C_CONTIGUOUS', 'F_CONTIGUOUS', 'ALIGNED', 'WRITEABLE',
|
||||||
|
'OWNDATA', 'WRITEBACKIFCOPY']
|
||||||
|
def _flags_fromnum(num):
|
||||||
|
res = []
|
||||||
|
for key in _flagnames:
|
||||||
|
value = _flagdict[key]
|
||||||
|
if (num & value):
|
||||||
|
res.append(key)
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
class _ndptr(_ndptr_base):
|
||||||
|
@classmethod
|
||||||
|
def from_param(cls, obj):
|
||||||
|
if not isinstance(obj, ndarray):
|
||||||
|
raise TypeError("argument must be an ndarray")
|
||||||
|
if cls._dtype_ is not None \
|
||||||
|
and obj.dtype != cls._dtype_:
|
||||||
|
raise TypeError("array must have data type %s" % cls._dtype_)
|
||||||
|
if cls._ndim_ is not None \
|
||||||
|
and obj.ndim != cls._ndim_:
|
||||||
|
raise TypeError("array must have %d dimension(s)" % cls._ndim_)
|
||||||
|
if cls._shape_ is not None \
|
||||||
|
and obj.shape != cls._shape_:
|
||||||
|
raise TypeError("array must have shape %s" % str(cls._shape_))
|
||||||
|
if cls._flags_ is not None \
|
||||||
|
and ((obj.flags.num & cls._flags_) != cls._flags_):
|
||||||
|
raise TypeError("array must have flags %s" %
|
||||||
|
_flags_fromnum(cls._flags_))
|
||||||
|
return obj.ctypes
|
||||||
|
|
||||||
|
|
||||||
|
class _concrete_ndptr(_ndptr):
|
||||||
|
"""
|
||||||
|
Like _ndptr, but with `_shape_` and `_dtype_` specified.
|
||||||
|
|
||||||
|
Notably, this means the pointer has enough information to reconstruct
|
||||||
|
the array, which is not generally true.
|
||||||
|
"""
|
||||||
|
def _check_retval_(self):
|
||||||
|
"""
|
||||||
|
This method is called when this class is used as the .restype
|
||||||
|
attribute for a shared-library function, to automatically wrap the
|
||||||
|
pointer into an array.
|
||||||
|
"""
|
||||||
|
return self.contents
|
||||||
|
|
||||||
|
@property
|
||||||
|
def contents(self):
|
||||||
|
"""
|
||||||
|
Get an ndarray viewing the data pointed to by this pointer.
|
||||||
|
|
||||||
|
This mirrors the `contents` attribute of a normal ctypes pointer
|
||||||
|
"""
|
||||||
|
full_dtype = _dtype((self._dtype_, self._shape_))
|
||||||
|
full_ctype = ctypes.c_char * full_dtype.itemsize
|
||||||
|
buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents
|
||||||
|
return frombuffer(buffer, dtype=full_dtype).squeeze(axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
# Factory for an array-checking class with from_param defined for
|
||||||
|
# use with ctypes argtypes mechanism
|
||||||
|
_pointer_type_cache = {}
|
||||||
|
def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
|
||||||
|
"""
|
||||||
|
Array-checking restype/argtypes.
|
||||||
|
|
||||||
|
An ndpointer instance is used to describe an ndarray in restypes
|
||||||
|
and argtypes specifications. This approach is more flexible than
|
||||||
|
using, for example, ``POINTER(c_double)``, since several restrictions
|
||||||
|
can be specified, which are verified upon calling the ctypes function.
|
||||||
|
These include data type, number of dimensions, shape and flags. If a
|
||||||
|
given array does not satisfy the specified restrictions,
|
||||||
|
a ``TypeError`` is raised.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dtype : data-type, optional
|
||||||
|
Array data-type.
|
||||||
|
ndim : int, optional
|
||||||
|
Number of array dimensions.
|
||||||
|
shape : tuple of ints, optional
|
||||||
|
Array shape.
|
||||||
|
flags : str or tuple of str
|
||||||
|
Array flags; may be one or more of:
|
||||||
|
|
||||||
|
- C_CONTIGUOUS / C / CONTIGUOUS
|
||||||
|
- F_CONTIGUOUS / F / FORTRAN
|
||||||
|
- OWNDATA / O
|
||||||
|
- WRITEABLE / W
|
||||||
|
- ALIGNED / A
|
||||||
|
- WRITEBACKIFCOPY / X
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
klass : ndpointer type object
|
||||||
|
A type object, which is an ``_ndtpr`` instance containing
|
||||||
|
dtype, ndim, shape and flags information.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError
|
||||||
|
If a given array does not satisfy the specified restrictions.
|
||||||
|
|
||||||
|
Examples
|
||||||
|
--------
|
||||||
|
>>> clib.somefunc.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64,
|
||||||
|
... ndim=1,
|
||||||
|
... flags='C_CONTIGUOUS')]
|
||||||
|
... #doctest: +SKIP
|
||||||
|
>>> clib.somefunc(np.array([1, 2, 3], dtype=np.float64))
|
||||||
|
... #doctest: +SKIP
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# normalize dtype to an Optional[dtype]
|
||||||
|
if dtype is not None:
|
||||||
|
dtype = _dtype(dtype)
|
||||||
|
|
||||||
|
# normalize flags to an Optional[int]
|
||||||
|
num = None
|
||||||
|
if flags is not None:
|
||||||
|
if isinstance(flags, str):
|
||||||
|
flags = flags.split(',')
|
||||||
|
elif isinstance(flags, (int, integer)):
|
||||||
|
num = flags
|
||||||
|
flags = _flags_fromnum(num)
|
||||||
|
elif isinstance(flags, flagsobj):
|
||||||
|
num = flags.num
|
||||||
|
flags = _flags_fromnum(num)
|
||||||
|
if num is None:
|
||||||
|
try:
|
||||||
|
flags = [x.strip().upper() for x in flags]
|
||||||
|
except Exception as e:
|
||||||
|
raise TypeError("invalid flags specification") from e
|
||||||
|
num = _num_fromflags(flags)
|
||||||
|
|
||||||
|
# normalize shape to an Optional[tuple]
|
||||||
|
if shape is not None:
|
||||||
|
try:
|
||||||
|
shape = tuple(shape)
|
||||||
|
except TypeError:
|
||||||
|
# single integer -> 1-tuple
|
||||||
|
shape = (shape,)
|
||||||
|
|
||||||
|
cache_key = (dtype, ndim, shape, num)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return _pointer_type_cache[cache_key]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# produce a name for the new type
|
||||||
|
if dtype is None:
|
||||||
|
name = 'any'
|
||||||
|
elif dtype.names is not None:
|
||||||
|
name = str(id(dtype))
|
||||||
|
else:
|
||||||
|
name = dtype.str
|
||||||
|
if ndim is not None:
|
||||||
|
name += "_%dd" % ndim
|
||||||
|
if shape is not None:
|
||||||
|
name += "_"+"x".join(str(x) for x in shape)
|
||||||
|
if flags is not None:
|
||||||
|
name += "_"+"_".join(flags)
|
||||||
|
|
||||||
|
if dtype is not None and shape is not None:
|
||||||
|
base = _concrete_ndptr
|
||||||
|
else:
|
||||||
|
base = _ndptr
|
||||||
|
|
||||||
|
klass = type("ndpointer_%s"%name, (base,),
|
||||||
|
{"_dtype_": dtype,
|
||||||
|
"_shape_" : shape,
|
||||||
|
"_ndim_" : ndim,
|
||||||
|
"_flags_" : num})
|
||||||
|
_pointer_type_cache[cache_key] = klass
|
||||||
|
return klass
|
||||||
|
|
||||||
|
|
||||||
|
if ctypes is not None:
|
||||||
|
def _ctype_ndarray(element_type, shape):
|
||||||
|
""" Create an ndarray of the given element type and shape """
|
||||||
|
for dim in shape[::-1]:
|
||||||
|
element_type = dim * element_type
|
||||||
|
# prevent the type name include np.ctypeslib
|
||||||
|
element_type.__module__ = None
|
||||||
|
return element_type
|
||||||
|
|
||||||
|
|
||||||
|
def _get_scalar_type_map():
|
||||||
|
"""
|
||||||
|
Return a dictionary mapping native endian scalar dtype to ctypes types
|
||||||
|
"""
|
||||||
|
ct = ctypes
|
||||||
|
simple_types = [
|
||||||
|
ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong,
|
||||||
|
ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong,
|
||||||
|
ct.c_float, ct.c_double,
|
||||||
|
ct.c_bool,
|
||||||
|
]
|
||||||
|
return {_dtype(ctype): ctype for ctype in simple_types}
|
||||||
|
|
||||||
|
|
||||||
|
_scalar_type_map = _get_scalar_type_map()
|
||||||
|
|
||||||
|
|
||||||
|
def _ctype_from_dtype_scalar(dtype):
|
||||||
|
# swapping twice ensure that `=` is promoted to <, >, or |
|
||||||
|
dtype_with_endian = dtype.newbyteorder('S').newbyteorder('S')
|
||||||
|
dtype_native = dtype.newbyteorder('=')
|
||||||
|
try:
|
||||||
|
ctype = _scalar_type_map[dtype_native]
|
||||||
|
except KeyError as e:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Converting {!r} to a ctypes type".format(dtype)
|
||||||
|
) from None
|
||||||
|
|
||||||
|
if dtype_with_endian.byteorder == '>':
|
||||||
|
ctype = ctype.__ctype_be__
|
||||||
|
elif dtype_with_endian.byteorder == '<':
|
||||||
|
ctype = ctype.__ctype_le__
|
||||||
|
|
||||||
|
return ctype
|
||||||
|
|
||||||
|
|
||||||
|
def _ctype_from_dtype_subarray(dtype):
|
||||||
|
element_dtype, shape = dtype.subdtype
|
||||||
|
ctype = _ctype_from_dtype(element_dtype)
|
||||||
|
return _ctype_ndarray(ctype, shape)
|
||||||
|
|
||||||
|
|
||||||
|
def _ctype_from_dtype_structured(dtype):
|
||||||
|
# extract offsets of each field
|
||||||
|
field_data = []
|
||||||
|
for name in dtype.names:
|
||||||
|
field_dtype, offset = dtype.fields[name][:2]
|
||||||
|
field_data.append((offset, name, _ctype_from_dtype(field_dtype)))
|
||||||
|
|
||||||
|
# ctypes doesn't care about field order
|
||||||
|
field_data = sorted(field_data, key=lambda f: f[0])
|
||||||
|
|
||||||
|
if len(field_data) > 1 and all(offset == 0 for offset, name, ctype in field_data):
|
||||||
|
# union, if multiple fields all at address 0
|
||||||
|
size = 0
|
||||||
|
_fields_ = []
|
||||||
|
for offset, name, ctype in field_data:
|
||||||
|
_fields_.append((name, ctype))
|
||||||
|
size = max(size, ctypes.sizeof(ctype))
|
||||||
|
|
||||||
|
# pad to the right size
|
||||||
|
if dtype.itemsize != size:
|
||||||
|
_fields_.append(('', ctypes.c_char * dtype.itemsize))
|
||||||
|
|
||||||
|
# we inserted manual padding, so always `_pack_`
|
||||||
|
return type('union', (ctypes.Union,), dict(
|
||||||
|
_fields_=_fields_,
|
||||||
|
_pack_=1,
|
||||||
|
__module__=None,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
last_offset = 0
|
||||||
|
_fields_ = []
|
||||||
|
for offset, name, ctype in field_data:
|
||||||
|
padding = offset - last_offset
|
||||||
|
if padding < 0:
|
||||||
|
raise NotImplementedError("Overlapping fields")
|
||||||
|
if padding > 0:
|
||||||
|
_fields_.append(('', ctypes.c_char * padding))
|
||||||
|
|
||||||
|
_fields_.append((name, ctype))
|
||||||
|
last_offset = offset + ctypes.sizeof(ctype)
|
||||||
|
|
||||||
|
|
||||||
|
padding = dtype.itemsize - last_offset
|
||||||
|
if padding > 0:
|
||||||
|
_fields_.append(('', ctypes.c_char * padding))
|
||||||
|
|
||||||
|
# we inserted manual padding, so always `_pack_`
|
||||||
|
return type('struct', (ctypes.Structure,), dict(
|
||||||
|
_fields_=_fields_,
|
||||||
|
_pack_=1,
|
||||||
|
__module__=None,
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def _ctype_from_dtype(dtype):
|
||||||
|
if dtype.fields is not None:
|
||||||
|
return _ctype_from_dtype_structured(dtype)
|
||||||
|
elif dtype.subdtype is not None:
|
||||||
|
return _ctype_from_dtype_subarray(dtype)
|
||||||
|
else:
|
||||||
|
return _ctype_from_dtype_scalar(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def as_ctypes_type(dtype):
|
||||||
|
r"""
|
||||||
|
Convert a dtype into a ctypes type.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dtype : dtype
|
||||||
|
The dtype to convert
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
ctype
|
||||||
|
A ctype scalar, union, array, or struct
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
NotImplementedError
|
||||||
|
If the conversion is not possible
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
This function does not losslessly round-trip in either direction.
|
||||||
|
|
||||||
|
``np.dtype(as_ctypes_type(dt))`` will:
|
||||||
|
|
||||||
|
- insert padding fields
|
||||||
|
- reorder fields to be sorted by offset
|
||||||
|
- discard field titles
|
||||||
|
|
||||||
|
``as_ctypes_type(np.dtype(ctype))`` will:
|
||||||
|
|
||||||
|
- discard the class names of `ctypes.Structure`\ s and
|
||||||
|
`ctypes.Union`\ s
|
||||||
|
- convert single-element `ctypes.Union`\ s into single-element
|
||||||
|
`ctypes.Structure`\ s
|
||||||
|
- insert padding fields
|
||||||
|
|
||||||
|
"""
|
||||||
|
return _ctype_from_dtype(_dtype(dtype))
|
||||||
|
|
||||||
|
|
||||||
|
def as_array(obj, shape=None):
|
||||||
|
"""
|
||||||
|
Create a numpy array from a ctypes array or POINTER.
|
||||||
|
|
||||||
|
The numpy array shares the memory with the ctypes object.
|
||||||
|
|
||||||
|
The shape parameter must be given if converting from a ctypes POINTER.
|
||||||
|
The shape parameter is ignored if converting from a ctypes array
|
||||||
|
"""
|
||||||
|
if isinstance(obj, ctypes._Pointer):
|
||||||
|
# convert pointers to an array of the desired shape
|
||||||
|
if shape is None:
|
||||||
|
raise TypeError(
|
||||||
|
'as_array() requires a shape argument when called on a '
|
||||||
|
'pointer')
|
||||||
|
p_arr_type = ctypes.POINTER(_ctype_ndarray(obj._type_, shape))
|
||||||
|
obj = ctypes.cast(obj, p_arr_type).contents
|
||||||
|
|
||||||
|
return asarray(obj)
|
||||||
|
|
||||||
|
|
||||||
|
def as_ctypes(obj):
|
||||||
|
"""Create and return a ctypes object from a numpy array. Actually
|
||||||
|
anything that exposes the __array_interface__ is accepted."""
|
||||||
|
ai = obj.__array_interface__
|
||||||
|
if ai["strides"]:
|
||||||
|
raise TypeError("strided arrays not supported")
|
||||||
|
if ai["version"] != 3:
|
||||||
|
raise TypeError("only __array_interface__ version 3 supported")
|
||||||
|
addr, readonly = ai["data"]
|
||||||
|
if readonly:
|
||||||
|
raise TypeError("readonly arrays unsupported")
|
||||||
|
|
||||||
|
# can't use `_dtype((ai["typestr"], ai["shape"]))` here, as it overflows
|
||||||
|
# dtype.itemsize (gh-14214)
|
||||||
|
ctype_scalar = as_ctypes_type(ai["typestr"])
|
||||||
|
result_type = _ctype_ndarray(ctype_scalar, ai["shape"])
|
||||||
|
result = result_type.from_address(addr)
|
||||||
|
result.__keep = obj
|
||||||
|
return result
|
95
crates/ruff_benchmark/resources/numpy/globals.py
Normal file
95
crates/ruff_benchmark/resources/numpy/globals.py
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
"""
|
||||||
|
Module defining global singleton classes.
|
||||||
|
|
||||||
|
This module raises a RuntimeError if an attempt to reload it is made. In that
|
||||||
|
way the identities of the classes defined here are fixed and will remain so
|
||||||
|
even if numpy itself is reloaded. In particular, a function like the following
|
||||||
|
will still work correctly after numpy is reloaded::
|
||||||
|
|
||||||
|
def foo(arg=np._NoValue):
|
||||||
|
if arg is np._NoValue:
|
||||||
|
...
|
||||||
|
|
||||||
|
That was not the case when the singleton classes were defined in the numpy
|
||||||
|
``__init__.py`` file. See gh-7844 for a discussion of the reload problem that
|
||||||
|
motivated this module.
|
||||||
|
|
||||||
|
"""
|
||||||
|
import enum
|
||||||
|
|
||||||
|
from ._utils import set_module as _set_module
|
||||||
|
|
||||||
|
__all__ = ['_NoValue', '_CopyMode']
|
||||||
|
|
||||||
|
|
||||||
|
# Disallow reloading this module so as to preserve the identities of the
|
||||||
|
# classes defined here.
|
||||||
|
if '_is_loaded' in globals():
|
||||||
|
raise RuntimeError('Reloading numpy._globals is not allowed')
|
||||||
|
_is_loaded = True
|
||||||
|
|
||||||
|
|
||||||
|
class _NoValueType:
|
||||||
|
"""Special keyword value.
|
||||||
|
|
||||||
|
The instance of this class may be used as the default value assigned to a
|
||||||
|
keyword if no other obvious default (e.g., `None`) is suitable,
|
||||||
|
|
||||||
|
Common reasons for using this keyword are:
|
||||||
|
|
||||||
|
- A new keyword is added to a function, and that function forwards its
|
||||||
|
inputs to another function or method which can be defined outside of
|
||||||
|
NumPy. For example, ``np.std(x)`` calls ``x.std``, so when a ``keepdims``
|
||||||
|
keyword was added that could only be forwarded if the user explicitly
|
||||||
|
specified ``keepdims``; downstream array libraries may not have added
|
||||||
|
the same keyword, so adding ``x.std(..., keepdims=keepdims)``
|
||||||
|
unconditionally could have broken previously working code.
|
||||||
|
- A keyword is being deprecated, and a deprecation warning must only be
|
||||||
|
emitted when the keyword is used.
|
||||||
|
|
||||||
|
"""
|
||||||
|
__instance = None
|
||||||
|
def __new__(cls):
|
||||||
|
# ensure that only one instance exists
|
||||||
|
if not cls.__instance:
|
||||||
|
cls.__instance = super().__new__(cls)
|
||||||
|
return cls.__instance
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<no value>"
|
||||||
|
|
||||||
|
|
||||||
|
_NoValue = _NoValueType()
|
||||||
|
|
||||||
|
|
||||||
|
@_set_module("numpy")
|
||||||
|
class _CopyMode(enum.Enum):
|
||||||
|
"""
|
||||||
|
An enumeration for the copy modes supported
|
||||||
|
by numpy.copy() and numpy.array(). The following three modes are supported,
|
||||||
|
|
||||||
|
- ALWAYS: This means that a deep copy of the input
|
||||||
|
array will always be taken.
|
||||||
|
- IF_NEEDED: This means that a deep copy of the input
|
||||||
|
array will be taken only if necessary.
|
||||||
|
- NEVER: This means that the deep copy will never be taken.
|
||||||
|
If a copy cannot be avoided then a `ValueError` will be
|
||||||
|
raised.
|
||||||
|
|
||||||
|
Note that the buffer-protocol could in theory do copies. NumPy currently
|
||||||
|
assumes an object exporting the buffer protocol will never do this.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALWAYS = True
|
||||||
|
IF_NEEDED = False
|
||||||
|
NEVER = 2
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
# For backwards compatibility
|
||||||
|
if self == _CopyMode.ALWAYS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if self == _CopyMode.IF_NEEDED:
|
||||||
|
return False
|
||||||
|
|
||||||
|
raise ValueError(f"{self} is neither True nor False.")
|
834
crates/ruff_benchmark/resources/pydantic/types.py
Normal file
834
crates/ruff_benchmark/resources/pydantic/types.py
Normal file
|
@ -0,0 +1,834 @@
|
||||||
|
from __future__ import annotations as _annotations
|
||||||
|
|
||||||
|
import abc
|
||||||
|
import dataclasses as _dataclasses
|
||||||
|
import re
|
||||||
|
from datetime import date, datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
ClassVar,
|
||||||
|
FrozenSet,
|
||||||
|
Generic,
|
||||||
|
Hashable,
|
||||||
|
List,
|
||||||
|
Set,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import annotated_types
|
||||||
|
from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
|
||||||
|
from typing_extensions import Annotated, Literal
|
||||||
|
|
||||||
|
from ._internal import _fields, _validators
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'Strict',
|
||||||
|
'StrictStr',
|
||||||
|
'conbytes',
|
||||||
|
'conlist',
|
||||||
|
'conset',
|
||||||
|
'confrozenset',
|
||||||
|
'constr',
|
||||||
|
'ImportString',
|
||||||
|
'conint',
|
||||||
|
'PositiveInt',
|
||||||
|
'NegativeInt',
|
||||||
|
'NonNegativeInt',
|
||||||
|
'NonPositiveInt',
|
||||||
|
'confloat',
|
||||||
|
'PositiveFloat',
|
||||||
|
'NegativeFloat',
|
||||||
|
'NonNegativeFloat',
|
||||||
|
'NonPositiveFloat',
|
||||||
|
'FiniteFloat',
|
||||||
|
'condecimal',
|
||||||
|
'UUID1',
|
||||||
|
'UUID3',
|
||||||
|
'UUID4',
|
||||||
|
'UUID5',
|
||||||
|
'FilePath',
|
||||||
|
'DirectoryPath',
|
||||||
|
'Json',
|
||||||
|
'SecretField',
|
||||||
|
'SecretStr',
|
||||||
|
'SecretBytes',
|
||||||
|
'StrictBool',
|
||||||
|
'StrictBytes',
|
||||||
|
'StrictInt',
|
||||||
|
'StrictFloat',
|
||||||
|
'PaymentCardNumber',
|
||||||
|
'ByteSize',
|
||||||
|
'PastDate',
|
||||||
|
'FutureDate',
|
||||||
|
'condate',
|
||||||
|
'AwareDatetime',
|
||||||
|
'NaiveDatetime',
|
||||||
|
]
|
||||||
|
|
||||||
|
from ._internal._core_metadata import build_metadata_dict
|
||||||
|
from ._internal._utils import update_not_none
|
||||||
|
from .json_schema import JsonSchemaMetadata
|
||||||
|
|
||||||
|
|
||||||
|
@_dataclasses.dataclass
|
||||||
|
class Strict(_fields.PydanticMetadata):
|
||||||
|
strict: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BOOLEAN TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
StrictBool = Annotated[bool, Strict()]
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTEGER TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
def conint(
|
||||||
|
*,
|
||||||
|
strict: bool | None = None,
|
||||||
|
gt: int | None = None,
|
||||||
|
ge: int | None = None,
|
||||||
|
lt: int | None = None,
|
||||||
|
le: int | None = None,
|
||||||
|
multiple_of: int | None = None,
|
||||||
|
) -> type[int]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
int,
|
||||||
|
Strict(strict) if strict is not None else None,
|
||||||
|
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
|
||||||
|
annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
PositiveInt = Annotated[int, annotated_types.Gt(0)]
|
||||||
|
NegativeInt = Annotated[int, annotated_types.Lt(0)]
|
||||||
|
NonPositiveInt = Annotated[int, annotated_types.Le(0)]
|
||||||
|
NonNegativeInt = Annotated[int, annotated_types.Ge(0)]
|
||||||
|
StrictInt = Annotated[int, Strict()]
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLOAT TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
@_dataclasses.dataclass
|
||||||
|
class AllowInfNan(_fields.PydanticMetadata):
|
||||||
|
allow_inf_nan: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
def confloat(
|
||||||
|
*,
|
||||||
|
strict: bool | None = None,
|
||||||
|
gt: float | None = None,
|
||||||
|
ge: float | None = None,
|
||||||
|
lt: float | None = None,
|
||||||
|
le: float | None = None,
|
||||||
|
multiple_of: float | None = None,
|
||||||
|
allow_inf_nan: bool | None = None,
|
||||||
|
) -> type[float]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
float,
|
||||||
|
Strict(strict) if strict is not None else None,
|
||||||
|
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
|
||||||
|
annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None,
|
||||||
|
AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
PositiveFloat = Annotated[float, annotated_types.Gt(0)]
|
||||||
|
NegativeFloat = Annotated[float, annotated_types.Lt(0)]
|
||||||
|
NonPositiveFloat = Annotated[float, annotated_types.Le(0)]
|
||||||
|
NonNegativeFloat = Annotated[float, annotated_types.Ge(0)]
|
||||||
|
StrictFloat = Annotated[float, Strict(True)]
|
||||||
|
FiniteFloat = Annotated[float, AllowInfNan(False)]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
def conbytes(
|
||||||
|
*,
|
||||||
|
min_length: int | None = None,
|
||||||
|
max_length: int | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
|
) -> type[bytes]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
bytes,
|
||||||
|
Strict(strict) if strict is not None else None,
|
||||||
|
annotated_types.Len(min_length or 0, max_length),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
StrictBytes = Annotated[bytes, Strict()]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ STRING TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
def constr(
|
||||||
|
*,
|
||||||
|
strip_whitespace: bool | None = None,
|
||||||
|
to_upper: bool | None = None,
|
||||||
|
to_lower: bool | None = None,
|
||||||
|
strict: bool | None = None,
|
||||||
|
min_length: int | None = None,
|
||||||
|
max_length: int | None = None,
|
||||||
|
pattern: str | None = None,
|
||||||
|
) -> type[str]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
str,
|
||||||
|
Strict(strict) if strict is not None else None,
|
||||||
|
annotated_types.Len(min_length or 0, max_length),
|
||||||
|
_fields.PydanticGeneralMetadata(
|
||||||
|
strip_whitespace=strip_whitespace,
|
||||||
|
to_upper=to_upper,
|
||||||
|
to_lower=to_lower,
|
||||||
|
pattern=pattern,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
StrictStr = Annotated[str, Strict()]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ COLLECTION TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
HashableItemType = TypeVar('HashableItemType', bound=Hashable)
|
||||||
|
|
||||||
|
|
||||||
|
def conset(
|
||||||
|
item_type: Type[HashableItemType], *, min_length: int = None, max_length: int = None
|
||||||
|
) -> Type[Set[HashableItemType]]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
Set[item_type], annotated_types.Len(min_length or 0, max_length) # type: ignore[valid-type]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def confrozenset(
|
||||||
|
item_type: Type[HashableItemType], *, min_length: int | None = None, max_length: int | None = None
|
||||||
|
) -> Type[FrozenSet[HashableItemType]]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
FrozenSet[item_type], # type: ignore[valid-type]
|
||||||
|
annotated_types.Len(min_length or 0, max_length),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
AnyItemType = TypeVar('AnyItemType')
|
||||||
|
|
||||||
|
|
||||||
|
def conlist(
|
||||||
|
item_type: Type[AnyItemType], *, min_length: int | None = None, max_length: int | None = None
|
||||||
|
) -> Type[List[AnyItemType]]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
List[item_type], # type: ignore[valid-type]
|
||||||
|
annotated_types.Len(min_length or 0, max_length),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def contuple(
|
||||||
|
item_type: Type[AnyItemType], *, min_length: int | None = None, max_length: int | None = None
|
||||||
|
) -> Type[Tuple[AnyItemType]]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
Tuple[item_type],
|
||||||
|
annotated_types.Len(min_length or 0, max_length),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ IMPORT STRING TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
AnyType = TypeVar('AnyType')
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
ImportString = Annotated[AnyType, ...]
|
||||||
|
else:
|
||||||
|
|
||||||
|
class ImportString:
|
||||||
|
@classmethod
|
||||||
|
def __class_getitem__(cls, item: AnyType) -> AnyType:
|
||||||
|
return Annotated[item, cls()]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
if schema is None or schema == {'type': 'any'}:
|
||||||
|
# Treat bare usage of ImportString (`schema is None`) as the same as ImportString[Any]
|
||||||
|
return core_schema.function_plain_schema(lambda v, _: _validators.import_string(v))
|
||||||
|
else:
|
||||||
|
return core_schema.function_before_schema(lambda v, _: _validators.import_string(v), schema)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return 'ImportString'
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DECIMAL TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
def condecimal(
|
||||||
|
*,
|
||||||
|
strict: bool | None = None,
|
||||||
|
gt: int | Decimal | None = None,
|
||||||
|
ge: int | Decimal | None = None,
|
||||||
|
lt: int | Decimal | None = None,
|
||||||
|
le: int | Decimal | None = None,
|
||||||
|
multiple_of: int | Decimal | None = None,
|
||||||
|
max_digits: int | None = None,
|
||||||
|
decimal_places: int | None = None,
|
||||||
|
allow_inf_nan: bool | None = None,
|
||||||
|
) -> Type[Decimal]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
Decimal,
|
||||||
|
Strict(strict) if strict is not None else None,
|
||||||
|
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
|
||||||
|
annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None,
|
||||||
|
_fields.PydanticGeneralMetadata(max_digits=max_digits, decimal_places=decimal_places),
|
||||||
|
AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
@_dataclasses.dataclass(frozen=True) # Add frozen=True to make it hashable
|
||||||
|
class UuidVersion:
|
||||||
|
uuid_version: Literal[1, 3, 4, 5]
|
||||||
|
|
||||||
|
def __pydantic_modify_json_schema__(self, field_schema: dict[str, Any]) -> None:
|
||||||
|
field_schema.pop('anyOf', None) # remove the bytes/str union
|
||||||
|
field_schema.update(type='string', format=f'uuid{self.uuid_version}')
|
||||||
|
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
self, schema: core_schema.CoreSchema, **_kwargs: Any
|
||||||
|
) -> core_schema.FunctionSchema:
|
||||||
|
return core_schema.function_after_schema(schema, cast(core_schema.ValidatorFunction, self.validate))
|
||||||
|
|
||||||
|
def validate(self, value: UUID, _: core_schema.ValidationInfo) -> UUID:
|
||||||
|
if value.version != self.uuid_version:
|
||||||
|
raise PydanticCustomError(
|
||||||
|
'uuid_version', 'uuid version {required_version} expected', {'required_version': self.uuid_version}
|
||||||
|
)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
UUID1 = Annotated[UUID, UuidVersion(1)]
|
||||||
|
UUID3 = Annotated[UUID, UuidVersion(3)]
|
||||||
|
UUID4 = Annotated[UUID, UuidVersion(4)]
|
||||||
|
UUID5 = Annotated[UUID, UuidVersion(5)]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
@_dataclasses.dataclass
|
||||||
|
class PathType:
|
||||||
|
path_type: Literal['file', 'dir', 'new']
|
||||||
|
|
||||||
|
def __pydantic_modify_json_schema__(self, field_schema: dict[str, Any]) -> None:
|
||||||
|
format_conversion = {'file': 'file-path', 'dir': 'directory-path'}
|
||||||
|
field_schema.update(format=format_conversion.get(self.path_type, 'path'), type='string')
|
||||||
|
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
self, schema: core_schema.CoreSchema, **_kwargs: Any
|
||||||
|
) -> core_schema.FunctionSchema:
|
||||||
|
function_lookup = {
|
||||||
|
'file': cast(core_schema.ValidatorFunction, self.validate_file),
|
||||||
|
'dir': cast(core_schema.ValidatorFunction, self.validate_directory),
|
||||||
|
'new': cast(core_schema.ValidatorFunction, self.validate_new),
|
||||||
|
}
|
||||||
|
|
||||||
|
return core_schema.function_after_schema(
|
||||||
|
schema,
|
||||||
|
function_lookup[self.path_type],
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_file(path: Path, _: core_schema.ValidationInfo) -> Path:
|
||||||
|
if path.is_file():
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
raise PydanticCustomError('path_not_file', 'Path does not point to a file')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_directory(path: Path, _: core_schema.ValidationInfo) -> Path:
|
||||||
|
if path.is_dir():
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
raise PydanticCustomError('path_not_directory', 'Path does not point to a directory')
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_new(path: Path, _: core_schema.ValidationInfo) -> Path:
|
||||||
|
if path.exists():
|
||||||
|
raise PydanticCustomError('path_exists', 'path already exists')
|
||||||
|
elif not path.parent.exists():
|
||||||
|
raise PydanticCustomError('parent_does_not_exist', 'Parent directory does not exist')
|
||||||
|
else:
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
FilePath = Annotated[Path, PathType('file')]
|
||||||
|
DirectoryPath = Annotated[Path, PathType('dir')]
|
||||||
|
NewPath = Annotated[Path, PathType('new')]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str]
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Json:
|
||||||
|
@classmethod
|
||||||
|
def __class_getitem__(cls, item: AnyType) -> AnyType:
|
||||||
|
return Annotated[item, cls()]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
|
||||||
|
) -> core_schema.JsonSchema:
|
||||||
|
return core_schema.json_schema(schema)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pydantic_modify_json_schema__(cls, field_schema: dict[str, Any]) -> None:
|
||||||
|
field_schema.update(type='string', format='json-string')
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return 'Json'
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(type(self))
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
return type(other) == type(self)
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
SecretType = TypeVar('SecretType', str, bytes)
|
||||||
|
|
||||||
|
|
||||||
|
class SecretField(abc.ABC, Generic[SecretType]):
|
||||||
|
_error_kind: str
|
||||||
|
|
||||||
|
def __init__(self, secret_value: SecretType) -> None:
|
||||||
|
self._secret_value: SecretType = secret_value
|
||||||
|
|
||||||
|
def get_secret_value(self) -> SecretType:
|
||||||
|
return self._secret_value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionSchema:
|
||||||
|
validator = SecretFieldValidator(cls)
|
||||||
|
if issubclass(cls, SecretStr):
|
||||||
|
# Use a lambda here so that `apply_metadata` can be called on the validator before the override is generated
|
||||||
|
override = lambda: core_schema.str_schema( # noqa E731
|
||||||
|
min_length=validator.min_length,
|
||||||
|
max_length=validator.max_length,
|
||||||
|
)
|
||||||
|
elif issubclass(cls, SecretBytes):
|
||||||
|
override = lambda: core_schema.bytes_schema( # noqa E731
|
||||||
|
min_length=validator.min_length,
|
||||||
|
max_length=validator.max_length,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
override = None
|
||||||
|
metadata = build_metadata_dict(
|
||||||
|
update_cs_function=validator.__pydantic_update_schema__,
|
||||||
|
js_metadata=JsonSchemaMetadata(core_schema_override=override),
|
||||||
|
)
|
||||||
|
return core_schema.function_after_schema(
|
||||||
|
core_schema.union_schema(
|
||||||
|
core_schema.is_instance_schema(cls),
|
||||||
|
cls._pre_core_schema(),
|
||||||
|
strict=True,
|
||||||
|
custom_error_type=cls._error_kind,
|
||||||
|
),
|
||||||
|
validator,
|
||||||
|
metadata=metadata,
|
||||||
|
serialization=core_schema.function_plain_ser_schema(cls._serialize, json_return_type='str'),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _serialize(
|
||||||
|
cls, value: SecretField[SecretType], info: core_schema.SerializationInfo
|
||||||
|
) -> str | SecretField[SecretType]:
|
||||||
|
if info.mode == 'json':
|
||||||
|
# we want the output to always be string without the `b'` prefix for byties,
|
||||||
|
# hence we just use `secret_display`
|
||||||
|
return secret_display(value)
|
||||||
|
else:
|
||||||
|
return value
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _pre_core_schema(cls) -> core_schema.CoreSchema:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __pydantic_modify_json_schema__(cls, field_schema: dict[str, Any]) -> None:
|
||||||
|
update_not_none(
|
||||||
|
field_schema,
|
||||||
|
type='string',
|
||||||
|
writeOnly=True,
|
||||||
|
format='password',
|
||||||
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other: Any) -> bool:
|
||||||
|
return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value()
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(self.get_secret_value())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._secret_value)
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _display(self) -> SecretType:
|
||||||
|
...
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return str(self._display())
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'{self.__class__.__name__}({self._display()!r})'
|
||||||
|
|
||||||
|
|
||||||
|
def secret_display(secret_field: SecretField[Any]) -> str:
|
||||||
|
return '**********' if secret_field.get_secret_value() else ''
|
||||||
|
|
||||||
|
|
||||||
|
class SecretFieldValidator(_fields.CustomValidator, Generic[SecretType]):
|
||||||
|
__slots__ = 'field_type', 'min_length', 'max_length', 'error_prefix'
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, field_type: Type[SecretField[SecretType]], min_length: int | None = None, max_length: int | None = None
|
||||||
|
) -> None:
|
||||||
|
self.field_type: Type[SecretField[SecretType]] = field_type
|
||||||
|
self.min_length = min_length
|
||||||
|
self.max_length = max_length
|
||||||
|
self.error_prefix: Literal['string', 'bytes'] = 'string' if field_type is SecretStr else 'bytes'
|
||||||
|
|
||||||
|
def __call__(self, __value: SecretField[SecretType] | SecretType, _: core_schema.ValidationInfo) -> Any:
|
||||||
|
if self.min_length is not None and len(__value) < self.min_length:
|
||||||
|
short_kind: core_schema.ErrorType = f'{self.error_prefix}_too_short' # type: ignore[assignment]
|
||||||
|
raise PydanticKnownError(short_kind, {'min_length': self.min_length})
|
||||||
|
if self.max_length is not None and len(__value) > self.max_length:
|
||||||
|
long_kind: core_schema.ErrorType = f'{self.error_prefix}_too_long' # type: ignore[assignment]
|
||||||
|
raise PydanticKnownError(long_kind, {'max_length': self.max_length})
|
||||||
|
|
||||||
|
if isinstance(__value, self.field_type):
|
||||||
|
return __value
|
||||||
|
else:
|
||||||
|
return self.field_type(__value) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def __pydantic_update_schema__(self, schema: core_schema.CoreSchema, **constraints: Any) -> None:
|
||||||
|
self._update_attrs(constraints, {'min_length', 'max_length'})
|
||||||
|
|
||||||
|
|
||||||
|
class SecretStr(SecretField[str]):
|
||||||
|
_error_kind = 'string_type'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _pre_core_schema(cls) -> core_schema.CoreSchema:
|
||||||
|
return core_schema.str_schema()
|
||||||
|
|
||||||
|
def _display(self) -> str:
|
||||||
|
return secret_display(self)
|
||||||
|
|
||||||
|
|
||||||
|
class SecretBytes(SecretField[bytes]):
|
||||||
|
_error_kind = 'bytes_type'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _pre_core_schema(cls) -> core_schema.CoreSchema:
|
||||||
|
return core_schema.bytes_schema()
|
||||||
|
|
||||||
|
def _display(self) -> bytes:
|
||||||
|
return secret_display(self).encode()
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentCardBrand(str, Enum):
|
||||||
|
# If you add another card type, please also add it to the
|
||||||
|
# Hypothesis strategy in `pydantic._hypothesis_plugin`.
|
||||||
|
amex = 'American Express'
|
||||||
|
mastercard = 'Mastercard'
|
||||||
|
visa = 'Visa'
|
||||||
|
other = 'other'
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
class PaymentCardNumber(str):
|
||||||
|
"""
|
||||||
|
Based on: https://en.wikipedia.org/wiki/Payment_card_number
|
||||||
|
"""
|
||||||
|
|
||||||
|
strip_whitespace: ClassVar[bool] = True
|
||||||
|
min_length: ClassVar[int] = 12
|
||||||
|
max_length: ClassVar[int] = 19
|
||||||
|
bin: str
|
||||||
|
last4: str
|
||||||
|
brand: PaymentCardBrand
|
||||||
|
|
||||||
|
def __init__(self, card_number: str):
|
||||||
|
self.validate_digits(card_number)
|
||||||
|
|
||||||
|
card_number = self.validate_luhn_check_digit(card_number)
|
||||||
|
|
||||||
|
self.bin = card_number[:6]
|
||||||
|
self.last4 = card_number[-4:]
|
||||||
|
self.brand = self.validate_brand(card_number)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionSchema:
|
||||||
|
return core_schema.function_after_schema(
|
||||||
|
core_schema.str_schema(
|
||||||
|
min_length=cls.min_length, max_length=cls.max_length, strip_whitespace=cls.strip_whitespace
|
||||||
|
),
|
||||||
|
cls.validate,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> 'PaymentCardNumber':
|
||||||
|
return cls(__input_value)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def masked(self) -> str:
|
||||||
|
num_masked = len(self) - 10 # len(bin) + len(last4) == 10
|
||||||
|
return f'{self.bin}{"*" * num_masked}{self.last4}'
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_digits(cls, card_number: str) -> None:
|
||||||
|
if not card_number.isdigit():
|
||||||
|
raise PydanticCustomError('payment_card_number_digits', 'Card number is not all digits')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_luhn_check_digit(cls, card_number: str) -> str:
|
||||||
|
"""
|
||||||
|
Based on: https://en.wikipedia.org/wiki/Luhn_algorithm
|
||||||
|
"""
|
||||||
|
sum_ = int(card_number[-1])
|
||||||
|
length = len(card_number)
|
||||||
|
parity = length % 2
|
||||||
|
for i in range(length - 1):
|
||||||
|
digit = int(card_number[i])
|
||||||
|
if i % 2 == parity:
|
||||||
|
digit *= 2
|
||||||
|
if digit > 9:
|
||||||
|
digit -= 9
|
||||||
|
sum_ += digit
|
||||||
|
valid = sum_ % 10 == 0
|
||||||
|
if not valid:
|
||||||
|
raise PydanticCustomError('payment_card_number_luhn', 'Card number is not luhn valid')
|
||||||
|
return card_number
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_brand(card_number: str) -> PaymentCardBrand:
|
||||||
|
"""
|
||||||
|
Validate length based on BIN for major brands:
|
||||||
|
https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN)
|
||||||
|
"""
|
||||||
|
if card_number[0] == '4':
|
||||||
|
brand = PaymentCardBrand.visa
|
||||||
|
elif 51 <= int(card_number[:2]) <= 55:
|
||||||
|
brand = PaymentCardBrand.mastercard
|
||||||
|
elif card_number[:2] in {'34', '37'}:
|
||||||
|
brand = PaymentCardBrand.amex
|
||||||
|
else:
|
||||||
|
brand = PaymentCardBrand.other
|
||||||
|
|
||||||
|
required_length: Union[None, int, str] = None
|
||||||
|
if brand in PaymentCardBrand.mastercard:
|
||||||
|
required_length = 16
|
||||||
|
valid = len(card_number) == required_length
|
||||||
|
elif brand == PaymentCardBrand.visa:
|
||||||
|
required_length = '13, 16 or 19'
|
||||||
|
valid = len(card_number) in {13, 16, 19}
|
||||||
|
elif brand == PaymentCardBrand.amex:
|
||||||
|
required_length = 15
|
||||||
|
valid = len(card_number) == required_length
|
||||||
|
else:
|
||||||
|
valid = True
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
raise PydanticCustomError(
|
||||||
|
'payment_card_number_brand',
|
||||||
|
'Length for a {brand} card must be {required_length}',
|
||||||
|
{'brand': brand, 'required_length': required_length},
|
||||||
|
)
|
||||||
|
return brand
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
BYTE_SIZES = {
|
||||||
|
'b': 1,
|
||||||
|
'kb': 10**3,
|
||||||
|
'mb': 10**6,
|
||||||
|
'gb': 10**9,
|
||||||
|
'tb': 10**12,
|
||||||
|
'pb': 10**15,
|
||||||
|
'eb': 10**18,
|
||||||
|
'kib': 2**10,
|
||||||
|
'mib': 2**20,
|
||||||
|
'gib': 2**30,
|
||||||
|
'tib': 2**40,
|
||||||
|
'pib': 2**50,
|
||||||
|
'eib': 2**60,
|
||||||
|
}
|
||||||
|
BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k})
|
||||||
|
byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
class ByteSize(int):
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionPlainSchema:
|
||||||
|
# TODO better schema
|
||||||
|
return core_schema.function_plain_schema(cls.validate)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> 'ByteSize':
|
||||||
|
try:
|
||||||
|
return cls(int(__input_value))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
str_match = byte_string_re.match(str(__input_value))
|
||||||
|
if str_match is None:
|
||||||
|
raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string')
|
||||||
|
|
||||||
|
scalar, unit = str_match.groups()
|
||||||
|
if unit is None:
|
||||||
|
unit = 'b'
|
||||||
|
|
||||||
|
try:
|
||||||
|
unit_mult = BYTE_SIZES[unit.lower()]
|
||||||
|
except KeyError:
|
||||||
|
raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit})
|
||||||
|
|
||||||
|
return cls(int(float(scalar) * unit_mult))
|
||||||
|
|
||||||
|
def human_readable(self, decimal: bool = False) -> str:
|
||||||
|
if decimal:
|
||||||
|
divisor = 1000
|
||||||
|
units = 'B', 'KB', 'MB', 'GB', 'TB', 'PB'
|
||||||
|
final_unit = 'EB'
|
||||||
|
else:
|
||||||
|
divisor = 1024
|
||||||
|
units = 'B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'
|
||||||
|
final_unit = 'EiB'
|
||||||
|
|
||||||
|
num = float(self)
|
||||||
|
for unit in units:
|
||||||
|
if abs(num) < divisor:
|
||||||
|
if unit == 'B':
|
||||||
|
return f'{num:0.0f}{unit}'
|
||||||
|
else:
|
||||||
|
return f'{num:0.1f}{unit}'
|
||||||
|
num /= divisor
|
||||||
|
|
||||||
|
return f'{num:0.1f}{final_unit}'
|
||||||
|
|
||||||
|
def to(self, unit: str) -> float:
|
||||||
|
try:
|
||||||
|
unit_div = BYTE_SIZES[unit.lower()]
|
||||||
|
except KeyError:
|
||||||
|
raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit})
|
||||||
|
|
||||||
|
return self / unit_div
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATE TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
PastDate = Annotated[date, ...]
|
||||||
|
FutureDate = Annotated[date, ...]
|
||||||
|
else:
|
||||||
|
|
||||||
|
class PastDate:
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
if schema is None:
|
||||||
|
# used directly as a type
|
||||||
|
return core_schema.date_schema(now_op='past')
|
||||||
|
else:
|
||||||
|
assert schema['type'] == 'date'
|
||||||
|
schema['now_op'] = 'past'
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return 'PastDate'
|
||||||
|
|
||||||
|
class FutureDate:
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
if schema is None:
|
||||||
|
# used directly as a type
|
||||||
|
return core_schema.date_schema(now_op='future')
|
||||||
|
else:
|
||||||
|
assert schema['type'] == 'date'
|
||||||
|
schema['now_op'] = 'future'
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return 'FutureDate'
|
||||||
|
|
||||||
|
|
||||||
|
def condate(*, strict: bool = None, gt: date = None, ge: date = None, lt: date = None, le: date = None) -> type[date]:
|
||||||
|
return Annotated[ # type: ignore[return-value]
|
||||||
|
date,
|
||||||
|
Strict(strict) if strict is not None else None,
|
||||||
|
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATETIME TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
AwareDatetime = Annotated[datetime, ...]
|
||||||
|
NaiveDatetime = Annotated[datetime, ...]
|
||||||
|
else:
|
||||||
|
|
||||||
|
class AwareDatetime:
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
if schema is None:
|
||||||
|
# used directly as a type
|
||||||
|
return core_schema.datetime_schema(tz_constraint='aware')
|
||||||
|
else:
|
||||||
|
assert schema['type'] == 'datetime'
|
||||||
|
schema['tz_constraint'] = 'aware'
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return 'AwareDatetime'
|
||||||
|
|
||||||
|
class NaiveDatetime:
|
||||||
|
@classmethod
|
||||||
|
def __get_pydantic_core_schema__(
|
||||||
|
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
|
||||||
|
) -> core_schema.CoreSchema:
|
||||||
|
if schema is None:
|
||||||
|
# used directly as a type
|
||||||
|
return core_schema.datetime_schema(tz_constraint='naive')
|
||||||
|
else:
|
||||||
|
assert schema['type'] == 'datetime'
|
||||||
|
schema['tz_constraint'] = 'naive'
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return 'NaiveDatetime'
|
161
crates/ruff_benchmark/resources/pypinyin.py
Normal file
161
crates/ruff_benchmark/resources/pypinyin.py
Normal file
|
@ -0,0 +1,161 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
处理汉语拼音方案中的一些特殊情况
|
||||||
|
|
||||||
|
|
||||||
|
汉语拼音方案:
|
||||||
|
|
||||||
|
* https://zh.wiktionary.org/wiki/%E9%99%84%E5%BD%95:%E6%B1%89%E8%AF%AD%E6%8B%BC%E9%9F%B3%E6%96%B9%E6%A1%88
|
||||||
|
* http://www.moe.edu.cn/s78/A19/yxs_left/moe_810/s230/195802/t19580201_186000.html
|
||||||
|
""" # noqa
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from pypinyin.style._constants import _FINALS
|
||||||
|
|
||||||
|
# u -> ü
|
||||||
|
UV_MAP = {
|
||||||
|
'u': 'ü',
|
||||||
|
'ū': 'ǖ',
|
||||||
|
'ú': 'ǘ',
|
||||||
|
'ǔ': 'ǚ',
|
||||||
|
'ù': 'ǜ',
|
||||||
|
}
|
||||||
|
U_TONES = set(UV_MAP.keys())
|
||||||
|
# ü行的韵跟声母j,q,x拼的时候,写成ju(居),qu(区),xu(虚)
|
||||||
|
UV_RE = re.compile(
|
||||||
|
r'^(j|q|x)({tones})(.*)$'.format(
|
||||||
|
tones='|'.join(UV_MAP.keys())))
|
||||||
|
I_TONES = set(['i', 'ī', 'í', 'ǐ', 'ì'])
|
||||||
|
|
||||||
|
# iu -> iou
|
||||||
|
IU_MAP = {
|
||||||
|
'iu': 'iou',
|
||||||
|
'iū': 'ioū',
|
||||||
|
'iú': 'ioú',
|
||||||
|
'iǔ': 'ioǔ',
|
||||||
|
'iù': 'ioù',
|
||||||
|
}
|
||||||
|
IU_TONES = set(IU_MAP.keys())
|
||||||
|
IU_RE = re.compile(r'^([a-z]+)({tones})$'.format(tones='|'.join(IU_TONES)))
|
||||||
|
|
||||||
|
# ui -> uei
|
||||||
|
UI_MAP = {
|
||||||
|
'ui': 'uei',
|
||||||
|
'uī': 'ueī',
|
||||||
|
'uí': 'ueí',
|
||||||
|
'uǐ': 'ueǐ',
|
||||||
|
'uì': 'ueì',
|
||||||
|
}
|
||||||
|
UI_TONES = set(UI_MAP.keys())
|
||||||
|
UI_RE = re.compile(r'([a-z]+)({tones})$'.format(tones='|'.join(UI_TONES)))
|
||||||
|
|
||||||
|
# un -> uen
|
||||||
|
UN_MAP = {
|
||||||
|
'un': 'uen',
|
||||||
|
'ūn': 'ūen',
|
||||||
|
'ún': 'úen',
|
||||||
|
'ǔn': 'ǔen',
|
||||||
|
'ùn': 'ùen',
|
||||||
|
}
|
||||||
|
UN_TONES = set(UN_MAP.keys())
|
||||||
|
UN_RE = re.compile(r'([a-z]+)({tones})$'.format(tones='|'.join(UN_TONES)))
|
||||||
|
|
||||||
|
|
||||||
|
def convert_zero_consonant(pinyin):
|
||||||
|
"""零声母转换,还原原始的韵母
|
||||||
|
|
||||||
|
i行的韵母,前面没有声母的时候,写成yi(衣),ya(呀),ye(耶),yao(腰),
|
||||||
|
you(忧),yan(烟),yin(因),yang(央),ying(英),yong(雍)。
|
||||||
|
|
||||||
|
u行的韵母,前面没有声母的时候,写成wu(乌),wa(蛙),wo(窝),wai(歪),
|
||||||
|
wei(威),wan(弯),wen(温),wang(汪),weng(翁)。
|
||||||
|
|
||||||
|
ü行的韵母,前面没有声母的时候,写成yu(迂),yue(约),yuan(冤),
|
||||||
|
yun(晕);ü上两点省略。
|
||||||
|
"""
|
||||||
|
raw_pinyin = pinyin
|
||||||
|
# y: yu -> v, yi -> i, y -> i
|
||||||
|
if raw_pinyin.startswith('y'):
|
||||||
|
# 去除 y 后的拼音
|
||||||
|
no_y_py = pinyin[1:]
|
||||||
|
first_char = no_y_py[0] if len(no_y_py) > 0 else None
|
||||||
|
|
||||||
|
# yu -> ü: yue -> üe
|
||||||
|
if first_char in U_TONES:
|
||||||
|
pinyin = UV_MAP[first_char] + pinyin[2:]
|
||||||
|
# yi -> i: yi -> i
|
||||||
|
elif first_char in I_TONES:
|
||||||
|
pinyin = no_y_py
|
||||||
|
# y -> i: ya -> ia
|
||||||
|
else:
|
||||||
|
pinyin = 'i' + no_y_py
|
||||||
|
|
||||||
|
# w: wu -> u, w -> u
|
||||||
|
if raw_pinyin.startswith('w'):
|
||||||
|
# 去除 w 后的拼音
|
||||||
|
no_w_py = pinyin[1:]
|
||||||
|
first_char = no_w_py[0] if len(no_w_py) > 0 else None
|
||||||
|
|
||||||
|
# wu -> u: wu -> u
|
||||||
|
if first_char in U_TONES:
|
||||||
|
pinyin = pinyin[1:]
|
||||||
|
# w -> u: wa -> ua
|
||||||
|
else:
|
||||||
|
pinyin = 'u' + pinyin[1:]
|
||||||
|
|
||||||
|
# 确保不会出现韵母表中不存在的韵母
|
||||||
|
if pinyin not in _FINALS:
|
||||||
|
return raw_pinyin
|
||||||
|
|
||||||
|
return pinyin
|
||||||
|
|
||||||
|
|
||||||
|
def convert_uv(pinyin):
|
||||||
|
"""ü 转换,还原原始的韵母
|
||||||
|
|
||||||
|
ü行的韵跟声母j,q,x拼的时候,写成ju(居),qu(区),xu(虚),
|
||||||
|
ü上两点也省略;但是跟声母n,l拼的时候,仍然写成nü(女),lü(吕)。
|
||||||
|
"""
|
||||||
|
return UV_RE.sub(
|
||||||
|
lambda m: ''.join((m.group(1), UV_MAP[m.group(2)], m.group(3))),
|
||||||
|
pinyin)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_iou(pinyin):
|
||||||
|
"""iou 转换,还原原始的韵母
|
||||||
|
|
||||||
|
iou,uei,uen前面加声母的时候,写成iu,ui,un。
|
||||||
|
例如niu(牛),gui(归),lun(论)。
|
||||||
|
"""
|
||||||
|
return IU_RE.sub(lambda m: m.group(1) + IU_MAP[m.group(2)], pinyin)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_uei(pinyin):
|
||||||
|
"""uei 转换,还原原始的韵母
|
||||||
|
|
||||||
|
iou,uei,uen前面加声母的时候,写成iu,ui,un。
|
||||||
|
例如niu(牛),gui(归),lun(论)。
|
||||||
|
"""
|
||||||
|
return UI_RE.sub(lambda m: m.group(1) + UI_MAP[m.group(2)], pinyin)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_uen(pinyin):
|
||||||
|
"""uen 转换,还原原始的韵母
|
||||||
|
|
||||||
|
iou,uei,uen前面加声母的时候,写成iu,ui,un。
|
||||||
|
例如niu(牛),gui(归),lun(论)。
|
||||||
|
"""
|
||||||
|
return UN_RE.sub(lambda m: m.group(1) + UN_MAP[m.group(2)], pinyin)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_finals(pinyin):
|
||||||
|
"""还原原始的韵母"""
|
||||||
|
pinyin = convert_zero_consonant(pinyin)
|
||||||
|
pinyin = convert_uv(pinyin)
|
||||||
|
pinyin = convert_iou(pinyin)
|
||||||
|
pinyin = convert_uei(pinyin)
|
||||||
|
pinyin = convert_uen(pinyin)
|
||||||
|
return pinyin
|
10
crates/ruff_benchmark/resources/tomllib/__init__.py
Normal file
10
crates/ruff_benchmark/resources/tomllib/__init__.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
|
||||||
|
# Licensed to PSF under a Contributor Agreement.
|
||||||
|
|
||||||
|
__all__ = ("loads", "load", "TOMLDecodeError")
|
||||||
|
|
||||||
|
from ._parser import TOMLDecodeError, load, loads
|
||||||
|
|
||||||
|
# Pretend this exception was created here.
|
||||||
|
TOMLDecodeError.__module__ = __name__
|
691
crates/ruff_benchmark/resources/tomllib/_parser.py
Normal file
691
crates/ruff_benchmark/resources/tomllib/_parser.py
Normal file
|
@ -0,0 +1,691 @@
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
|
||||||
|
# Licensed to PSF under a Contributor Agreement.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterable
|
||||||
|
import string
|
||||||
|
from types import MappingProxyType
|
||||||
|
from typing import Any, BinaryIO, NamedTuple
|
||||||
|
|
||||||
|
from ._re import (
|
||||||
|
RE_DATETIME,
|
||||||
|
RE_LOCALTIME,
|
||||||
|
RE_NUMBER,
|
||||||
|
match_to_datetime,
|
||||||
|
match_to_localtime,
|
||||||
|
match_to_number,
|
||||||
|
)
|
||||||
|
from ._types import Key, ParseFloat, Pos
|
||||||
|
|
||||||
|
ASCII_CTRL = frozenset(chr(i) for i in range(32)) | frozenset(chr(127))
|
||||||
|
|
||||||
|
# Neither of these sets include quotation mark or backslash. They are
|
||||||
|
# currently handled as separate cases in the parser functions.
|
||||||
|
ILLEGAL_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t")
|
||||||
|
ILLEGAL_MULTILINE_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t\n")
|
||||||
|
|
||||||
|
ILLEGAL_LITERAL_STR_CHARS = ILLEGAL_BASIC_STR_CHARS
|
||||||
|
ILLEGAL_MULTILINE_LITERAL_STR_CHARS = ILLEGAL_MULTILINE_BASIC_STR_CHARS
|
||||||
|
|
||||||
|
ILLEGAL_COMMENT_CHARS = ILLEGAL_BASIC_STR_CHARS
|
||||||
|
|
||||||
|
TOML_WS = frozenset(" \t")
|
||||||
|
TOML_WS_AND_NEWLINE = TOML_WS | frozenset("\n")
|
||||||
|
BARE_KEY_CHARS = frozenset(string.ascii_letters + string.digits + "-_")
|
||||||
|
KEY_INITIAL_CHARS = BARE_KEY_CHARS | frozenset("\"'")
|
||||||
|
HEXDIGIT_CHARS = frozenset(string.hexdigits)
|
||||||
|
|
||||||
|
BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType(
|
||||||
|
{
|
||||||
|
"\\b": "\u0008", # backspace
|
||||||
|
"\\t": "\u0009", # tab
|
||||||
|
"\\n": "\u000A", # linefeed
|
||||||
|
"\\f": "\u000C", # form feed
|
||||||
|
"\\r": "\u000D", # carriage return
|
||||||
|
'\\"': "\u0022", # quote
|
||||||
|
"\\\\": "\u005C", # backslash
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TOMLDecodeError(ValueError):
|
||||||
|
"""An error raised if a document is not valid TOML."""
|
||||||
|
|
||||||
|
|
||||||
|
def load(fp: BinaryIO, /, *, parse_float: ParseFloat = float) -> dict[str, Any]:
|
||||||
|
"""Parse TOML from a binary file object."""
|
||||||
|
b = fp.read()
|
||||||
|
try:
|
||||||
|
s = b.decode()
|
||||||
|
except AttributeError:
|
||||||
|
raise TypeError(
|
||||||
|
"File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`"
|
||||||
|
) from None
|
||||||
|
return loads(s, parse_float=parse_float)
|
||||||
|
|
||||||
|
|
||||||
|
def loads(s: str, /, *, parse_float: ParseFloat = float) -> dict[str, Any]: # noqa: C901
|
||||||
|
"""Parse TOML from a string."""
|
||||||
|
|
||||||
|
# The spec allows converting "\r\n" to "\n", even in string
|
||||||
|
# literals. Let's do so to simplify parsing.
|
||||||
|
src = s.replace("\r\n", "\n")
|
||||||
|
pos = 0
|
||||||
|
out = Output(NestedDict(), Flags())
|
||||||
|
header: Key = ()
|
||||||
|
parse_float = make_safe_parse_float(parse_float)
|
||||||
|
|
||||||
|
# Parse one statement at a time
|
||||||
|
# (typically means one line in TOML source)
|
||||||
|
while True:
|
||||||
|
# 1. Skip line leading whitespace
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
|
||||||
|
# 2. Parse rules. Expect one of the following:
|
||||||
|
# - end of file
|
||||||
|
# - end of line
|
||||||
|
# - comment
|
||||||
|
# - key/value pair
|
||||||
|
# - append dict to list (and move to its namespace)
|
||||||
|
# - create dict (and move to its namespace)
|
||||||
|
# Skip trailing whitespace when applicable.
|
||||||
|
try:
|
||||||
|
char = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
if char == "\n":
|
||||||
|
pos += 1
|
||||||
|
continue
|
||||||
|
if char in KEY_INITIAL_CHARS:
|
||||||
|
pos = key_value_rule(src, pos, out, header, parse_float)
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
elif char == "[":
|
||||||
|
try:
|
||||||
|
second_char: str | None = src[pos + 1]
|
||||||
|
except IndexError:
|
||||||
|
second_char = None
|
||||||
|
out.flags.finalize_pending()
|
||||||
|
if second_char == "[":
|
||||||
|
pos, header = create_list_rule(src, pos, out)
|
||||||
|
else:
|
||||||
|
pos, header = create_dict_rule(src, pos, out)
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
elif char != "#":
|
||||||
|
raise suffixed_err(src, pos, "Invalid statement")
|
||||||
|
|
||||||
|
# 3. Skip comment
|
||||||
|
pos = skip_comment(src, pos)
|
||||||
|
|
||||||
|
# 4. Expect end of line or end of file
|
||||||
|
try:
|
||||||
|
char = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
break
|
||||||
|
if char != "\n":
|
||||||
|
raise suffixed_err(
|
||||||
|
src, pos, "Expected newline or end of document after a statement"
|
||||||
|
)
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
return out.data.dict
|
||||||
|
|
||||||
|
|
||||||
|
class Flags:
|
||||||
|
"""Flags that map to parsed keys/namespaces."""
|
||||||
|
|
||||||
|
# Marks an immutable namespace (inline array or inline table).
|
||||||
|
FROZEN = 0
|
||||||
|
# Marks a nest that has been explicitly created and can no longer
|
||||||
|
# be opened using the "[table]" syntax.
|
||||||
|
EXPLICIT_NEST = 1
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._flags: dict[str, dict] = {}
|
||||||
|
self._pending_flags: set[tuple[Key, int]] = set()
|
||||||
|
|
||||||
|
def add_pending(self, key: Key, flag: int) -> None:
|
||||||
|
self._pending_flags.add((key, flag))
|
||||||
|
|
||||||
|
def finalize_pending(self) -> None:
|
||||||
|
for key, flag in self._pending_flags:
|
||||||
|
self.set(key, flag, recursive=False)
|
||||||
|
self._pending_flags.clear()
|
||||||
|
|
||||||
|
def unset_all(self, key: Key) -> None:
|
||||||
|
cont = self._flags
|
||||||
|
for k in key[:-1]:
|
||||||
|
if k not in cont:
|
||||||
|
return
|
||||||
|
cont = cont[k]["nested"]
|
||||||
|
cont.pop(key[-1], None)
|
||||||
|
|
||||||
|
def set(self, key: Key, flag: int, *, recursive: bool) -> None: # noqa: A003
|
||||||
|
cont = self._flags
|
||||||
|
key_parent, key_stem = key[:-1], key[-1]
|
||||||
|
for k in key_parent:
|
||||||
|
if k not in cont:
|
||||||
|
cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}}
|
||||||
|
cont = cont[k]["nested"]
|
||||||
|
if key_stem not in cont:
|
||||||
|
cont[key_stem] = {"flags": set(), "recursive_flags": set(), "nested": {}}
|
||||||
|
cont[key_stem]["recursive_flags" if recursive else "flags"].add(flag)
|
||||||
|
|
||||||
|
def is_(self, key: Key, flag: int) -> bool:
|
||||||
|
if not key:
|
||||||
|
return False # document root has no flags
|
||||||
|
cont = self._flags
|
||||||
|
for k in key[:-1]:
|
||||||
|
if k not in cont:
|
||||||
|
return False
|
||||||
|
inner_cont = cont[k]
|
||||||
|
if flag in inner_cont["recursive_flags"]:
|
||||||
|
return True
|
||||||
|
cont = inner_cont["nested"]
|
||||||
|
key_stem = key[-1]
|
||||||
|
if key_stem in cont:
|
||||||
|
cont = cont[key_stem]
|
||||||
|
return flag in cont["flags"] or flag in cont["recursive_flags"]
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class NestedDict:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# The parsed content of the TOML document
|
||||||
|
self.dict: dict[str, Any] = {}
|
||||||
|
|
||||||
|
def get_or_create_nest(
|
||||||
|
self,
|
||||||
|
key: Key,
|
||||||
|
*,
|
||||||
|
access_lists: bool = True,
|
||||||
|
) -> dict:
|
||||||
|
cont: Any = self.dict
|
||||||
|
for k in key:
|
||||||
|
if k not in cont:
|
||||||
|
cont[k] = {}
|
||||||
|
cont = cont[k]
|
||||||
|
if access_lists and isinstance(cont, list):
|
||||||
|
cont = cont[-1]
|
||||||
|
if not isinstance(cont, dict):
|
||||||
|
raise KeyError("There is no nest behind this key")
|
||||||
|
return cont
|
||||||
|
|
||||||
|
def append_nest_to_list(self, key: Key) -> None:
|
||||||
|
cont = self.get_or_create_nest(key[:-1])
|
||||||
|
last_key = key[-1]
|
||||||
|
if last_key in cont:
|
||||||
|
list_ = cont[last_key]
|
||||||
|
if not isinstance(list_, list):
|
||||||
|
raise KeyError("An object other than list found behind this key")
|
||||||
|
list_.append({})
|
||||||
|
else:
|
||||||
|
cont[last_key] = [{}]
|
||||||
|
|
||||||
|
|
||||||
|
class Output(NamedTuple):
|
||||||
|
data: NestedDict
|
||||||
|
flags: Flags
|
||||||
|
|
||||||
|
|
||||||
|
def skip_chars(src: str, pos: Pos, chars: Iterable[str]) -> Pos:
|
||||||
|
try:
|
||||||
|
while src[pos] in chars:
|
||||||
|
pos += 1
|
||||||
|
except IndexError:
|
||||||
|
pass
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def skip_until(
|
||||||
|
src: str,
|
||||||
|
pos: Pos,
|
||||||
|
expect: str,
|
||||||
|
*,
|
||||||
|
error_on: frozenset[str],
|
||||||
|
error_on_eof: bool,
|
||||||
|
) -> Pos:
|
||||||
|
try:
|
||||||
|
new_pos = src.index(expect, pos)
|
||||||
|
except ValueError:
|
||||||
|
new_pos = len(src)
|
||||||
|
if error_on_eof:
|
||||||
|
raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None
|
||||||
|
|
||||||
|
if not error_on.isdisjoint(src[pos:new_pos]):
|
||||||
|
while src[pos] not in error_on:
|
||||||
|
pos += 1
|
||||||
|
raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}")
|
||||||
|
return new_pos
|
||||||
|
|
||||||
|
|
||||||
|
def skip_comment(src: str, pos: Pos) -> Pos:
|
||||||
|
try:
|
||||||
|
char: str | None = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
char = None
|
||||||
|
if char == "#":
|
||||||
|
return skip_until(
|
||||||
|
src, pos + 1, "\n", error_on=ILLEGAL_COMMENT_CHARS, error_on_eof=False
|
||||||
|
)
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def skip_comments_and_array_ws(src: str, pos: Pos) -> Pos:
|
||||||
|
while True:
|
||||||
|
pos_before_skip = pos
|
||||||
|
pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
|
||||||
|
pos = skip_comment(src, pos)
|
||||||
|
if pos == pos_before_skip:
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def create_dict_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
|
||||||
|
pos += 1 # Skip "["
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
pos, key = parse_key(src, pos)
|
||||||
|
|
||||||
|
if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN):
|
||||||
|
raise suffixed_err(src, pos, f"Cannot declare {key} twice")
|
||||||
|
out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
|
||||||
|
try:
|
||||||
|
out.data.get_or_create_nest(key)
|
||||||
|
except KeyError:
|
||||||
|
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
|
||||||
|
|
||||||
|
if not src.startswith("]", pos):
|
||||||
|
raise suffixed_err(src, pos, "Expected ']' at the end of a table declaration")
|
||||||
|
return pos + 1, key
|
||||||
|
|
||||||
|
|
||||||
|
def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
|
||||||
|
pos += 2 # Skip "[["
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
pos, key = parse_key(src, pos)
|
||||||
|
|
||||||
|
if out.flags.is_(key, Flags.FROZEN):
|
||||||
|
raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
|
||||||
|
# Free the namespace now that it points to another empty list item...
|
||||||
|
out.flags.unset_all(key)
|
||||||
|
# ...but this key precisely is still prohibited from table declaration
|
||||||
|
out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
|
||||||
|
try:
|
||||||
|
out.data.append_nest_to_list(key)
|
||||||
|
except KeyError:
|
||||||
|
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
|
||||||
|
|
||||||
|
if not src.startswith("]]", pos):
|
||||||
|
raise suffixed_err(src, pos, "Expected ']]' at the end of an array declaration")
|
||||||
|
return pos + 2, key
|
||||||
|
|
||||||
|
|
||||||
|
def key_value_rule(
|
||||||
|
src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat
|
||||||
|
) -> Pos:
|
||||||
|
pos, key, value = parse_key_value_pair(src, pos, parse_float)
|
||||||
|
key_parent, key_stem = key[:-1], key[-1]
|
||||||
|
abs_key_parent = header + key_parent
|
||||||
|
|
||||||
|
relative_path_cont_keys = (header + key[:i] for i in range(1, len(key)))
|
||||||
|
for cont_key in relative_path_cont_keys:
|
||||||
|
# Check that dotted key syntax does not redefine an existing table
|
||||||
|
if out.flags.is_(cont_key, Flags.EXPLICIT_NEST):
|
||||||
|
raise suffixed_err(src, pos, f"Cannot redefine namespace {cont_key}")
|
||||||
|
# Containers in the relative path can't be opened with the table syntax or
|
||||||
|
# dotted key/value syntax in following table sections.
|
||||||
|
out.flags.add_pending(cont_key, Flags.EXPLICIT_NEST)
|
||||||
|
|
||||||
|
if out.flags.is_(abs_key_parent, Flags.FROZEN):
|
||||||
|
raise suffixed_err(
|
||||||
|
src, pos, f"Cannot mutate immutable namespace {abs_key_parent}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
nest = out.data.get_or_create_nest(abs_key_parent)
|
||||||
|
except KeyError:
|
||||||
|
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
|
||||||
|
if key_stem in nest:
|
||||||
|
raise suffixed_err(src, pos, "Cannot overwrite a value")
|
||||||
|
# Mark inline table and array namespaces recursively immutable
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
out.flags.set(header + key, Flags.FROZEN, recursive=True)
|
||||||
|
nest[key_stem] = value
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def parse_key_value_pair(
|
||||||
|
src: str, pos: Pos, parse_float: ParseFloat
|
||||||
|
) -> tuple[Pos, Key, Any]:
|
||||||
|
pos, key = parse_key(src, pos)
|
||||||
|
try:
|
||||||
|
char: str | None = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
char = None
|
||||||
|
if char != "=":
|
||||||
|
raise suffixed_err(src, pos, "Expected '=' after a key in a key/value pair")
|
||||||
|
pos += 1
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
pos, value = parse_value(src, pos, parse_float)
|
||||||
|
return pos, key, value
|
||||||
|
|
||||||
|
|
||||||
|
def parse_key(src: str, pos: Pos) -> tuple[Pos, Key]:
|
||||||
|
pos, key_part = parse_key_part(src, pos)
|
||||||
|
key: Key = (key_part,)
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
char: str | None = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
char = None
|
||||||
|
if char != ".":
|
||||||
|
return pos, key
|
||||||
|
pos += 1
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
pos, key_part = parse_key_part(src, pos)
|
||||||
|
key += (key_part,)
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]:
|
||||||
|
try:
|
||||||
|
char: str | None = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
char = None
|
||||||
|
if char in BARE_KEY_CHARS:
|
||||||
|
start_pos = pos
|
||||||
|
pos = skip_chars(src, pos, BARE_KEY_CHARS)
|
||||||
|
return pos, src[start_pos:pos]
|
||||||
|
if char == "'":
|
||||||
|
return parse_literal_str(src, pos)
|
||||||
|
if char == '"':
|
||||||
|
return parse_one_line_basic_str(src, pos)
|
||||||
|
raise suffixed_err(src, pos, "Invalid initial character for a key part")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]:
|
||||||
|
pos += 1
|
||||||
|
return parse_basic_str(src, pos, multiline=False)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]:
|
||||||
|
pos += 1
|
||||||
|
array: list = []
|
||||||
|
|
||||||
|
pos = skip_comments_and_array_ws(src, pos)
|
||||||
|
if src.startswith("]", pos):
|
||||||
|
return pos + 1, array
|
||||||
|
while True:
|
||||||
|
pos, val = parse_value(src, pos, parse_float)
|
||||||
|
array.append(val)
|
||||||
|
pos = skip_comments_and_array_ws(src, pos)
|
||||||
|
|
||||||
|
c = src[pos : pos + 1]
|
||||||
|
if c == "]":
|
||||||
|
return pos + 1, array
|
||||||
|
if c != ",":
|
||||||
|
raise suffixed_err(src, pos, "Unclosed array")
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
pos = skip_comments_and_array_ws(src, pos)
|
||||||
|
if src.startswith("]", pos):
|
||||||
|
return pos + 1, array
|
||||||
|
|
||||||
|
|
||||||
|
def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict]:
|
||||||
|
pos += 1
|
||||||
|
nested_dict = NestedDict()
|
||||||
|
flags = Flags()
|
||||||
|
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
if src.startswith("}", pos):
|
||||||
|
return pos + 1, nested_dict.dict
|
||||||
|
while True:
|
||||||
|
pos, key, value = parse_key_value_pair(src, pos, parse_float)
|
||||||
|
key_parent, key_stem = key[:-1], key[-1]
|
||||||
|
if flags.is_(key, Flags.FROZEN):
|
||||||
|
raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
|
||||||
|
try:
|
||||||
|
nest = nested_dict.get_or_create_nest(key_parent, access_lists=False)
|
||||||
|
except KeyError:
|
||||||
|
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
|
||||||
|
if key_stem in nest:
|
||||||
|
raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}")
|
||||||
|
nest[key_stem] = value
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
c = src[pos : pos + 1]
|
||||||
|
if c == "}":
|
||||||
|
return pos + 1, nested_dict.dict
|
||||||
|
if c != ",":
|
||||||
|
raise suffixed_err(src, pos, "Unclosed inline table")
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
flags.set(key, Flags.FROZEN, recursive=True)
|
||||||
|
pos += 1
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_basic_str_escape(
|
||||||
|
src: str, pos: Pos, *, multiline: bool = False
|
||||||
|
) -> tuple[Pos, str]:
|
||||||
|
escape_id = src[pos : pos + 2]
|
||||||
|
pos += 2
|
||||||
|
if multiline and escape_id in {"\\ ", "\\\t", "\\\n"}:
|
||||||
|
# Skip whitespace until next non-whitespace character or end of
|
||||||
|
# the doc. Error if non-whitespace is found before newline.
|
||||||
|
if escape_id != "\\\n":
|
||||||
|
pos = skip_chars(src, pos, TOML_WS)
|
||||||
|
try:
|
||||||
|
char = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
return pos, ""
|
||||||
|
if char != "\n":
|
||||||
|
raise suffixed_err(src, pos, "Unescaped '\\' in a string")
|
||||||
|
pos += 1
|
||||||
|
pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
|
||||||
|
return pos, ""
|
||||||
|
if escape_id == "\\u":
|
||||||
|
return parse_hex_char(src, pos, 4)
|
||||||
|
if escape_id == "\\U":
|
||||||
|
return parse_hex_char(src, pos, 8)
|
||||||
|
try:
|
||||||
|
return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
|
||||||
|
except KeyError:
|
||||||
|
raise suffixed_err(src, pos, "Unescaped '\\' in a string") from None
|
||||||
|
|
||||||
|
|
||||||
|
def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]:
|
||||||
|
return parse_basic_str_escape(src, pos, multiline=True)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]:
|
||||||
|
hex_str = src[pos : pos + hex_len]
|
||||||
|
if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str):
|
||||||
|
raise suffixed_err(src, pos, "Invalid hex value")
|
||||||
|
pos += hex_len
|
||||||
|
hex_int = int(hex_str, 16)
|
||||||
|
if not is_unicode_scalar_value(hex_int):
|
||||||
|
raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value")
|
||||||
|
return pos, chr(hex_int)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_literal_str(src: str, pos: Pos) -> tuple[Pos, str]:
|
||||||
|
pos += 1 # Skip starting apostrophe
|
||||||
|
start_pos = pos
|
||||||
|
pos = skip_until(
|
||||||
|
src, pos, "'", error_on=ILLEGAL_LITERAL_STR_CHARS, error_on_eof=True
|
||||||
|
)
|
||||||
|
return pos + 1, src[start_pos:pos] # Skip ending apostrophe
|
||||||
|
|
||||||
|
|
||||||
|
def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]:
|
||||||
|
pos += 3
|
||||||
|
if src.startswith("\n", pos):
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
if literal:
|
||||||
|
delim = "'"
|
||||||
|
end_pos = skip_until(
|
||||||
|
src,
|
||||||
|
pos,
|
||||||
|
"'''",
|
||||||
|
error_on=ILLEGAL_MULTILINE_LITERAL_STR_CHARS,
|
||||||
|
error_on_eof=True,
|
||||||
|
)
|
||||||
|
result = src[pos:end_pos]
|
||||||
|
pos = end_pos + 3
|
||||||
|
else:
|
||||||
|
delim = '"'
|
||||||
|
pos, result = parse_basic_str(src, pos, multiline=True)
|
||||||
|
|
||||||
|
# Add at maximum two extra apostrophes/quotes if the end sequence
|
||||||
|
# is 4 or 5 chars long instead of just 3.
|
||||||
|
if not src.startswith(delim, pos):
|
||||||
|
return pos, result
|
||||||
|
pos += 1
|
||||||
|
if not src.startswith(delim, pos):
|
||||||
|
return pos, result + delim
|
||||||
|
pos += 1
|
||||||
|
return pos, result + (delim * 2)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
|
||||||
|
if multiline:
|
||||||
|
error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS
|
||||||
|
parse_escapes = parse_basic_str_escape_multiline
|
||||||
|
else:
|
||||||
|
error_on = ILLEGAL_BASIC_STR_CHARS
|
||||||
|
parse_escapes = parse_basic_str_escape
|
||||||
|
result = ""
|
||||||
|
start_pos = pos
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
char = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
raise suffixed_err(src, pos, "Unterminated string") from None
|
||||||
|
if char == '"':
|
||||||
|
if not multiline:
|
||||||
|
return pos + 1, result + src[start_pos:pos]
|
||||||
|
if src.startswith('"""', pos):
|
||||||
|
return pos + 3, result + src[start_pos:pos]
|
||||||
|
pos += 1
|
||||||
|
continue
|
||||||
|
if char == "\\":
|
||||||
|
result += src[start_pos:pos]
|
||||||
|
pos, parsed_escape = parse_escapes(src, pos)
|
||||||
|
result += parsed_escape
|
||||||
|
start_pos = pos
|
||||||
|
continue
|
||||||
|
if char in error_on:
|
||||||
|
raise suffixed_err(src, pos, f"Illegal character {char!r}")
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
|
||||||
|
def parse_value( # noqa: C901
|
||||||
|
src: str, pos: Pos, parse_float: ParseFloat
|
||||||
|
) -> tuple[Pos, Any]:
|
||||||
|
try:
|
||||||
|
char: str | None = src[pos]
|
||||||
|
except IndexError:
|
||||||
|
char = None
|
||||||
|
|
||||||
|
# IMPORTANT: order conditions based on speed of checking and likelihood
|
||||||
|
|
||||||
|
# Basic strings
|
||||||
|
if char == '"':
|
||||||
|
if src.startswith('"""', pos):
|
||||||
|
return parse_multiline_str(src, pos, literal=False)
|
||||||
|
return parse_one_line_basic_str(src, pos)
|
||||||
|
|
||||||
|
# Literal strings
|
||||||
|
if char == "'":
|
||||||
|
if src.startswith("'''", pos):
|
||||||
|
return parse_multiline_str(src, pos, literal=True)
|
||||||
|
return parse_literal_str(src, pos)
|
||||||
|
|
||||||
|
# Booleans
|
||||||
|
if char == "t":
|
||||||
|
if src.startswith("true", pos):
|
||||||
|
return pos + 4, True
|
||||||
|
if char == "f":
|
||||||
|
if src.startswith("false", pos):
|
||||||
|
return pos + 5, False
|
||||||
|
|
||||||
|
# Arrays
|
||||||
|
if char == "[":
|
||||||
|
return parse_array(src, pos, parse_float)
|
||||||
|
|
||||||
|
# Inline tables
|
||||||
|
if char == "{":
|
||||||
|
return parse_inline_table(src, pos, parse_float)
|
||||||
|
|
||||||
|
# Dates and times
|
||||||
|
datetime_match = RE_DATETIME.match(src, pos)
|
||||||
|
if datetime_match:
|
||||||
|
try:
|
||||||
|
datetime_obj = match_to_datetime(datetime_match)
|
||||||
|
except ValueError as e:
|
||||||
|
raise suffixed_err(src, pos, "Invalid date or datetime") from e
|
||||||
|
return datetime_match.end(), datetime_obj
|
||||||
|
localtime_match = RE_LOCALTIME.match(src, pos)
|
||||||
|
if localtime_match:
|
||||||
|
return localtime_match.end(), match_to_localtime(localtime_match)
|
||||||
|
|
||||||
|
# Integers and "normal" floats.
|
||||||
|
# The regex will greedily match any type starting with a decimal
|
||||||
|
# char, so needs to be located after handling of dates and times.
|
||||||
|
number_match = RE_NUMBER.match(src, pos)
|
||||||
|
if number_match:
|
||||||
|
return number_match.end(), match_to_number(number_match, parse_float)
|
||||||
|
|
||||||
|
# Special floats
|
||||||
|
first_three = src[pos : pos + 3]
|
||||||
|
if first_three in {"inf", "nan"}:
|
||||||
|
return pos + 3, parse_float(first_three)
|
||||||
|
first_four = src[pos : pos + 4]
|
||||||
|
if first_four in {"-inf", "+inf", "-nan", "+nan"}:
|
||||||
|
return pos + 4, parse_float(first_four)
|
||||||
|
|
||||||
|
raise suffixed_err(src, pos, "Invalid value")
|
||||||
|
|
||||||
|
|
||||||
|
def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError:
|
||||||
|
"""Return a `TOMLDecodeError` where error message is suffixed with
|
||||||
|
coordinates in source."""
|
||||||
|
|
||||||
|
def coord_repr(src: str, pos: Pos) -> str:
|
||||||
|
if pos >= len(src):
|
||||||
|
return "end of document"
|
||||||
|
line = src.count("\n", 0, pos) + 1
|
||||||
|
if line == 1:
|
||||||
|
column = pos + 1
|
||||||
|
else:
|
||||||
|
column = pos - src.rindex("\n", 0, pos)
|
||||||
|
return f"line {line}, column {column}"
|
||||||
|
|
||||||
|
return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})")
|
||||||
|
|
||||||
|
|
||||||
|
def is_unicode_scalar_value(codepoint: int) -> bool:
|
||||||
|
return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111)
|
||||||
|
|
||||||
|
|
||||||
|
def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat:
|
||||||
|
"""A decorator to make `parse_float` safe.
|
||||||
|
|
||||||
|
`parse_float` must not return dicts or lists, because these types
|
||||||
|
would be mixed with parsed TOML tables and arrays, thus confusing
|
||||||
|
the parser. The returned decorated callable raises `ValueError`
|
||||||
|
instead of returning illegal types.
|
||||||
|
"""
|
||||||
|
# The default `float` callable never returns illegal types. Optimize it.
|
||||||
|
if parse_float is float: # type: ignore[comparison-overlap]
|
||||||
|
return float
|
||||||
|
|
||||||
|
def safe_parse_float(float_str: str) -> Any:
|
||||||
|
float_value = parse_float(float_str)
|
||||||
|
if isinstance(float_value, (dict, list)):
|
||||||
|
raise ValueError("parse_float must not return dicts or lists")
|
||||||
|
return float_value
|
||||||
|
|
||||||
|
return safe_parse_float
|
107
crates/ruff_benchmark/resources/tomllib/_re.py
Normal file
107
crates/ruff_benchmark/resources/tomllib/_re.py
Normal file
|
@ -0,0 +1,107 @@
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
|
||||||
|
# Licensed to PSF under a Contributor Agreement.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, datetime, time, timedelta, timezone, tzinfo
|
||||||
|
from functools import lru_cache
|
||||||
|
import re
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ._types import ParseFloat
|
||||||
|
|
||||||
|
# E.g.
|
||||||
|
# - 00:32:00.999999
|
||||||
|
# - 00:32:00
|
||||||
|
_TIME_RE_STR = r"([01][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])(?:\.([0-9]{1,6})[0-9]*)?"
|
||||||
|
|
||||||
|
RE_NUMBER = re.compile(
|
||||||
|
r"""
|
||||||
|
0
|
||||||
|
(?:
|
||||||
|
x[0-9A-Fa-f](?:_?[0-9A-Fa-f])* # hex
|
||||||
|
|
|
||||||
|
b[01](?:_?[01])* # bin
|
||||||
|
|
|
||||||
|
o[0-7](?:_?[0-7])* # oct
|
||||||
|
)
|
||||||
|
|
|
||||||
|
[+-]?(?:0|[1-9](?:_?[0-9])*) # dec, integer part
|
||||||
|
(?P<floatpart>
|
||||||
|
(?:\.[0-9](?:_?[0-9])*)? # optional fractional part
|
||||||
|
(?:[eE][+-]?[0-9](?:_?[0-9])*)? # optional exponent part
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
flags=re.VERBOSE,
|
||||||
|
)
|
||||||
|
RE_LOCALTIME = re.compile(_TIME_RE_STR)
|
||||||
|
RE_DATETIME = re.compile(
|
||||||
|
rf"""
|
||||||
|
([0-9]{{4}})-(0[1-9]|1[0-2])-(0[1-9]|[12][0-9]|3[01]) # date, e.g. 1988-10-27
|
||||||
|
(?:
|
||||||
|
[Tt ]
|
||||||
|
{_TIME_RE_STR}
|
||||||
|
(?:([Zz])|([+-])([01][0-9]|2[0-3]):([0-5][0-9]))? # optional time offset
|
||||||
|
)?
|
||||||
|
""",
|
||||||
|
flags=re.VERBOSE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def match_to_datetime(match: re.Match) -> datetime | date:
|
||||||
|
"""Convert a `RE_DATETIME` match to `datetime.datetime` or `datetime.date`.
|
||||||
|
|
||||||
|
Raises ValueError if the match does not correspond to a valid date
|
||||||
|
or datetime.
|
||||||
|
"""
|
||||||
|
(
|
||||||
|
year_str,
|
||||||
|
month_str,
|
||||||
|
day_str,
|
||||||
|
hour_str,
|
||||||
|
minute_str,
|
||||||
|
sec_str,
|
||||||
|
micros_str,
|
||||||
|
zulu_time,
|
||||||
|
offset_sign_str,
|
||||||
|
offset_hour_str,
|
||||||
|
offset_minute_str,
|
||||||
|
) = match.groups()
|
||||||
|
year, month, day = int(year_str), int(month_str), int(day_str)
|
||||||
|
if hour_str is None:
|
||||||
|
return date(year, month, day)
|
||||||
|
hour, minute, sec = int(hour_str), int(minute_str), int(sec_str)
|
||||||
|
micros = int(micros_str.ljust(6, "0")) if micros_str else 0
|
||||||
|
if offset_sign_str:
|
||||||
|
tz: tzinfo | None = cached_tz(
|
||||||
|
offset_hour_str, offset_minute_str, offset_sign_str
|
||||||
|
)
|
||||||
|
elif zulu_time:
|
||||||
|
tz = timezone.utc
|
||||||
|
else: # local date-time
|
||||||
|
tz = None
|
||||||
|
return datetime(year, month, day, hour, minute, sec, micros, tzinfo=tz)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=None)
|
||||||
|
def cached_tz(hour_str: str, minute_str: str, sign_str: str) -> timezone:
|
||||||
|
sign = 1 if sign_str == "+" else -1
|
||||||
|
return timezone(
|
||||||
|
timedelta(
|
||||||
|
hours=sign * int(hour_str),
|
||||||
|
minutes=sign * int(minute_str),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def match_to_localtime(match: re.Match) -> time:
|
||||||
|
hour_str, minute_str, sec_str, micros_str = match.groups()
|
||||||
|
micros = int(micros_str.ljust(6, "0")) if micros_str else 0
|
||||||
|
return time(int(hour_str), int(minute_str), int(sec_str), micros)
|
||||||
|
|
||||||
|
|
||||||
|
def match_to_number(match: re.Match, parse_float: ParseFloat) -> Any:
|
||||||
|
if match.group("floatpart"):
|
||||||
|
return parse_float(match.group())
|
||||||
|
return int(match.group(), 0)
|
10
crates/ruff_benchmark/resources/tomllib/_types.py
Normal file
10
crates/ruff_benchmark/resources/tomllib/_types.py
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
# SPDX-License-Identifier: MIT
|
||||||
|
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
|
||||||
|
# Licensed to PSF under a Contributor Agreement.
|
||||||
|
|
||||||
|
from typing import Any, Callable, Tuple
|
||||||
|
|
||||||
|
# Type annotations
|
||||||
|
ParseFloat = Callable[[str], Any]
|
||||||
|
Key = Tuple[str, ...]
|
||||||
|
Pos = int
|
|
@ -1,10 +1,32 @@
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
pub mod criterion;
|
pub mod criterion;
|
||||||
|
|
||||||
use std::fmt::{Display, Formatter};
|
pub static NUMPY_GLOBALS: TestFile = TestFile::new(
|
||||||
use std::path::PathBuf;
|
"numpy/globals.py",
|
||||||
use std::process::Command;
|
include_str!("../resources/numpy/globals.py"),
|
||||||
|
);
|
||||||
|
|
||||||
use url::Url;
|
pub static UNICODE_PYPINYIN: TestFile = TestFile::new(
|
||||||
|
"unicode/pypinyin.py",
|
||||||
|
include_str!("../resources/pypinyin.py"),
|
||||||
|
);
|
||||||
|
|
||||||
|
pub static PYDANTIC_TYPES: TestFile = TestFile::new(
|
||||||
|
"pydantic/types.py",
|
||||||
|
include_str!("../resources/pydantic/types.py"),
|
||||||
|
);
|
||||||
|
|
||||||
|
pub static NUMPY_CTYPESLIB: TestFile = TestFile::new(
|
||||||
|
"numpy/ctypeslib.py",
|
||||||
|
include_str!("../resources/numpy/ctypeslib.py"),
|
||||||
|
);
|
||||||
|
|
||||||
|
// "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py"
|
||||||
|
pub static LARGE_DATASET: TestFile = TestFile::new(
|
||||||
|
"large/dataset.py",
|
||||||
|
include_str!("../resources/large/dataset.py"),
|
||||||
|
);
|
||||||
|
|
||||||
/// Relative size of a test case. Benchmarks can use it to configure the time for how long a benchmark should run to get stable results.
|
/// Relative size of a test case. Benchmarks can use it to configure the time for how long a benchmark should run to get stable results.
|
||||||
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
|
||||||
|
@ -26,35 +48,33 @@ pub struct TestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestCase {
|
impl TestCase {
|
||||||
pub fn fast(file: TestFile) -> Self {
|
pub const fn fast(file: TestFile) -> Self {
|
||||||
Self {
|
Self {
|
||||||
file,
|
file,
|
||||||
speed: TestCaseSpeed::Fast,
|
speed: TestCaseSpeed::Fast,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn normal(file: TestFile) -> Self {
|
pub const fn normal(file: TestFile) -> Self {
|
||||||
Self {
|
Self {
|
||||||
file,
|
file,
|
||||||
speed: TestCaseSpeed::Normal,
|
speed: TestCaseSpeed::Normal,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn slow(file: TestFile) -> Self {
|
pub const fn slow(file: TestFile) -> Self {
|
||||||
Self {
|
Self {
|
||||||
file,
|
file,
|
||||||
speed: TestCaseSpeed::Slow,
|
speed: TestCaseSpeed::Slow,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl TestCase {
|
|
||||||
pub fn code(&self) -> &str {
|
pub fn code(&self) -> &str {
|
||||||
&self.file.code
|
self.file.code
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn name(&self) -> &str {
|
pub fn name(&self) -> &str {
|
||||||
&self.file.name
|
self.file.name
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn speed(&self) -> TestCaseSpeed {
|
pub fn speed(&self) -> TestCaseSpeed {
|
||||||
|
@ -62,119 +82,32 @@ impl TestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn path(&self) -> PathBuf {
|
pub fn path(&self) -> PathBuf {
|
||||||
TARGET_DIR.join(self.name())
|
PathBuf::from(file!())
|
||||||
|
.parent()
|
||||||
|
.unwrap()
|
||||||
|
.parent()
|
||||||
|
.unwrap()
|
||||||
|
.join("resources")
|
||||||
|
.join(self.name())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct TestFile {
|
pub struct TestFile {
|
||||||
name: String,
|
name: &'static str,
|
||||||
code: String,
|
code: &'static str,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TestFile {
|
impl TestFile {
|
||||||
pub fn code(&self) -> &str {
|
pub const fn new(name: &'static str, code: &'static str) -> Self {
|
||||||
&self.code
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn name(&self) -> &str {
|
|
||||||
&self.name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static TARGET_DIR: std::sync::LazyLock<PathBuf> = std::sync::LazyLock::new(|| {
|
|
||||||
cargo_target_directory().unwrap_or_else(|| PathBuf::from("target"))
|
|
||||||
});
|
|
||||||
|
|
||||||
fn cargo_target_directory() -> Option<PathBuf> {
|
|
||||||
#[derive(serde::Deserialize)]
|
|
||||||
struct Metadata {
|
|
||||||
target_directory: PathBuf,
|
|
||||||
}
|
|
||||||
|
|
||||||
std::env::var_os("CARGO_TARGET_DIR")
|
|
||||||
.map(PathBuf::from)
|
|
||||||
.or_else(|| {
|
|
||||||
let output = Command::new(std::env::var_os("CARGO")?)
|
|
||||||
.args(["metadata", "--format-version", "1"])
|
|
||||||
.output()
|
|
||||||
.ok()?;
|
|
||||||
let metadata: Metadata = serde_json::from_slice(&output.stdout).ok()?;
|
|
||||||
Some(metadata.target_directory)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TestFile {
|
|
||||||
pub fn new(name: String, code: String) -> Self {
|
|
||||||
Self { name, code }
|
Self { name, code }
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::print_stderr)]
|
pub fn code(&self) -> &str {
|
||||||
pub fn try_download(name: &str, url: &str) -> Result<TestFile, TestFileDownloadError> {
|
self.code
|
||||||
let url = Url::parse(url)?;
|
|
||||||
|
|
||||||
let cached_filename = TARGET_DIR.join(name);
|
|
||||||
|
|
||||||
if let Ok(content) = std::fs::read_to_string(&cached_filename) {
|
|
||||||
Ok(TestFile::new(name.to_string(), content))
|
|
||||||
} else {
|
|
||||||
// File not yet cached, download and cache it in the target directory
|
|
||||||
let response = ureq::get(url.as_str()).call()?;
|
|
||||||
|
|
||||||
let content = response.into_string()?;
|
|
||||||
|
|
||||||
// SAFETY: There's always the `target` directory
|
|
||||||
let parent = cached_filename.parent().unwrap();
|
|
||||||
if let Err(error) = std::fs::create_dir_all(parent) {
|
|
||||||
eprintln!("Failed to create the directory for the test case {name}: {error}");
|
|
||||||
} else if let Err(error) = std::fs::write(cached_filename, &content) {
|
|
||||||
eprintln!("Failed to cache test case file downloaded from {url}: {error}");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(TestFile::new(name.to_string(), content))
|
pub fn name(&self) -> &str {
|
||||||
|
self.name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub enum TestFileDownloadError {
|
|
||||||
UrlParse(url::ParseError),
|
|
||||||
Request(Box<ureq::Error>),
|
|
||||||
Download(std::io::Error),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<url::ParseError> for TestFileDownloadError {
|
|
||||||
fn from(value: url::ParseError) -> Self {
|
|
||||||
Self::UrlParse(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<ureq::Error> for TestFileDownloadError {
|
|
||||||
fn from(value: ureq::Error) -> Self {
|
|
||||||
Self::Request(Box::new(value))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<std::io::Error> for TestFileDownloadError {
|
|
||||||
fn from(value: std::io::Error) -> Self {
|
|
||||||
Self::Download(value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for TestFileDownloadError {
|
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
|
||||||
match self {
|
|
||||||
TestFileDownloadError::UrlParse(inner) => {
|
|
||||||
write!(f, "Failed to parse url: {inner}")
|
|
||||||
}
|
|
||||||
TestFileDownloadError::Request(inner) => {
|
|
||||||
write!(f, "Failed to download file: {inner}")
|
|
||||||
}
|
|
||||||
TestFileDownloadError::Download(inner) => {
|
|
||||||
write!(f, "Failed to download file: {inner}")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl std::error::Error for TestFileDownloadError {}
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue