Vendor benchmark test files (#15878)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
This commit is contained in:
Micha Reiser 2025-02-02 18:16:07 +00:00 committed by GitHub
parent d9a1034db0
commit 770b7f3439
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 4235 additions and 314 deletions

View file

@ -5,6 +5,7 @@ exclude: |
.github/workflows/release.yml|
crates/red_knot_vendored/vendor/.*|
crates/red_knot_project/resources/.*|
crates/ruff_benchmark/resources/.*|
crates/ruff_linter/resources/.*|
crates/ruff_linter/src/rules/.*/snapshots/.*|
crates/ruff_notebook/resources/.*|

108
Cargo.lock generated
View file

@ -190,12 +190,6 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
[[package]]
name = "base64"
version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bincode"
version = "1.3.3"
@ -2577,28 +2571,13 @@ version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "ring"
version = "0.17.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
"getrandom",
"libc",
"spin",
"untrusted",
"windows-sys 0.52.0",
]
[[package]]
name = "ron"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88073939a61e5b7680558e6be56b419e208420c2adb92be54921fa6b72283f1a"
dependencies = [
"base64 0.13.1",
"base64",
"bitflags 1.3.2",
"serde",
]
@ -2692,11 +2671,7 @@ dependencies = [
"ruff_python_parser",
"ruff_python_trivia",
"rustc-hash 2.1.0",
"serde",
"serde_json",
"tikv-jemallocator",
"ureq",
"url",
]
[[package]]
@ -3253,38 +3228,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "rustls"
version = "0.23.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8"
dependencies = [
"log",
"once_cell",
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pki-types"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37"
[[package]]
name = "rustls-webpki"
version = "0.102.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "rustversion"
version = "1.0.19"
@ -3567,12 +3510,6 @@ dependencies = [
"anstream",
]
[[package]]
name = "spin"
version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
@ -3622,12 +3559,6 @@ dependencies = [
"syn 2.0.96",
]
[[package]]
name = "subtle"
version = "2.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292"
[[package]]
name = "syn"
version = "1.0.109"
@ -4116,28 +4047,6 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9df2af067a7953e9c3831320f35c1cc0600c30d44d9f7a12b01db1cd88d6b47"
[[package]]
name = "untrusted"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
[[package]]
name = "ureq"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d"
dependencies = [
"base64 0.22.1",
"flate2",
"log",
"once_cell",
"rustls",
"rustls-pki-types",
"url",
"webpki-roots",
]
[[package]]
name = "url"
version = "2.5.4"
@ -4406,15 +4315,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "webpki-roots"
version = "0.26.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d642ff16b7e79272ae451b7322067cdc17cadf68c23264be9d94a32319efe7e"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "7.0.1"
@ -4723,12 +4623,6 @@ dependencies = [
"synstructure",
]
[[package]]
name = "zeroize"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde"
[[package]]
name = "zerovec"
version = "0.10.4"

View file

@ -134,7 +134,12 @@ serde_with = { version = "3.6.0", default-features = false, features = [
shellexpand = { version = "3.0.0" }
similar = { version = "2.4.0", features = ["inline"] }
smallvec = { version = "1.13.2" }
snapbox = { version = "0.6.0", features = ["diff", "term-svg", "cmd", "examples"] }
snapbox = { version = "0.6.0", features = [
"diff",
"term-svg",
"cmd",
"examples",
] }
static_assertions = "1.1.0"
strum = { version = "0.26.0", features = ["strum_macros"] }
strum_macros = { version = "0.26.0" }
@ -159,7 +164,6 @@ unicode-ident = { version = "1.0.12" }
unicode-width = { version = "0.2.0" }
unicode_names2 = { version = "1.2.2" }
unicode-normalization = { version = "0.1.23" }
ureq = { version = "2.9.6" }
url = { version = "2.5.0" }
uuid = { version = "1.6.1", features = [
"v4",
@ -305,7 +309,11 @@ local-artifacts-jobs = ["./build-binaries", "./build-docker"]
# Publish jobs to run in CI
publish-jobs = ["./publish-pypi", "./publish-wasm"]
# Post-announce jobs to run in CI
post-announce-jobs = ["./notify-dependents", "./publish-docs", "./publish-playground"]
post-announce-jobs = [
"./notify-dependents",
"./publish-docs",
"./publish-playground",
]
# Custom permissions for GitHub Jobs
github-custom-job-permissions = { "build-docker" = { packages = "write", contents = "read" }, "publish-wasm" = { contents = "read", id-token = "write", packages = "write" } }
# Whether to install an updater program

View file

@ -41,10 +41,6 @@ codspeed-criterion-compat = { workspace = true, default-features = false, option
criterion = { workspace = true, default-features = false }
rayon = { workspace = true }
rustc-hash = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
url = { workspace = true }
ureq = { workspace = true }
[dev-dependencies]
ruff_db = { workspace = true }

View file

@ -3,7 +3,10 @@ use std::path::Path;
use ruff_benchmark::criterion::{
criterion_group, criterion_main, BenchmarkId, Criterion, Throughput,
};
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
use ruff_benchmark::{
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
};
use ruff_python_formatter::{format_module_ast, PreviewMode, PyFormatOptions};
use ruff_python_parser::{parse, Mode};
use ruff_python_trivia::CommentRanges;
@ -24,27 +27,20 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
Ok(vec![
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
TestCase::normal(TestFile::try_download(
"pydantic/types.py",
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
)?),
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
TestCase::slow(TestFile::try_download(
"large/dataset.py",
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
)?),
])
fn create_test_cases() -> Vec<TestCase> {
vec![
TestCase::fast(NUMPY_GLOBALS.clone()),
TestCase::fast(UNICODE_PYPINYIN.clone()),
TestCase::normal(PYDANTIC_TYPES.clone()),
TestCase::normal(NUMPY_CTYPESLIB.clone()),
TestCase::slow(LARGE_DATASET.clone()),
]
}
fn benchmark_formatter(criterion: &mut Criterion) {
let mut group = criterion.benchmark_group("formatter");
let test_cases = create_test_cases().unwrap();
for case in test_cases {
for case in create_test_cases() {
group.throughput(Throughput::Bytes(case.code().len() as u64));
group.bench_with_input(

View file

@ -1,7 +1,9 @@
use ruff_benchmark::criterion::{
criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput,
};
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
use ruff_benchmark::{
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
};
use ruff_python_parser::{lexer, Mode, TokenKind};
#[cfg(target_os = "windows")]
@ -20,24 +22,18 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
Ok(vec![
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
TestCase::normal(TestFile::try_download(
"pydantic/types.py",
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
)?),
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
TestCase::slow(TestFile::try_download(
"large/dataset.py",
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
)?),
])
fn create_test_cases() -> Vec<TestCase> {
vec![
TestCase::fast(NUMPY_GLOBALS.clone()),
TestCase::fast(UNICODE_PYPINYIN.clone()),
TestCase::normal(PYDANTIC_TYPES.clone()),
TestCase::normal(NUMPY_CTYPESLIB.clone()),
TestCase::slow(LARGE_DATASET.clone()),
]
}
fn benchmark_lexer(criterion: &mut Criterion<WallTime>) {
let test_cases = create_test_cases().unwrap();
let test_cases = create_test_cases();
let mut group = criterion.benchmark_group("lexer");
for case in test_cases {

View file

@ -1,7 +1,9 @@
use ruff_benchmark::criterion::{
criterion_group, criterion_main, BenchmarkGroup, BenchmarkId, Criterion, Throughput,
};
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
use ruff_benchmark::{
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
};
use ruff_linter::linter::{lint_only, ParseSource};
use ruff_linter::rule_selector::PreviewOptions;
use ruff_linter::settings::rule_table::RuleTable;
@ -46,24 +48,18 @@ static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
#[allow(unsafe_code)]
pub static _rjem_malloc_conf: &[u8] = b"dirty_decay_ms:-1,muzzy_decay_ms:-1\0";
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
Ok(vec![
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
TestCase::normal(TestFile::try_download(
"pydantic/types.py",
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
)?),
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
TestCase::slow(TestFile::try_download(
"large/dataset.py",
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
)?),
])
fn create_test_cases() -> Vec<TestCase> {
vec![
TestCase::fast(NUMPY_GLOBALS.clone()),
TestCase::fast(UNICODE_PYPINYIN.clone()),
TestCase::normal(PYDANTIC_TYPES.clone()),
TestCase::normal(NUMPY_CTYPESLIB.clone()),
TestCase::slow(LARGE_DATASET.clone()),
]
}
fn benchmark_linter(mut group: BenchmarkGroup, settings: &LinterSettings) {
let test_cases = create_test_cases().unwrap();
let test_cases = create_test_cases();
for case in test_cases {
group.throughput(Throughput::Bytes(case.code().len() as u64));

View file

@ -1,7 +1,9 @@
use ruff_benchmark::criterion::{
criterion_group, criterion_main, measurement::WallTime, BenchmarkId, Criterion, Throughput,
};
use ruff_benchmark::{TestCase, TestFile, TestFileDownloadError};
use ruff_benchmark::{
TestCase, LARGE_DATASET, NUMPY_CTYPESLIB, NUMPY_GLOBALS, PYDANTIC_TYPES, UNICODE_PYPINYIN,
};
use ruff_python_ast::statement_visitor::{walk_stmt, StatementVisitor};
use ruff_python_ast::Stmt;
use ruff_python_parser::parse_module;
@ -22,20 +24,14 @@ static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[global_allocator]
static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc;
fn create_test_cases() -> Result<Vec<TestCase>, TestFileDownloadError> {
Ok(vec![
TestCase::fast(TestFile::try_download("numpy/globals.py", "https://raw.githubusercontent.com/numpy/numpy/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py")?),
TestCase::fast(TestFile::try_download("unicode/pypinyin.py", "https://raw.githubusercontent.com/mozillazg/python-pinyin/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py")?),
TestCase::normal(TestFile::try_download(
"pydantic/types.py",
"https://raw.githubusercontent.com/pydantic/pydantic/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py",
)?),
TestCase::normal(TestFile::try_download("numpy/ctypeslib.py", "https://raw.githubusercontent.com/numpy/numpy/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py")?),
TestCase::slow(TestFile::try_download(
"large/dataset.py",
"https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py",
)?),
])
fn create_test_cases() -> Vec<TestCase> {
vec![
TestCase::fast(NUMPY_GLOBALS.clone()),
TestCase::fast(UNICODE_PYPINYIN.clone()),
TestCase::normal(PYDANTIC_TYPES.clone()),
TestCase::normal(NUMPY_CTYPESLIB.clone()),
TestCase::slow(LARGE_DATASET.clone()),
]
}
struct CountVisitor {
@ -50,7 +46,7 @@ impl<'a> StatementVisitor<'a> for CountVisitor {
}
fn benchmark_parser(criterion: &mut Criterion<WallTime>) {
let test_cases = create_test_cases().unwrap();
let test_cases = create_test_cases();
let mut group = criterion.benchmark_group("parser");
for case in test_cases {

View file

@ -24,7 +24,25 @@ struct Case {
re_path: SystemPathBuf,
}
const TOMLLIB_312_URL: &str = "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib";
// "https://raw.githubusercontent.com/python/cpython/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib";
static TOMLLIB_FILES: [TestFile; 4] = [
TestFile::new(
"tomllib/__init__.py",
include_str!("../resources/tomllib/__init__.py"),
),
TestFile::new(
"tomllib/_parser.py",
include_str!("../resources/tomllib/_parser.py"),
),
TestFile::new(
"tomllib/_re.py",
include_str!("../resources/tomllib/_re.py"),
),
TestFile::new(
"tomllib/_types.py",
include_str!("../resources/tomllib/_types.py"),
),
];
/// A structured set of fields we use to do diagnostic comparisons.
///
@ -80,27 +98,19 @@ static EXPECTED_DIAGNOSTICS: &[KeyDiagnosticFields] = &[
),
];
fn get_test_file(name: &str) -> TestFile {
let path = format!("tomllib/{name}");
let url = format!("{TOMLLIB_312_URL}/{name}");
TestFile::try_download(&path, &url).unwrap()
}
fn tomllib_path(filename: &str) -> SystemPathBuf {
SystemPathBuf::from(format!("/src/tomllib/{filename}").as_str())
fn tomllib_path(file: &TestFile) -> SystemPathBuf {
SystemPathBuf::from("src").join(file.name())
}
fn setup_case() -> Case {
let system = TestSystem::default();
let fs = system.memory_file_system().clone();
let tomllib_filenames = ["__init__.py", "_parser.py", "_re.py", "_types.py"];
fs.write_files(tomllib_filenames.iter().map(|filename| {
(
tomllib_path(filename),
get_test_file(filename).code().to_string(),
)
}))
fs.write_files(
TOMLLIB_FILES
.iter()
.map(|file| (tomllib_path(file), file.code().to_string())),
)
.unwrap();
let src_root = SystemPath::new("/src");
@ -114,15 +124,22 @@ fn setup_case() -> Case {
});
let mut db = ProjectDatabase::new(metadata, system).unwrap();
let mut tomllib_files = FxHashSet::default();
let mut re: Option<File> = None;
for test_file in &TOMLLIB_FILES {
let file = system_path_to_file(&db, tomllib_path(test_file)).unwrap();
if test_file.name().ends_with("_re.py") {
re = Some(file);
}
tomllib_files.insert(file);
}
let re = re.unwrap();
let tomllib_files: FxHashSet<File> = tomllib_filenames
.iter()
.map(|filename| system_path_to_file(&db, tomllib_path(filename)).unwrap())
.collect();
db.project().set_open_files(&mut db, tomllib_files);
let re_path = tomllib_path("_re.py");
let re = system_path_to_file(&db, &re_path).unwrap();
let re_path = re.path(&db).as_system_path().unwrap().to_owned();
Case {
db,
fs,

View file

@ -0,0 +1,16 @@
This directory vendors some files from actual projects.
This is to benchmark Ruff's performance against real-world
code instead of synthetic benchmarks.
The following files are included:
* [`numpy/globals`](https://github.com/numpy/numpy/blob/89d64415e349ca75a25250f22b874aa16e5c0973/numpy/_globals.py)
* [`numpy/ctypeslib.py`](https://github.com/numpy/numpy/blob/e42c9503a14d66adfd41356ef5640c6975c45218/numpy/ctypeslib.py)
* [`pypinyin.py`](https://github.com/mozillazg/python-pinyin/blob/9521e47d96e3583a5477f5e43a2e82d513f27a3f/pypinyin/standard.py)
* [`pydantic/types.py`](https://github.com/pydantic/pydantic/blob/83b3c49e99ceb4599d9286a3d793cea44ac36d4b/pydantic/types.py)
* [`large/dataset.py`](https://github.com/DHI/mikeio/blob/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py)
* [`tomllib`](https://github.com/python/cpython/tree/8e8a4baf652f6e1cee7acde9d78c4b6154539748/Lib/tomllib) (3.12)
The files are included in the `resources` directory to allow
running benchmarks offline and for simplicity. They're licensed
according to their original licenses (see link).

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,547 @@
"""
============================
``ctypes`` Utility Functions
============================
See Also
--------
load_library : Load a C library.
ndpointer : Array restype/argtype with verification.
as_ctypes : Create a ctypes array from an ndarray.
as_array : Create an ndarray from a ctypes array.
References
----------
.. [1] "SciPy Cookbook: ctypes", https://scipy-cookbook.readthedocs.io/items/Ctypes.html
Examples
--------
Load the C library:
>>> _lib = np.ctypeslib.load_library('libmystuff', '.') #doctest: +SKIP
Our result type, an ndarray that must be of type double, be 1-dimensional
and is C-contiguous in memory:
>>> array_1d_double = np.ctypeslib.ndpointer(
... dtype=np.double,
... ndim=1, flags='CONTIGUOUS') #doctest: +SKIP
Our C-function typically takes an array and updates its values
in-place. For example::
void foo_func(double* x, int length)
{
int i;
for (i = 0; i < length; i++) {
x[i] = i*i;
}
}
We wrap it using:
>>> _lib.foo_func.restype = None #doctest: +SKIP
>>> _lib.foo_func.argtypes = [array_1d_double, c_int] #doctest: +SKIP
Then, we're ready to call ``foo_func``:
>>> out = np.empty(15, dtype=np.double)
>>> _lib.foo_func(out, len(out)) #doctest: +SKIP
"""
__all__ = ['load_library', 'ndpointer', 'c_intp', 'as_ctypes', 'as_array',
'as_ctypes_type']
import os
from numpy import (
integer, ndarray, dtype as _dtype, asarray, frombuffer
)
from numpy.core.multiarray import _flagdict, flagsobj
try:
import ctypes
except ImportError:
ctypes = None
if ctypes is None:
def _dummy(*args, **kwds):
"""
Dummy object that raises an ImportError if ctypes is not available.
Raises
------
ImportError
If ctypes is not available.
"""
raise ImportError("ctypes is not available.")
load_library = _dummy
as_ctypes = _dummy
as_array = _dummy
from numpy import intp as c_intp
_ndptr_base = object
else:
import numpy.core._internal as nic
c_intp = nic._getintp_ctype()
del nic
_ndptr_base = ctypes.c_void_p
# Adapted from Albert Strasheim
def load_library(libname, loader_path):
"""
It is possible to load a library using
>>> lib = ctypes.cdll[<full_path_name>] # doctest: +SKIP
But there are cross-platform considerations, such as library file extensions,
plus the fact Windows will just load the first library it finds with that name.
NumPy supplies the load_library function as a convenience.
.. versionchanged:: 1.20.0
Allow libname and loader_path to take any
:term:`python:path-like object`.
Parameters
----------
libname : path-like
Name of the library, which can have 'lib' as a prefix,
but without an extension.
loader_path : path-like
Where the library can be found.
Returns
-------
ctypes.cdll[libpath] : library object
A ctypes library object
Raises
------
OSError
If there is no library with the expected extension, or the
library is defective and cannot be loaded.
"""
if ctypes.__version__ < '1.0.1':
import warnings
warnings.warn("All features of ctypes interface may not work "
"with ctypes < 1.0.1", stacklevel=2)
# Convert path-like objects into strings
libname = os.fsdecode(libname)
loader_path = os.fsdecode(loader_path)
ext = os.path.splitext(libname)[1]
if not ext:
# Try to load library with platform-specific name, otherwise
# default to libname.[so|pyd]. Sometimes, these files are built
# erroneously on non-linux platforms.
from numpy.distutils.misc_util import get_shared_lib_extension
so_ext = get_shared_lib_extension()
libname_ext = [libname + so_ext]
# mac, windows and linux >= py3.2 shared library and loadable
# module have different extensions so try both
so_ext2 = get_shared_lib_extension(is_python_ext=True)
if not so_ext2 == so_ext:
libname_ext.insert(0, libname + so_ext2)
else:
libname_ext = [libname]
loader_path = os.path.abspath(loader_path)
if not os.path.isdir(loader_path):
libdir = os.path.dirname(loader_path)
else:
libdir = loader_path
for ln in libname_ext:
libpath = os.path.join(libdir, ln)
if os.path.exists(libpath):
try:
return ctypes.cdll[libpath]
except OSError:
## defective lib file
raise
## if no successful return in the libname_ext loop:
raise OSError("no file with expected extension")
def _num_fromflags(flaglist):
num = 0
for val in flaglist:
num += _flagdict[val]
return num
_flagnames = ['C_CONTIGUOUS', 'F_CONTIGUOUS', 'ALIGNED', 'WRITEABLE',
'OWNDATA', 'WRITEBACKIFCOPY']
def _flags_fromnum(num):
res = []
for key in _flagnames:
value = _flagdict[key]
if (num & value):
res.append(key)
return res
class _ndptr(_ndptr_base):
@classmethod
def from_param(cls, obj):
if not isinstance(obj, ndarray):
raise TypeError("argument must be an ndarray")
if cls._dtype_ is not None \
and obj.dtype != cls._dtype_:
raise TypeError("array must have data type %s" % cls._dtype_)
if cls._ndim_ is not None \
and obj.ndim != cls._ndim_:
raise TypeError("array must have %d dimension(s)" % cls._ndim_)
if cls._shape_ is not None \
and obj.shape != cls._shape_:
raise TypeError("array must have shape %s" % str(cls._shape_))
if cls._flags_ is not None \
and ((obj.flags.num & cls._flags_) != cls._flags_):
raise TypeError("array must have flags %s" %
_flags_fromnum(cls._flags_))
return obj.ctypes
class _concrete_ndptr(_ndptr):
"""
Like _ndptr, but with `_shape_` and `_dtype_` specified.
Notably, this means the pointer has enough information to reconstruct
the array, which is not generally true.
"""
def _check_retval_(self):
"""
This method is called when this class is used as the .restype
attribute for a shared-library function, to automatically wrap the
pointer into an array.
"""
return self.contents
@property
def contents(self):
"""
Get an ndarray viewing the data pointed to by this pointer.
This mirrors the `contents` attribute of a normal ctypes pointer
"""
full_dtype = _dtype((self._dtype_, self._shape_))
full_ctype = ctypes.c_char * full_dtype.itemsize
buffer = ctypes.cast(self, ctypes.POINTER(full_ctype)).contents
return frombuffer(buffer, dtype=full_dtype).squeeze(axis=0)
# Factory for an array-checking class with from_param defined for
# use with ctypes argtypes mechanism
_pointer_type_cache = {}
def ndpointer(dtype=None, ndim=None, shape=None, flags=None):
"""
Array-checking restype/argtypes.
An ndpointer instance is used to describe an ndarray in restypes
and argtypes specifications. This approach is more flexible than
using, for example, ``POINTER(c_double)``, since several restrictions
can be specified, which are verified upon calling the ctypes function.
These include data type, number of dimensions, shape and flags. If a
given array does not satisfy the specified restrictions,
a ``TypeError`` is raised.
Parameters
----------
dtype : data-type, optional
Array data-type.
ndim : int, optional
Number of array dimensions.
shape : tuple of ints, optional
Array shape.
flags : str or tuple of str
Array flags; may be one or more of:
- C_CONTIGUOUS / C / CONTIGUOUS
- F_CONTIGUOUS / F / FORTRAN
- OWNDATA / O
- WRITEABLE / W
- ALIGNED / A
- WRITEBACKIFCOPY / X
Returns
-------
klass : ndpointer type object
A type object, which is an ``_ndtpr`` instance containing
dtype, ndim, shape and flags information.
Raises
------
TypeError
If a given array does not satisfy the specified restrictions.
Examples
--------
>>> clib.somefunc.argtypes = [np.ctypeslib.ndpointer(dtype=np.float64,
... ndim=1,
... flags='C_CONTIGUOUS')]
... #doctest: +SKIP
>>> clib.somefunc(np.array([1, 2, 3], dtype=np.float64))
... #doctest: +SKIP
"""
# normalize dtype to an Optional[dtype]
if dtype is not None:
dtype = _dtype(dtype)
# normalize flags to an Optional[int]
num = None
if flags is not None:
if isinstance(flags, str):
flags = flags.split(',')
elif isinstance(flags, (int, integer)):
num = flags
flags = _flags_fromnum(num)
elif isinstance(flags, flagsobj):
num = flags.num
flags = _flags_fromnum(num)
if num is None:
try:
flags = [x.strip().upper() for x in flags]
except Exception as e:
raise TypeError("invalid flags specification") from e
num = _num_fromflags(flags)
# normalize shape to an Optional[tuple]
if shape is not None:
try:
shape = tuple(shape)
except TypeError:
# single integer -> 1-tuple
shape = (shape,)
cache_key = (dtype, ndim, shape, num)
try:
return _pointer_type_cache[cache_key]
except KeyError:
pass
# produce a name for the new type
if dtype is None:
name = 'any'
elif dtype.names is not None:
name = str(id(dtype))
else:
name = dtype.str
if ndim is not None:
name += "_%dd" % ndim
if shape is not None:
name += "_"+"x".join(str(x) for x in shape)
if flags is not None:
name += "_"+"_".join(flags)
if dtype is not None and shape is not None:
base = _concrete_ndptr
else:
base = _ndptr
klass = type("ndpointer_%s"%name, (base,),
{"_dtype_": dtype,
"_shape_" : shape,
"_ndim_" : ndim,
"_flags_" : num})
_pointer_type_cache[cache_key] = klass
return klass
if ctypes is not None:
def _ctype_ndarray(element_type, shape):
""" Create an ndarray of the given element type and shape """
for dim in shape[::-1]:
element_type = dim * element_type
# prevent the type name include np.ctypeslib
element_type.__module__ = None
return element_type
def _get_scalar_type_map():
"""
Return a dictionary mapping native endian scalar dtype to ctypes types
"""
ct = ctypes
simple_types = [
ct.c_byte, ct.c_short, ct.c_int, ct.c_long, ct.c_longlong,
ct.c_ubyte, ct.c_ushort, ct.c_uint, ct.c_ulong, ct.c_ulonglong,
ct.c_float, ct.c_double,
ct.c_bool,
]
return {_dtype(ctype): ctype for ctype in simple_types}
_scalar_type_map = _get_scalar_type_map()
def _ctype_from_dtype_scalar(dtype):
# swapping twice ensure that `=` is promoted to <, >, or |
dtype_with_endian = dtype.newbyteorder('S').newbyteorder('S')
dtype_native = dtype.newbyteorder('=')
try:
ctype = _scalar_type_map[dtype_native]
except KeyError as e:
raise NotImplementedError(
"Converting {!r} to a ctypes type".format(dtype)
) from None
if dtype_with_endian.byteorder == '>':
ctype = ctype.__ctype_be__
elif dtype_with_endian.byteorder == '<':
ctype = ctype.__ctype_le__
return ctype
def _ctype_from_dtype_subarray(dtype):
element_dtype, shape = dtype.subdtype
ctype = _ctype_from_dtype(element_dtype)
return _ctype_ndarray(ctype, shape)
def _ctype_from_dtype_structured(dtype):
# extract offsets of each field
field_data = []
for name in dtype.names:
field_dtype, offset = dtype.fields[name][:2]
field_data.append((offset, name, _ctype_from_dtype(field_dtype)))
# ctypes doesn't care about field order
field_data = sorted(field_data, key=lambda f: f[0])
if len(field_data) > 1 and all(offset == 0 for offset, name, ctype in field_data):
# union, if multiple fields all at address 0
size = 0
_fields_ = []
for offset, name, ctype in field_data:
_fields_.append((name, ctype))
size = max(size, ctypes.sizeof(ctype))
# pad to the right size
if dtype.itemsize != size:
_fields_.append(('', ctypes.c_char * dtype.itemsize))
# we inserted manual padding, so always `_pack_`
return type('union', (ctypes.Union,), dict(
_fields_=_fields_,
_pack_=1,
__module__=None,
))
else:
last_offset = 0
_fields_ = []
for offset, name, ctype in field_data:
padding = offset - last_offset
if padding < 0:
raise NotImplementedError("Overlapping fields")
if padding > 0:
_fields_.append(('', ctypes.c_char * padding))
_fields_.append((name, ctype))
last_offset = offset + ctypes.sizeof(ctype)
padding = dtype.itemsize - last_offset
if padding > 0:
_fields_.append(('', ctypes.c_char * padding))
# we inserted manual padding, so always `_pack_`
return type('struct', (ctypes.Structure,), dict(
_fields_=_fields_,
_pack_=1,
__module__=None,
))
def _ctype_from_dtype(dtype):
if dtype.fields is not None:
return _ctype_from_dtype_structured(dtype)
elif dtype.subdtype is not None:
return _ctype_from_dtype_subarray(dtype)
else:
return _ctype_from_dtype_scalar(dtype)
def as_ctypes_type(dtype):
r"""
Convert a dtype into a ctypes type.
Parameters
----------
dtype : dtype
The dtype to convert
Returns
-------
ctype
A ctype scalar, union, array, or struct
Raises
------
NotImplementedError
If the conversion is not possible
Notes
-----
This function does not losslessly round-trip in either direction.
``np.dtype(as_ctypes_type(dt))`` will:
- insert padding fields
- reorder fields to be sorted by offset
- discard field titles
``as_ctypes_type(np.dtype(ctype))`` will:
- discard the class names of `ctypes.Structure`\ s and
`ctypes.Union`\ s
- convert single-element `ctypes.Union`\ s into single-element
`ctypes.Structure`\ s
- insert padding fields
"""
return _ctype_from_dtype(_dtype(dtype))
def as_array(obj, shape=None):
"""
Create a numpy array from a ctypes array or POINTER.
The numpy array shares the memory with the ctypes object.
The shape parameter must be given if converting from a ctypes POINTER.
The shape parameter is ignored if converting from a ctypes array
"""
if isinstance(obj, ctypes._Pointer):
# convert pointers to an array of the desired shape
if shape is None:
raise TypeError(
'as_array() requires a shape argument when called on a '
'pointer')
p_arr_type = ctypes.POINTER(_ctype_ndarray(obj._type_, shape))
obj = ctypes.cast(obj, p_arr_type).contents
return asarray(obj)
def as_ctypes(obj):
"""Create and return a ctypes object from a numpy array. Actually
anything that exposes the __array_interface__ is accepted."""
ai = obj.__array_interface__
if ai["strides"]:
raise TypeError("strided arrays not supported")
if ai["version"] != 3:
raise TypeError("only __array_interface__ version 3 supported")
addr, readonly = ai["data"]
if readonly:
raise TypeError("readonly arrays unsupported")
# can't use `_dtype((ai["typestr"], ai["shape"]))` here, as it overflows
# dtype.itemsize (gh-14214)
ctype_scalar = as_ctypes_type(ai["typestr"])
result_type = _ctype_ndarray(ctype_scalar, ai["shape"])
result = result_type.from_address(addr)
result.__keep = obj
return result

View file

@ -0,0 +1,95 @@
"""
Module defining global singleton classes.
This module raises a RuntimeError if an attempt to reload it is made. In that
way the identities of the classes defined here are fixed and will remain so
even if numpy itself is reloaded. In particular, a function like the following
will still work correctly after numpy is reloaded::
def foo(arg=np._NoValue):
if arg is np._NoValue:
...
That was not the case when the singleton classes were defined in the numpy
``__init__.py`` file. See gh-7844 for a discussion of the reload problem that
motivated this module.
"""
import enum
from ._utils import set_module as _set_module
__all__ = ['_NoValue', '_CopyMode']
# Disallow reloading this module so as to preserve the identities of the
# classes defined here.
if '_is_loaded' in globals():
raise RuntimeError('Reloading numpy._globals is not allowed')
_is_loaded = True
class _NoValueType:
"""Special keyword value.
The instance of this class may be used as the default value assigned to a
keyword if no other obvious default (e.g., `None`) is suitable,
Common reasons for using this keyword are:
- A new keyword is added to a function, and that function forwards its
inputs to another function or method which can be defined outside of
NumPy. For example, ``np.std(x)`` calls ``x.std``, so when a ``keepdims``
keyword was added that could only be forwarded if the user explicitly
specified ``keepdims``; downstream array libraries may not have added
the same keyword, so adding ``x.std(..., keepdims=keepdims)``
unconditionally could have broken previously working code.
- A keyword is being deprecated, and a deprecation warning must only be
emitted when the keyword is used.
"""
__instance = None
def __new__(cls):
# ensure that only one instance exists
if not cls.__instance:
cls.__instance = super().__new__(cls)
return cls.__instance
def __repr__(self):
return "<no value>"
_NoValue = _NoValueType()
@_set_module("numpy")
class _CopyMode(enum.Enum):
"""
An enumeration for the copy modes supported
by numpy.copy() and numpy.array(). The following three modes are supported,
- ALWAYS: This means that a deep copy of the input
array will always be taken.
- IF_NEEDED: This means that a deep copy of the input
array will be taken only if necessary.
- NEVER: This means that the deep copy will never be taken.
If a copy cannot be avoided then a `ValueError` will be
raised.
Note that the buffer-protocol could in theory do copies. NumPy currently
assumes an object exporting the buffer protocol will never do this.
"""
ALWAYS = True
IF_NEEDED = False
NEVER = 2
def __bool__(self):
# For backwards compatibility
if self == _CopyMode.ALWAYS:
return True
if self == _CopyMode.IF_NEEDED:
return False
raise ValueError(f"{self} is neither True nor False.")

View file

@ -0,0 +1,834 @@
from __future__ import annotations as _annotations
import abc
import dataclasses as _dataclasses
import re
from datetime import date, datetime
from decimal import Decimal
from enum import Enum
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
FrozenSet,
Generic,
Hashable,
List,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from uuid import UUID
import annotated_types
from pydantic_core import PydanticCustomError, PydanticKnownError, core_schema
from typing_extensions import Annotated, Literal
from ._internal import _fields, _validators
__all__ = [
'Strict',
'StrictStr',
'conbytes',
'conlist',
'conset',
'confrozenset',
'constr',
'ImportString',
'conint',
'PositiveInt',
'NegativeInt',
'NonNegativeInt',
'NonPositiveInt',
'confloat',
'PositiveFloat',
'NegativeFloat',
'NonNegativeFloat',
'NonPositiveFloat',
'FiniteFloat',
'condecimal',
'UUID1',
'UUID3',
'UUID4',
'UUID5',
'FilePath',
'DirectoryPath',
'Json',
'SecretField',
'SecretStr',
'SecretBytes',
'StrictBool',
'StrictBytes',
'StrictInt',
'StrictFloat',
'PaymentCardNumber',
'ByteSize',
'PastDate',
'FutureDate',
'condate',
'AwareDatetime',
'NaiveDatetime',
]
from ._internal._core_metadata import build_metadata_dict
from ._internal._utils import update_not_none
from .json_schema import JsonSchemaMetadata
@_dataclasses.dataclass
class Strict(_fields.PydanticMetadata):
strict: bool = True
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BOOLEAN TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
StrictBool = Annotated[bool, Strict()]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTEGER TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def conint(
*,
strict: bool | None = None,
gt: int | None = None,
ge: int | None = None,
lt: int | None = None,
le: int | None = None,
multiple_of: int | None = None,
) -> type[int]:
return Annotated[ # type: ignore[return-value]
int,
Strict(strict) if strict is not None else None,
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None,
]
PositiveInt = Annotated[int, annotated_types.Gt(0)]
NegativeInt = Annotated[int, annotated_types.Lt(0)]
NonPositiveInt = Annotated[int, annotated_types.Le(0)]
NonNegativeInt = Annotated[int, annotated_types.Ge(0)]
StrictInt = Annotated[int, Strict()]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FLOAT TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@_dataclasses.dataclass
class AllowInfNan(_fields.PydanticMetadata):
allow_inf_nan: bool = True
def confloat(
*,
strict: bool | None = None,
gt: float | None = None,
ge: float | None = None,
lt: float | None = None,
le: float | None = None,
multiple_of: float | None = None,
allow_inf_nan: bool | None = None,
) -> type[float]:
return Annotated[ # type: ignore[return-value]
float,
Strict(strict) if strict is not None else None,
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None,
AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None,
]
PositiveFloat = Annotated[float, annotated_types.Gt(0)]
NegativeFloat = Annotated[float, annotated_types.Lt(0)]
NonPositiveFloat = Annotated[float, annotated_types.Le(0)]
NonNegativeFloat = Annotated[float, annotated_types.Ge(0)]
StrictFloat = Annotated[float, Strict(True)]
FiniteFloat = Annotated[float, AllowInfNan(False)]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTES TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def conbytes(
*,
min_length: int | None = None,
max_length: int | None = None,
strict: bool | None = None,
) -> type[bytes]:
return Annotated[ # type: ignore[return-value]
bytes,
Strict(strict) if strict is not None else None,
annotated_types.Len(min_length or 0, max_length),
]
StrictBytes = Annotated[bytes, Strict()]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ STRING TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def constr(
*,
strip_whitespace: bool | None = None,
to_upper: bool | None = None,
to_lower: bool | None = None,
strict: bool | None = None,
min_length: int | None = None,
max_length: int | None = None,
pattern: str | None = None,
) -> type[str]:
return Annotated[ # type: ignore[return-value]
str,
Strict(strict) if strict is not None else None,
annotated_types.Len(min_length or 0, max_length),
_fields.PydanticGeneralMetadata(
strip_whitespace=strip_whitespace,
to_upper=to_upper,
to_lower=to_lower,
pattern=pattern,
),
]
StrictStr = Annotated[str, Strict()]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ COLLECTION TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
HashableItemType = TypeVar('HashableItemType', bound=Hashable)
def conset(
item_type: Type[HashableItemType], *, min_length: int = None, max_length: int = None
) -> Type[Set[HashableItemType]]:
return Annotated[ # type: ignore[return-value]
Set[item_type], annotated_types.Len(min_length or 0, max_length) # type: ignore[valid-type]
]
def confrozenset(
item_type: Type[HashableItemType], *, min_length: int | None = None, max_length: int | None = None
) -> Type[FrozenSet[HashableItemType]]:
return Annotated[ # type: ignore[return-value]
FrozenSet[item_type], # type: ignore[valid-type]
annotated_types.Len(min_length or 0, max_length),
]
AnyItemType = TypeVar('AnyItemType')
def conlist(
item_type: Type[AnyItemType], *, min_length: int | None = None, max_length: int | None = None
) -> Type[List[AnyItemType]]:
return Annotated[ # type: ignore[return-value]
List[item_type], # type: ignore[valid-type]
annotated_types.Len(min_length or 0, max_length),
]
def contuple(
item_type: Type[AnyItemType], *, min_length: int | None = None, max_length: int | None = None
) -> Type[Tuple[AnyItemType]]:
return Annotated[ # type: ignore[return-value]
Tuple[item_type],
annotated_types.Len(min_length or 0, max_length),
]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~ IMPORT STRING TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
AnyType = TypeVar('AnyType')
if TYPE_CHECKING:
ImportString = Annotated[AnyType, ...]
else:
class ImportString:
@classmethod
def __class_getitem__(cls, item: AnyType) -> AnyType:
return Annotated[item, cls()]
@classmethod
def __get_pydantic_core_schema__(
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
) -> core_schema.CoreSchema:
if schema is None or schema == {'type': 'any'}:
# Treat bare usage of ImportString (`schema is None`) as the same as ImportString[Any]
return core_schema.function_plain_schema(lambda v, _: _validators.import_string(v))
else:
return core_schema.function_before_schema(lambda v, _: _validators.import_string(v), schema)
def __repr__(self) -> str:
return 'ImportString'
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DECIMAL TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def condecimal(
*,
strict: bool | None = None,
gt: int | Decimal | None = None,
ge: int | Decimal | None = None,
lt: int | Decimal | None = None,
le: int | Decimal | None = None,
multiple_of: int | Decimal | None = None,
max_digits: int | None = None,
decimal_places: int | None = None,
allow_inf_nan: bool | None = None,
) -> Type[Decimal]:
return Annotated[ # type: ignore[return-value]
Decimal,
Strict(strict) if strict is not None else None,
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
annotated_types.MultipleOf(multiple_of) if multiple_of is not None else None,
_fields.PydanticGeneralMetadata(max_digits=max_digits, decimal_places=decimal_places),
AllowInfNan(allow_inf_nan) if allow_inf_nan is not None else None,
]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ UUID TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@_dataclasses.dataclass(frozen=True) # Add frozen=True to make it hashable
class UuidVersion:
uuid_version: Literal[1, 3, 4, 5]
def __pydantic_modify_json_schema__(self, field_schema: dict[str, Any]) -> None:
field_schema.pop('anyOf', None) # remove the bytes/str union
field_schema.update(type='string', format=f'uuid{self.uuid_version}')
def __get_pydantic_core_schema__(
self, schema: core_schema.CoreSchema, **_kwargs: Any
) -> core_schema.FunctionSchema:
return core_schema.function_after_schema(schema, cast(core_schema.ValidatorFunction, self.validate))
def validate(self, value: UUID, _: core_schema.ValidationInfo) -> UUID:
if value.version != self.uuid_version:
raise PydanticCustomError(
'uuid_version', 'uuid version {required_version} expected', {'required_version': self.uuid_version}
)
return value
UUID1 = Annotated[UUID, UuidVersion(1)]
UUID3 = Annotated[UUID, UuidVersion(3)]
UUID4 = Annotated[UUID, UuidVersion(4)]
UUID5 = Annotated[UUID, UuidVersion(5)]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PATH TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@_dataclasses.dataclass
class PathType:
path_type: Literal['file', 'dir', 'new']
def __pydantic_modify_json_schema__(self, field_schema: dict[str, Any]) -> None:
format_conversion = {'file': 'file-path', 'dir': 'directory-path'}
field_schema.update(format=format_conversion.get(self.path_type, 'path'), type='string')
def __get_pydantic_core_schema__(
self, schema: core_schema.CoreSchema, **_kwargs: Any
) -> core_schema.FunctionSchema:
function_lookup = {
'file': cast(core_schema.ValidatorFunction, self.validate_file),
'dir': cast(core_schema.ValidatorFunction, self.validate_directory),
'new': cast(core_schema.ValidatorFunction, self.validate_new),
}
return core_schema.function_after_schema(
schema,
function_lookup[self.path_type],
)
@staticmethod
def validate_file(path: Path, _: core_schema.ValidationInfo) -> Path:
if path.is_file():
return path
else:
raise PydanticCustomError('path_not_file', 'Path does not point to a file')
@staticmethod
def validate_directory(path: Path, _: core_schema.ValidationInfo) -> Path:
if path.is_dir():
return path
else:
raise PydanticCustomError('path_not_directory', 'Path does not point to a directory')
@staticmethod
def validate_new(path: Path, _: core_schema.ValidationInfo) -> Path:
if path.exists():
raise PydanticCustomError('path_exists', 'path already exists')
elif not path.parent.exists():
raise PydanticCustomError('parent_does_not_exist', 'Parent directory does not exist')
else:
return path
FilePath = Annotated[Path, PathType('file')]
DirectoryPath = Annotated[Path, PathType('dir')]
NewPath = Annotated[Path, PathType('new')]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ JSON TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if TYPE_CHECKING:
Json = Annotated[AnyType, ...] # Json[list[str]] will be recognized by type checkers as list[str]
else:
class Json:
@classmethod
def __class_getitem__(cls, item: AnyType) -> AnyType:
return Annotated[item, cls()]
@classmethod
def __get_pydantic_core_schema__(
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
) -> core_schema.JsonSchema:
return core_schema.json_schema(schema)
@classmethod
def __pydantic_modify_json_schema__(cls, field_schema: dict[str, Any]) -> None:
field_schema.update(type='string', format='json-string')
def __repr__(self) -> str:
return 'Json'
def __hash__(self) -> int:
return hash(type(self))
def __eq__(self, other: Any) -> bool:
return type(other) == type(self)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SECRET TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
SecretType = TypeVar('SecretType', str, bytes)
class SecretField(abc.ABC, Generic[SecretType]):
_error_kind: str
def __init__(self, secret_value: SecretType) -> None:
self._secret_value: SecretType = secret_value
def get_secret_value(self) -> SecretType:
return self._secret_value
@classmethod
def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionSchema:
validator = SecretFieldValidator(cls)
if issubclass(cls, SecretStr):
# Use a lambda here so that `apply_metadata` can be called on the validator before the override is generated
override = lambda: core_schema.str_schema( # noqa E731
min_length=validator.min_length,
max_length=validator.max_length,
)
elif issubclass(cls, SecretBytes):
override = lambda: core_schema.bytes_schema( # noqa E731
min_length=validator.min_length,
max_length=validator.max_length,
)
else:
override = None
metadata = build_metadata_dict(
update_cs_function=validator.__pydantic_update_schema__,
js_metadata=JsonSchemaMetadata(core_schema_override=override),
)
return core_schema.function_after_schema(
core_schema.union_schema(
core_schema.is_instance_schema(cls),
cls._pre_core_schema(),
strict=True,
custom_error_type=cls._error_kind,
),
validator,
metadata=metadata,
serialization=core_schema.function_plain_ser_schema(cls._serialize, json_return_type='str'),
)
@classmethod
def _serialize(
cls, value: SecretField[SecretType], info: core_schema.SerializationInfo
) -> str | SecretField[SecretType]:
if info.mode == 'json':
# we want the output to always be string without the `b'` prefix for byties,
# hence we just use `secret_display`
return secret_display(value)
else:
return value
@classmethod
@abc.abstractmethod
def _pre_core_schema(cls) -> core_schema.CoreSchema:
...
@classmethod
def __pydantic_modify_json_schema__(cls, field_schema: dict[str, Any]) -> None:
update_not_none(
field_schema,
type='string',
writeOnly=True,
format='password',
)
def __eq__(self, other: Any) -> bool:
return isinstance(other, self.__class__) and self.get_secret_value() == other.get_secret_value()
def __hash__(self) -> int:
return hash(self.get_secret_value())
def __len__(self) -> int:
return len(self._secret_value)
@abc.abstractmethod
def _display(self) -> SecretType:
...
def __str__(self) -> str:
return str(self._display())
def __repr__(self) -> str:
return f'{self.__class__.__name__}({self._display()!r})'
def secret_display(secret_field: SecretField[Any]) -> str:
return '**********' if secret_field.get_secret_value() else ''
class SecretFieldValidator(_fields.CustomValidator, Generic[SecretType]):
__slots__ = 'field_type', 'min_length', 'max_length', 'error_prefix'
def __init__(
self, field_type: Type[SecretField[SecretType]], min_length: int | None = None, max_length: int | None = None
) -> None:
self.field_type: Type[SecretField[SecretType]] = field_type
self.min_length = min_length
self.max_length = max_length
self.error_prefix: Literal['string', 'bytes'] = 'string' if field_type is SecretStr else 'bytes'
def __call__(self, __value: SecretField[SecretType] | SecretType, _: core_schema.ValidationInfo) -> Any:
if self.min_length is not None and len(__value) < self.min_length:
short_kind: core_schema.ErrorType = f'{self.error_prefix}_too_short' # type: ignore[assignment]
raise PydanticKnownError(short_kind, {'min_length': self.min_length})
if self.max_length is not None and len(__value) > self.max_length:
long_kind: core_schema.ErrorType = f'{self.error_prefix}_too_long' # type: ignore[assignment]
raise PydanticKnownError(long_kind, {'max_length': self.max_length})
if isinstance(__value, self.field_type):
return __value
else:
return self.field_type(__value) # type: ignore[arg-type]
def __pydantic_update_schema__(self, schema: core_schema.CoreSchema, **constraints: Any) -> None:
self._update_attrs(constraints, {'min_length', 'max_length'})
class SecretStr(SecretField[str]):
_error_kind = 'string_type'
@classmethod
def _pre_core_schema(cls) -> core_schema.CoreSchema:
return core_schema.str_schema()
def _display(self) -> str:
return secret_display(self)
class SecretBytes(SecretField[bytes]):
_error_kind = 'bytes_type'
@classmethod
def _pre_core_schema(cls) -> core_schema.CoreSchema:
return core_schema.bytes_schema()
def _display(self) -> bytes:
return secret_display(self).encode()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PAYMENT CARD TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
class PaymentCardBrand(str, Enum):
# If you add another card type, please also add it to the
# Hypothesis strategy in `pydantic._hypothesis_plugin`.
amex = 'American Express'
mastercard = 'Mastercard'
visa = 'Visa'
other = 'other'
def __str__(self) -> str:
return self.value
class PaymentCardNumber(str):
"""
Based on: https://en.wikipedia.org/wiki/Payment_card_number
"""
strip_whitespace: ClassVar[bool] = True
min_length: ClassVar[int] = 12
max_length: ClassVar[int] = 19
bin: str
last4: str
brand: PaymentCardBrand
def __init__(self, card_number: str):
self.validate_digits(card_number)
card_number = self.validate_luhn_check_digit(card_number)
self.bin = card_number[:6]
self.last4 = card_number[-4:]
self.brand = self.validate_brand(card_number)
@classmethod
def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionSchema:
return core_schema.function_after_schema(
core_schema.str_schema(
min_length=cls.min_length, max_length=cls.max_length, strip_whitespace=cls.strip_whitespace
),
cls.validate,
)
@classmethod
def validate(cls, __input_value: str, _: core_schema.ValidationInfo) -> 'PaymentCardNumber':
return cls(__input_value)
@property
def masked(self) -> str:
num_masked = len(self) - 10 # len(bin) + len(last4) == 10
return f'{self.bin}{"*" * num_masked}{self.last4}'
@classmethod
def validate_digits(cls, card_number: str) -> None:
if not card_number.isdigit():
raise PydanticCustomError('payment_card_number_digits', 'Card number is not all digits')
@classmethod
def validate_luhn_check_digit(cls, card_number: str) -> str:
"""
Based on: https://en.wikipedia.org/wiki/Luhn_algorithm
"""
sum_ = int(card_number[-1])
length = len(card_number)
parity = length % 2
for i in range(length - 1):
digit = int(card_number[i])
if i % 2 == parity:
digit *= 2
if digit > 9:
digit -= 9
sum_ += digit
valid = sum_ % 10 == 0
if not valid:
raise PydanticCustomError('payment_card_number_luhn', 'Card number is not luhn valid')
return card_number
@staticmethod
def validate_brand(card_number: str) -> PaymentCardBrand:
"""
Validate length based on BIN for major brands:
https://en.wikipedia.org/wiki/Payment_card_number#Issuer_identification_number_(IIN)
"""
if card_number[0] == '4':
brand = PaymentCardBrand.visa
elif 51 <= int(card_number[:2]) <= 55:
brand = PaymentCardBrand.mastercard
elif card_number[:2] in {'34', '37'}:
brand = PaymentCardBrand.amex
else:
brand = PaymentCardBrand.other
required_length: Union[None, int, str] = None
if brand in PaymentCardBrand.mastercard:
required_length = 16
valid = len(card_number) == required_length
elif brand == PaymentCardBrand.visa:
required_length = '13, 16 or 19'
valid = len(card_number) in {13, 16, 19}
elif brand == PaymentCardBrand.amex:
required_length = 15
valid = len(card_number) == required_length
else:
valid = True
if not valid:
raise PydanticCustomError(
'payment_card_number_brand',
'Length for a {brand} card must be {required_length}',
{'brand': brand, 'required_length': required_length},
)
return brand
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BYTE SIZE TYPE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
BYTE_SIZES = {
'b': 1,
'kb': 10**3,
'mb': 10**6,
'gb': 10**9,
'tb': 10**12,
'pb': 10**15,
'eb': 10**18,
'kib': 2**10,
'mib': 2**20,
'gib': 2**30,
'tib': 2**40,
'pib': 2**50,
'eib': 2**60,
}
BYTE_SIZES.update({k.lower()[0]: v for k, v in BYTE_SIZES.items() if 'i' not in k})
byte_string_re = re.compile(r'^\s*(\d*\.?\d+)\s*(\w+)?', re.IGNORECASE)
class ByteSize(int):
@classmethod
def __get_pydantic_core_schema__(cls, **_kwargs: Any) -> core_schema.FunctionPlainSchema:
# TODO better schema
return core_schema.function_plain_schema(cls.validate)
@classmethod
def validate(cls, __input_value: Any, _: core_schema.ValidationInfo) -> 'ByteSize':
try:
return cls(int(__input_value))
except ValueError:
pass
str_match = byte_string_re.match(str(__input_value))
if str_match is None:
raise PydanticCustomError('byte_size', 'could not parse value and unit from byte string')
scalar, unit = str_match.groups()
if unit is None:
unit = 'b'
try:
unit_mult = BYTE_SIZES[unit.lower()]
except KeyError:
raise PydanticCustomError('byte_size_unit', 'could not interpret byte unit: {unit}', {'unit': unit})
return cls(int(float(scalar) * unit_mult))
def human_readable(self, decimal: bool = False) -> str:
if decimal:
divisor = 1000
units = 'B', 'KB', 'MB', 'GB', 'TB', 'PB'
final_unit = 'EB'
else:
divisor = 1024
units = 'B', 'KiB', 'MiB', 'GiB', 'TiB', 'PiB'
final_unit = 'EiB'
num = float(self)
for unit in units:
if abs(num) < divisor:
if unit == 'B':
return f'{num:0.0f}{unit}'
else:
return f'{num:0.1f}{unit}'
num /= divisor
return f'{num:0.1f}{final_unit}'
def to(self, unit: str) -> float:
try:
unit_div = BYTE_SIZES[unit.lower()]
except KeyError:
raise PydanticCustomError('byte_size_unit', 'Could not interpret byte unit: {unit}', {'unit': unit})
return self / unit_div
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATE TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if TYPE_CHECKING:
PastDate = Annotated[date, ...]
FutureDate = Annotated[date, ...]
else:
class PastDate:
@classmethod
def __get_pydantic_core_schema__(
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
) -> core_schema.CoreSchema:
if schema is None:
# used directly as a type
return core_schema.date_schema(now_op='past')
else:
assert schema['type'] == 'date'
schema['now_op'] = 'past'
return schema
def __repr__(self) -> str:
return 'PastDate'
class FutureDate:
@classmethod
def __get_pydantic_core_schema__(
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
) -> core_schema.CoreSchema:
if schema is None:
# used directly as a type
return core_schema.date_schema(now_op='future')
else:
assert schema['type'] == 'date'
schema['now_op'] = 'future'
return schema
def __repr__(self) -> str:
return 'FutureDate'
def condate(*, strict: bool = None, gt: date = None, ge: date = None, lt: date = None, le: date = None) -> type[date]:
return Annotated[ # type: ignore[return-value]
date,
Strict(strict) if strict is not None else None,
annotated_types.Interval(gt=gt, ge=ge, lt=lt, le=le),
]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DATETIME TYPES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if TYPE_CHECKING:
AwareDatetime = Annotated[datetime, ...]
NaiveDatetime = Annotated[datetime, ...]
else:
class AwareDatetime:
@classmethod
def __get_pydantic_core_schema__(
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
) -> core_schema.CoreSchema:
if schema is None:
# used directly as a type
return core_schema.datetime_schema(tz_constraint='aware')
else:
assert schema['type'] == 'datetime'
schema['tz_constraint'] = 'aware'
return schema
def __repr__(self) -> str:
return 'AwareDatetime'
class NaiveDatetime:
@classmethod
def __get_pydantic_core_schema__(
cls, schema: core_schema.CoreSchema | None = None, **_kwargs: Any
) -> core_schema.CoreSchema:
if schema is None:
# used directly as a type
return core_schema.datetime_schema(tz_constraint='naive')
else:
assert schema['type'] == 'datetime'
schema['tz_constraint'] = 'naive'
return schema
def __repr__(self) -> str:
return 'NaiveDatetime'

View file

@ -0,0 +1,161 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
处理汉语拼音方案中的一些特殊情况
汉语拼音方案:
* https://zh.wiktionary.org/wiki/%E9%99%84%E5%BD%95:%E6%B1%89%E8%AF%AD%E6%8B%BC%E9%9F%B3%E6%96%B9%E6%A1%88
* http://www.moe.edu.cn/s78/A19/yxs_left/moe_810/s230/195802/t19580201_186000.html
""" # noqa
from __future__ import unicode_literals
import re
from pypinyin.style._constants import _FINALS
# u -> ü
UV_MAP = {
'u': 'ü',
'ū': 'ǖ',
'ú': 'ǘ',
'ǔ': 'ǚ',
'ù': 'ǜ',
}
U_TONES = set(UV_MAP.keys())
# ü行的韵跟声母jqx拼的时候写成ju(居)qu(区)xu(虚)
UV_RE = re.compile(
r'^(j|q|x)({tones})(.*)$'.format(
tones='|'.join(UV_MAP.keys())))
I_TONES = set(['i', 'ī', 'í', 'ǐ', 'ì'])
# iu -> iou
IU_MAP = {
'iu': 'iou',
'': 'ioū',
'': 'ioú',
'': 'ioǔ',
'': 'ioù',
}
IU_TONES = set(IU_MAP.keys())
IU_RE = re.compile(r'^([a-z]+)({tones})$'.format(tones='|'.join(IU_TONES)))
# ui -> uei
UI_MAP = {
'ui': 'uei',
'': 'ueī',
'': 'ueí',
'': 'ueǐ',
'': 'ueì',
}
UI_TONES = set(UI_MAP.keys())
UI_RE = re.compile(r'([a-z]+)({tones})$'.format(tones='|'.join(UI_TONES)))
# un -> uen
UN_MAP = {
'un': 'uen',
'ūn': 'ūen',
'ún': 'úen',
'ǔn': 'ǔen',
'ùn': 'ùen',
}
UN_TONES = set(UN_MAP.keys())
UN_RE = re.compile(r'([a-z]+)({tones})$'.format(tones='|'.join(UN_TONES)))
def convert_zero_consonant(pinyin):
"""零声母转换,还原原始的韵母
i行的韵母前面没有声母的时候写成yi()ya()ye()yao()
you()yan()yin()yang()ying()yong()
u行的韵母前面没有声母的时候写成wu()wa()wo()wai()
wei()wan()wen()wang()weng()
ü行的韵母前面没有声母的时候写成yu()yue()yuan()
yun()ü上两点省略
"""
raw_pinyin = pinyin
# y: yu -> v, yi -> i, y -> i
if raw_pinyin.startswith('y'):
# 去除 y 后的拼音
no_y_py = pinyin[1:]
first_char = no_y_py[0] if len(no_y_py) > 0 else None
# yu -> ü: yue -> üe
if first_char in U_TONES:
pinyin = UV_MAP[first_char] + pinyin[2:]
# yi -> i: yi -> i
elif first_char in I_TONES:
pinyin = no_y_py
# y -> i: ya -> ia
else:
pinyin = 'i' + no_y_py
# w: wu -> u, w -> u
if raw_pinyin.startswith('w'):
# 去除 w 后的拼音
no_w_py = pinyin[1:]
first_char = no_w_py[0] if len(no_w_py) > 0 else None
# wu -> u: wu -> u
if first_char in U_TONES:
pinyin = pinyin[1:]
# w -> u: wa -> ua
else:
pinyin = 'u' + pinyin[1:]
# 确保不会出现韵母表中不存在的韵母
if pinyin not in _FINALS:
return raw_pinyin
return pinyin
def convert_uv(pinyin):
"""ü 转换,还原原始的韵母
ü行的韵跟声母jqx拼的时候写成ju()qu()xu()
ü上两点也省略但是跟声母nl拼的时候仍然写成nü()()
"""
return UV_RE.sub(
lambda m: ''.join((m.group(1), UV_MAP[m.group(2)], m.group(3))),
pinyin)
def convert_iou(pinyin):
"""iou 转换,还原原始的韵母
iouueiuen前面加声母的时候写成iuuiun
例如niu()gui()lun()
"""
return IU_RE.sub(lambda m: m.group(1) + IU_MAP[m.group(2)], pinyin)
def convert_uei(pinyin):
"""uei 转换,还原原始的韵母
iouueiuen前面加声母的时候写成iuuiun
例如niu()gui()lun()
"""
return UI_RE.sub(lambda m: m.group(1) + UI_MAP[m.group(2)], pinyin)
def convert_uen(pinyin):
"""uen 转换,还原原始的韵母
iouueiuen前面加声母的时候写成iuuiun
例如niu()gui()lun()
"""
return UN_RE.sub(lambda m: m.group(1) + UN_MAP[m.group(2)], pinyin)
def convert_finals(pinyin):
"""还原原始的韵母"""
pinyin = convert_zero_consonant(pinyin)
pinyin = convert_uv(pinyin)
pinyin = convert_iou(pinyin)
pinyin = convert_uei(pinyin)
pinyin = convert_uen(pinyin)
return pinyin

View file

@ -0,0 +1,10 @@
# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
# Licensed to PSF under a Contributor Agreement.
__all__ = ("loads", "load", "TOMLDecodeError")
from ._parser import TOMLDecodeError, load, loads
# Pretend this exception was created here.
TOMLDecodeError.__module__ = __name__

View file

@ -0,0 +1,691 @@
# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
# Licensed to PSF under a Contributor Agreement.
from __future__ import annotations
from collections.abc import Iterable
import string
from types import MappingProxyType
from typing import Any, BinaryIO, NamedTuple
from ._re import (
RE_DATETIME,
RE_LOCALTIME,
RE_NUMBER,
match_to_datetime,
match_to_localtime,
match_to_number,
)
from ._types import Key, ParseFloat, Pos
ASCII_CTRL = frozenset(chr(i) for i in range(32)) | frozenset(chr(127))
# Neither of these sets include quotation mark or backslash. They are
# currently handled as separate cases in the parser functions.
ILLEGAL_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t")
ILLEGAL_MULTILINE_BASIC_STR_CHARS = ASCII_CTRL - frozenset("\t\n")
ILLEGAL_LITERAL_STR_CHARS = ILLEGAL_BASIC_STR_CHARS
ILLEGAL_MULTILINE_LITERAL_STR_CHARS = ILLEGAL_MULTILINE_BASIC_STR_CHARS
ILLEGAL_COMMENT_CHARS = ILLEGAL_BASIC_STR_CHARS
TOML_WS = frozenset(" \t")
TOML_WS_AND_NEWLINE = TOML_WS | frozenset("\n")
BARE_KEY_CHARS = frozenset(string.ascii_letters + string.digits + "-_")
KEY_INITIAL_CHARS = BARE_KEY_CHARS | frozenset("\"'")
HEXDIGIT_CHARS = frozenset(string.hexdigits)
BASIC_STR_ESCAPE_REPLACEMENTS = MappingProxyType(
{
"\\b": "\u0008", # backspace
"\\t": "\u0009", # tab
"\\n": "\u000A", # linefeed
"\\f": "\u000C", # form feed
"\\r": "\u000D", # carriage return
'\\"': "\u0022", # quote
"\\\\": "\u005C", # backslash
}
)
class TOMLDecodeError(ValueError):
"""An error raised if a document is not valid TOML."""
def load(fp: BinaryIO, /, *, parse_float: ParseFloat = float) -> dict[str, Any]:
"""Parse TOML from a binary file object."""
b = fp.read()
try:
s = b.decode()
except AttributeError:
raise TypeError(
"File must be opened in binary mode, e.g. use `open('foo.toml', 'rb')`"
) from None
return loads(s, parse_float=parse_float)
def loads(s: str, /, *, parse_float: ParseFloat = float) -> dict[str, Any]: # noqa: C901
"""Parse TOML from a string."""
# The spec allows converting "\r\n" to "\n", even in string
# literals. Let's do so to simplify parsing.
src = s.replace("\r\n", "\n")
pos = 0
out = Output(NestedDict(), Flags())
header: Key = ()
parse_float = make_safe_parse_float(parse_float)
# Parse one statement at a time
# (typically means one line in TOML source)
while True:
# 1. Skip line leading whitespace
pos = skip_chars(src, pos, TOML_WS)
# 2. Parse rules. Expect one of the following:
# - end of file
# - end of line
# - comment
# - key/value pair
# - append dict to list (and move to its namespace)
# - create dict (and move to its namespace)
# Skip trailing whitespace when applicable.
try:
char = src[pos]
except IndexError:
break
if char == "\n":
pos += 1
continue
if char in KEY_INITIAL_CHARS:
pos = key_value_rule(src, pos, out, header, parse_float)
pos = skip_chars(src, pos, TOML_WS)
elif char == "[":
try:
second_char: str | None = src[pos + 1]
except IndexError:
second_char = None
out.flags.finalize_pending()
if second_char == "[":
pos, header = create_list_rule(src, pos, out)
else:
pos, header = create_dict_rule(src, pos, out)
pos = skip_chars(src, pos, TOML_WS)
elif char != "#":
raise suffixed_err(src, pos, "Invalid statement")
# 3. Skip comment
pos = skip_comment(src, pos)
# 4. Expect end of line or end of file
try:
char = src[pos]
except IndexError:
break
if char != "\n":
raise suffixed_err(
src, pos, "Expected newline or end of document after a statement"
)
pos += 1
return out.data.dict
class Flags:
"""Flags that map to parsed keys/namespaces."""
# Marks an immutable namespace (inline array or inline table).
FROZEN = 0
# Marks a nest that has been explicitly created and can no longer
# be opened using the "[table]" syntax.
EXPLICIT_NEST = 1
def __init__(self) -> None:
self._flags: dict[str, dict] = {}
self._pending_flags: set[tuple[Key, int]] = set()
def add_pending(self, key: Key, flag: int) -> None:
self._pending_flags.add((key, flag))
def finalize_pending(self) -> None:
for key, flag in self._pending_flags:
self.set(key, flag, recursive=False)
self._pending_flags.clear()
def unset_all(self, key: Key) -> None:
cont = self._flags
for k in key[:-1]:
if k not in cont:
return
cont = cont[k]["nested"]
cont.pop(key[-1], None)
def set(self, key: Key, flag: int, *, recursive: bool) -> None: # noqa: A003
cont = self._flags
key_parent, key_stem = key[:-1], key[-1]
for k in key_parent:
if k not in cont:
cont[k] = {"flags": set(), "recursive_flags": set(), "nested": {}}
cont = cont[k]["nested"]
if key_stem not in cont:
cont[key_stem] = {"flags": set(), "recursive_flags": set(), "nested": {}}
cont[key_stem]["recursive_flags" if recursive else "flags"].add(flag)
def is_(self, key: Key, flag: int) -> bool:
if not key:
return False # document root has no flags
cont = self._flags
for k in key[:-1]:
if k not in cont:
return False
inner_cont = cont[k]
if flag in inner_cont["recursive_flags"]:
return True
cont = inner_cont["nested"]
key_stem = key[-1]
if key_stem in cont:
cont = cont[key_stem]
return flag in cont["flags"] or flag in cont["recursive_flags"]
return False
class NestedDict:
def __init__(self) -> None:
# The parsed content of the TOML document
self.dict: dict[str, Any] = {}
def get_or_create_nest(
self,
key: Key,
*,
access_lists: bool = True,
) -> dict:
cont: Any = self.dict
for k in key:
if k not in cont:
cont[k] = {}
cont = cont[k]
if access_lists and isinstance(cont, list):
cont = cont[-1]
if not isinstance(cont, dict):
raise KeyError("There is no nest behind this key")
return cont
def append_nest_to_list(self, key: Key) -> None:
cont = self.get_or_create_nest(key[:-1])
last_key = key[-1]
if last_key in cont:
list_ = cont[last_key]
if not isinstance(list_, list):
raise KeyError("An object other than list found behind this key")
list_.append({})
else:
cont[last_key] = [{}]
class Output(NamedTuple):
data: NestedDict
flags: Flags
def skip_chars(src: str, pos: Pos, chars: Iterable[str]) -> Pos:
try:
while src[pos] in chars:
pos += 1
except IndexError:
pass
return pos
def skip_until(
src: str,
pos: Pos,
expect: str,
*,
error_on: frozenset[str],
error_on_eof: bool,
) -> Pos:
try:
new_pos = src.index(expect, pos)
except ValueError:
new_pos = len(src)
if error_on_eof:
raise suffixed_err(src, new_pos, f"Expected {expect!r}") from None
if not error_on.isdisjoint(src[pos:new_pos]):
while src[pos] not in error_on:
pos += 1
raise suffixed_err(src, pos, f"Found invalid character {src[pos]!r}")
return new_pos
def skip_comment(src: str, pos: Pos) -> Pos:
try:
char: str | None = src[pos]
except IndexError:
char = None
if char == "#":
return skip_until(
src, pos + 1, "\n", error_on=ILLEGAL_COMMENT_CHARS, error_on_eof=False
)
return pos
def skip_comments_and_array_ws(src: str, pos: Pos) -> Pos:
while True:
pos_before_skip = pos
pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
pos = skip_comment(src, pos)
if pos == pos_before_skip:
return pos
def create_dict_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
pos += 1 # Skip "["
pos = skip_chars(src, pos, TOML_WS)
pos, key = parse_key(src, pos)
if out.flags.is_(key, Flags.EXPLICIT_NEST) or out.flags.is_(key, Flags.FROZEN):
raise suffixed_err(src, pos, f"Cannot declare {key} twice")
out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
try:
out.data.get_or_create_nest(key)
except KeyError:
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
if not src.startswith("]", pos):
raise suffixed_err(src, pos, "Expected ']' at the end of a table declaration")
return pos + 1, key
def create_list_rule(src: str, pos: Pos, out: Output) -> tuple[Pos, Key]:
pos += 2 # Skip "[["
pos = skip_chars(src, pos, TOML_WS)
pos, key = parse_key(src, pos)
if out.flags.is_(key, Flags.FROZEN):
raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
# Free the namespace now that it points to another empty list item...
out.flags.unset_all(key)
# ...but this key precisely is still prohibited from table declaration
out.flags.set(key, Flags.EXPLICIT_NEST, recursive=False)
try:
out.data.append_nest_to_list(key)
except KeyError:
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
if not src.startswith("]]", pos):
raise suffixed_err(src, pos, "Expected ']]' at the end of an array declaration")
return pos + 2, key
def key_value_rule(
src: str, pos: Pos, out: Output, header: Key, parse_float: ParseFloat
) -> Pos:
pos, key, value = parse_key_value_pair(src, pos, parse_float)
key_parent, key_stem = key[:-1], key[-1]
abs_key_parent = header + key_parent
relative_path_cont_keys = (header + key[:i] for i in range(1, len(key)))
for cont_key in relative_path_cont_keys:
# Check that dotted key syntax does not redefine an existing table
if out.flags.is_(cont_key, Flags.EXPLICIT_NEST):
raise suffixed_err(src, pos, f"Cannot redefine namespace {cont_key}")
# Containers in the relative path can't be opened with the table syntax or
# dotted key/value syntax in following table sections.
out.flags.add_pending(cont_key, Flags.EXPLICIT_NEST)
if out.flags.is_(abs_key_parent, Flags.FROZEN):
raise suffixed_err(
src, pos, f"Cannot mutate immutable namespace {abs_key_parent}"
)
try:
nest = out.data.get_or_create_nest(abs_key_parent)
except KeyError:
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
if key_stem in nest:
raise suffixed_err(src, pos, "Cannot overwrite a value")
# Mark inline table and array namespaces recursively immutable
if isinstance(value, (dict, list)):
out.flags.set(header + key, Flags.FROZEN, recursive=True)
nest[key_stem] = value
return pos
def parse_key_value_pair(
src: str, pos: Pos, parse_float: ParseFloat
) -> tuple[Pos, Key, Any]:
pos, key = parse_key(src, pos)
try:
char: str | None = src[pos]
except IndexError:
char = None
if char != "=":
raise suffixed_err(src, pos, "Expected '=' after a key in a key/value pair")
pos += 1
pos = skip_chars(src, pos, TOML_WS)
pos, value = parse_value(src, pos, parse_float)
return pos, key, value
def parse_key(src: str, pos: Pos) -> tuple[Pos, Key]:
pos, key_part = parse_key_part(src, pos)
key: Key = (key_part,)
pos = skip_chars(src, pos, TOML_WS)
while True:
try:
char: str | None = src[pos]
except IndexError:
char = None
if char != ".":
return pos, key
pos += 1
pos = skip_chars(src, pos, TOML_WS)
pos, key_part = parse_key_part(src, pos)
key += (key_part,)
pos = skip_chars(src, pos, TOML_WS)
def parse_key_part(src: str, pos: Pos) -> tuple[Pos, str]:
try:
char: str | None = src[pos]
except IndexError:
char = None
if char in BARE_KEY_CHARS:
start_pos = pos
pos = skip_chars(src, pos, BARE_KEY_CHARS)
return pos, src[start_pos:pos]
if char == "'":
return parse_literal_str(src, pos)
if char == '"':
return parse_one_line_basic_str(src, pos)
raise suffixed_err(src, pos, "Invalid initial character for a key part")
def parse_one_line_basic_str(src: str, pos: Pos) -> tuple[Pos, str]:
pos += 1
return parse_basic_str(src, pos, multiline=False)
def parse_array(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, list]:
pos += 1
array: list = []
pos = skip_comments_and_array_ws(src, pos)
if src.startswith("]", pos):
return pos + 1, array
while True:
pos, val = parse_value(src, pos, parse_float)
array.append(val)
pos = skip_comments_and_array_ws(src, pos)
c = src[pos : pos + 1]
if c == "]":
return pos + 1, array
if c != ",":
raise suffixed_err(src, pos, "Unclosed array")
pos += 1
pos = skip_comments_and_array_ws(src, pos)
if src.startswith("]", pos):
return pos + 1, array
def parse_inline_table(src: str, pos: Pos, parse_float: ParseFloat) -> tuple[Pos, dict]:
pos += 1
nested_dict = NestedDict()
flags = Flags()
pos = skip_chars(src, pos, TOML_WS)
if src.startswith("}", pos):
return pos + 1, nested_dict.dict
while True:
pos, key, value = parse_key_value_pair(src, pos, parse_float)
key_parent, key_stem = key[:-1], key[-1]
if flags.is_(key, Flags.FROZEN):
raise suffixed_err(src, pos, f"Cannot mutate immutable namespace {key}")
try:
nest = nested_dict.get_or_create_nest(key_parent, access_lists=False)
except KeyError:
raise suffixed_err(src, pos, "Cannot overwrite a value") from None
if key_stem in nest:
raise suffixed_err(src, pos, f"Duplicate inline table key {key_stem!r}")
nest[key_stem] = value
pos = skip_chars(src, pos, TOML_WS)
c = src[pos : pos + 1]
if c == "}":
return pos + 1, nested_dict.dict
if c != ",":
raise suffixed_err(src, pos, "Unclosed inline table")
if isinstance(value, (dict, list)):
flags.set(key, Flags.FROZEN, recursive=True)
pos += 1
pos = skip_chars(src, pos, TOML_WS)
def parse_basic_str_escape(
src: str, pos: Pos, *, multiline: bool = False
) -> tuple[Pos, str]:
escape_id = src[pos : pos + 2]
pos += 2
if multiline and escape_id in {"\\ ", "\\\t", "\\\n"}:
# Skip whitespace until next non-whitespace character or end of
# the doc. Error if non-whitespace is found before newline.
if escape_id != "\\\n":
pos = skip_chars(src, pos, TOML_WS)
try:
char = src[pos]
except IndexError:
return pos, ""
if char != "\n":
raise suffixed_err(src, pos, "Unescaped '\\' in a string")
pos += 1
pos = skip_chars(src, pos, TOML_WS_AND_NEWLINE)
return pos, ""
if escape_id == "\\u":
return parse_hex_char(src, pos, 4)
if escape_id == "\\U":
return parse_hex_char(src, pos, 8)
try:
return pos, BASIC_STR_ESCAPE_REPLACEMENTS[escape_id]
except KeyError:
raise suffixed_err(src, pos, "Unescaped '\\' in a string") from None
def parse_basic_str_escape_multiline(src: str, pos: Pos) -> tuple[Pos, str]:
return parse_basic_str_escape(src, pos, multiline=True)
def parse_hex_char(src: str, pos: Pos, hex_len: int) -> tuple[Pos, str]:
hex_str = src[pos : pos + hex_len]
if len(hex_str) != hex_len or not HEXDIGIT_CHARS.issuperset(hex_str):
raise suffixed_err(src, pos, "Invalid hex value")
pos += hex_len
hex_int = int(hex_str, 16)
if not is_unicode_scalar_value(hex_int):
raise suffixed_err(src, pos, "Escaped character is not a Unicode scalar value")
return pos, chr(hex_int)
def parse_literal_str(src: str, pos: Pos) -> tuple[Pos, str]:
pos += 1 # Skip starting apostrophe
start_pos = pos
pos = skip_until(
src, pos, "'", error_on=ILLEGAL_LITERAL_STR_CHARS, error_on_eof=True
)
return pos + 1, src[start_pos:pos] # Skip ending apostrophe
def parse_multiline_str(src: str, pos: Pos, *, literal: bool) -> tuple[Pos, str]:
pos += 3
if src.startswith("\n", pos):
pos += 1
if literal:
delim = "'"
end_pos = skip_until(
src,
pos,
"'''",
error_on=ILLEGAL_MULTILINE_LITERAL_STR_CHARS,
error_on_eof=True,
)
result = src[pos:end_pos]
pos = end_pos + 3
else:
delim = '"'
pos, result = parse_basic_str(src, pos, multiline=True)
# Add at maximum two extra apostrophes/quotes if the end sequence
# is 4 or 5 chars long instead of just 3.
if not src.startswith(delim, pos):
return pos, result
pos += 1
if not src.startswith(delim, pos):
return pos, result + delim
pos += 1
return pos, result + (delim * 2)
def parse_basic_str(src: str, pos: Pos, *, multiline: bool) -> tuple[Pos, str]:
if multiline:
error_on = ILLEGAL_MULTILINE_BASIC_STR_CHARS
parse_escapes = parse_basic_str_escape_multiline
else:
error_on = ILLEGAL_BASIC_STR_CHARS
parse_escapes = parse_basic_str_escape
result = ""
start_pos = pos
while True:
try:
char = src[pos]
except IndexError:
raise suffixed_err(src, pos, "Unterminated string") from None
if char == '"':
if not multiline:
return pos + 1, result + src[start_pos:pos]
if src.startswith('"""', pos):
return pos + 3, result + src[start_pos:pos]
pos += 1
continue
if char == "\\":
result += src[start_pos:pos]
pos, parsed_escape = parse_escapes(src, pos)
result += parsed_escape
start_pos = pos
continue
if char in error_on:
raise suffixed_err(src, pos, f"Illegal character {char!r}")
pos += 1
def parse_value( # noqa: C901
src: str, pos: Pos, parse_float: ParseFloat
) -> tuple[Pos, Any]:
try:
char: str | None = src[pos]
except IndexError:
char = None
# IMPORTANT: order conditions based on speed of checking and likelihood
# Basic strings
if char == '"':
if src.startswith('"""', pos):
return parse_multiline_str(src, pos, literal=False)
return parse_one_line_basic_str(src, pos)
# Literal strings
if char == "'":
if src.startswith("'''", pos):
return parse_multiline_str(src, pos, literal=True)
return parse_literal_str(src, pos)
# Booleans
if char == "t":
if src.startswith("true", pos):
return pos + 4, True
if char == "f":
if src.startswith("false", pos):
return pos + 5, False
# Arrays
if char == "[":
return parse_array(src, pos, parse_float)
# Inline tables
if char == "{":
return parse_inline_table(src, pos, parse_float)
# Dates and times
datetime_match = RE_DATETIME.match(src, pos)
if datetime_match:
try:
datetime_obj = match_to_datetime(datetime_match)
except ValueError as e:
raise suffixed_err(src, pos, "Invalid date or datetime") from e
return datetime_match.end(), datetime_obj
localtime_match = RE_LOCALTIME.match(src, pos)
if localtime_match:
return localtime_match.end(), match_to_localtime(localtime_match)
# Integers and "normal" floats.
# The regex will greedily match any type starting with a decimal
# char, so needs to be located after handling of dates and times.
number_match = RE_NUMBER.match(src, pos)
if number_match:
return number_match.end(), match_to_number(number_match, parse_float)
# Special floats
first_three = src[pos : pos + 3]
if first_three in {"inf", "nan"}:
return pos + 3, parse_float(first_three)
first_four = src[pos : pos + 4]
if first_four in {"-inf", "+inf", "-nan", "+nan"}:
return pos + 4, parse_float(first_four)
raise suffixed_err(src, pos, "Invalid value")
def suffixed_err(src: str, pos: Pos, msg: str) -> TOMLDecodeError:
"""Return a `TOMLDecodeError` where error message is suffixed with
coordinates in source."""
def coord_repr(src: str, pos: Pos) -> str:
if pos >= len(src):
return "end of document"
line = src.count("\n", 0, pos) + 1
if line == 1:
column = pos + 1
else:
column = pos - src.rindex("\n", 0, pos)
return f"line {line}, column {column}"
return TOMLDecodeError(f"{msg} (at {coord_repr(src, pos)})")
def is_unicode_scalar_value(codepoint: int) -> bool:
return (0 <= codepoint <= 55295) or (57344 <= codepoint <= 1114111)
def make_safe_parse_float(parse_float: ParseFloat) -> ParseFloat:
"""A decorator to make `parse_float` safe.
`parse_float` must not return dicts or lists, because these types
would be mixed with parsed TOML tables and arrays, thus confusing
the parser. The returned decorated callable raises `ValueError`
instead of returning illegal types.
"""
# The default `float` callable never returns illegal types. Optimize it.
if parse_float is float: # type: ignore[comparison-overlap]
return float
def safe_parse_float(float_str: str) -> Any:
float_value = parse_float(float_str)
if isinstance(float_value, (dict, list)):
raise ValueError("parse_float must not return dicts or lists")
return float_value
return safe_parse_float

View file

@ -0,0 +1,107 @@
# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
# Licensed to PSF under a Contributor Agreement.
from __future__ import annotations
from datetime import date, datetime, time, timedelta, timezone, tzinfo
from functools import lru_cache
import re
from typing import Any
from ._types import ParseFloat
# E.g.
# - 00:32:00.999999
# - 00:32:00
_TIME_RE_STR = r"([01][0-9]|2[0-3]):([0-5][0-9]):([0-5][0-9])(?:\.([0-9]{1,6})[0-9]*)?"
RE_NUMBER = re.compile(
r"""
0
(?:
x[0-9A-Fa-f](?:_?[0-9A-Fa-f])* # hex
|
b[01](?:_?[01])* # bin
|
o[0-7](?:_?[0-7])* # oct
)
|
[+-]?(?:0|[1-9](?:_?[0-9])*) # dec, integer part
(?P<floatpart>
(?:\.[0-9](?:_?[0-9])*)? # optional fractional part
(?:[eE][+-]?[0-9](?:_?[0-9])*)? # optional exponent part
)
""",
flags=re.VERBOSE,
)
RE_LOCALTIME = re.compile(_TIME_RE_STR)
RE_DATETIME = re.compile(
rf"""
([0-9]{{4}})-(0[1-9]|1[0-2])-(0[1-9]|[12][0-9]|3[01]) # date, e.g. 1988-10-27
(?:
[Tt ]
{_TIME_RE_STR}
(?:([Zz])|([+-])([01][0-9]|2[0-3]):([0-5][0-9]))? # optional time offset
)?
""",
flags=re.VERBOSE,
)
def match_to_datetime(match: re.Match) -> datetime | date:
"""Convert a `RE_DATETIME` match to `datetime.datetime` or `datetime.date`.
Raises ValueError if the match does not correspond to a valid date
or datetime.
"""
(
year_str,
month_str,
day_str,
hour_str,
minute_str,
sec_str,
micros_str,
zulu_time,
offset_sign_str,
offset_hour_str,
offset_minute_str,
) = match.groups()
year, month, day = int(year_str), int(month_str), int(day_str)
if hour_str is None:
return date(year, month, day)
hour, minute, sec = int(hour_str), int(minute_str), int(sec_str)
micros = int(micros_str.ljust(6, "0")) if micros_str else 0
if offset_sign_str:
tz: tzinfo | None = cached_tz(
offset_hour_str, offset_minute_str, offset_sign_str
)
elif zulu_time:
tz = timezone.utc
else: # local date-time
tz = None
return datetime(year, month, day, hour, minute, sec, micros, tzinfo=tz)
@lru_cache(maxsize=None)
def cached_tz(hour_str: str, minute_str: str, sign_str: str) -> timezone:
sign = 1 if sign_str == "+" else -1
return timezone(
timedelta(
hours=sign * int(hour_str),
minutes=sign * int(minute_str),
)
)
def match_to_localtime(match: re.Match) -> time:
hour_str, minute_str, sec_str, micros_str = match.groups()
micros = int(micros_str.ljust(6, "0")) if micros_str else 0
return time(int(hour_str), int(minute_str), int(sec_str), micros)
def match_to_number(match: re.Match, parse_float: ParseFloat) -> Any:
if match.group("floatpart"):
return parse_float(match.group())
return int(match.group(), 0)

View file

@ -0,0 +1,10 @@
# SPDX-License-Identifier: MIT
# SPDX-FileCopyrightText: 2021 Taneli Hukkinen
# Licensed to PSF under a Contributor Agreement.
from typing import Any, Callable, Tuple
# Type annotations
ParseFloat = Callable[[str], Any]
Key = Tuple[str, ...]
Pos = int

View file

@ -1,10 +1,32 @@
use std::path::PathBuf;
pub mod criterion;
use std::fmt::{Display, Formatter};
use std::path::PathBuf;
use std::process::Command;
pub static NUMPY_GLOBALS: TestFile = TestFile::new(
"numpy/globals.py",
include_str!("../resources/numpy/globals.py"),
);
use url::Url;
pub static UNICODE_PYPINYIN: TestFile = TestFile::new(
"unicode/pypinyin.py",
include_str!("../resources/pypinyin.py"),
);
pub static PYDANTIC_TYPES: TestFile = TestFile::new(
"pydantic/types.py",
include_str!("../resources/pydantic/types.py"),
);
pub static NUMPY_CTYPESLIB: TestFile = TestFile::new(
"numpy/ctypeslib.py",
include_str!("../resources/numpy/ctypeslib.py"),
);
// "https://raw.githubusercontent.com/DHI/mikeio/b7d26418f4db2909b0aa965253dbe83194d7bb5b/tests/test_dataset.py"
pub static LARGE_DATASET: TestFile = TestFile::new(
"large/dataset.py",
include_str!("../resources/large/dataset.py"),
);
/// Relative size of a test case. Benchmarks can use it to configure the time for how long a benchmark should run to get stable results.
#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd)]
@ -26,35 +48,33 @@ pub struct TestCase {
}
impl TestCase {
pub fn fast(file: TestFile) -> Self {
pub const fn fast(file: TestFile) -> Self {
Self {
file,
speed: TestCaseSpeed::Fast,
}
}
pub fn normal(file: TestFile) -> Self {
pub const fn normal(file: TestFile) -> Self {
Self {
file,
speed: TestCaseSpeed::Normal,
}
}
pub fn slow(file: TestFile) -> Self {
pub const fn slow(file: TestFile) -> Self {
Self {
file,
speed: TestCaseSpeed::Slow,
}
}
}
impl TestCase {
pub fn code(&self) -> &str {
&self.file.code
self.file.code
}
pub fn name(&self) -> &str {
&self.file.name
self.file.name
}
pub fn speed(&self) -> TestCaseSpeed {
@ -62,119 +82,32 @@ impl TestCase {
}
pub fn path(&self) -> PathBuf {
TARGET_DIR.join(self.name())
PathBuf::from(file!())
.parent()
.unwrap()
.parent()
.unwrap()
.join("resources")
.join(self.name())
}
}
#[derive(Debug, Clone)]
pub struct TestFile {
name: String,
code: String,
name: &'static str,
code: &'static str,
}
impl TestFile {
pub fn code(&self) -> &str {
&self.code
}
pub fn name(&self) -> &str {
&self.name
}
}
static TARGET_DIR: std::sync::LazyLock<PathBuf> = std::sync::LazyLock::new(|| {
cargo_target_directory().unwrap_or_else(|| PathBuf::from("target"))
});
fn cargo_target_directory() -> Option<PathBuf> {
#[derive(serde::Deserialize)]
struct Metadata {
target_directory: PathBuf,
}
std::env::var_os("CARGO_TARGET_DIR")
.map(PathBuf::from)
.or_else(|| {
let output = Command::new(std::env::var_os("CARGO")?)
.args(["metadata", "--format-version", "1"])
.output()
.ok()?;
let metadata: Metadata = serde_json::from_slice(&output.stdout).ok()?;
Some(metadata.target_directory)
})
}
impl TestFile {
pub fn new(name: String, code: String) -> Self {
pub const fn new(name: &'static str, code: &'static str) -> Self {
Self { name, code }
}
#[allow(clippy::print_stderr)]
pub fn try_download(name: &str, url: &str) -> Result<TestFile, TestFileDownloadError> {
let url = Url::parse(url)?;
pub fn code(&self) -> &str {
self.code
}
let cached_filename = TARGET_DIR.join(name);
if let Ok(content) = std::fs::read_to_string(&cached_filename) {
Ok(TestFile::new(name.to_string(), content))
} else {
// File not yet cached, download and cache it in the target directory
let response = ureq::get(url.as_str()).call()?;
let content = response.into_string()?;
// SAFETY: There's always the `target` directory
let parent = cached_filename.parent().unwrap();
if let Err(error) = std::fs::create_dir_all(parent) {
eprintln!("Failed to create the directory for the test case {name}: {error}");
} else if let Err(error) = std::fs::write(cached_filename, &content) {
eprintln!("Failed to cache test case file downloaded from {url}: {error}");
}
Ok(TestFile::new(name.to_string(), content))
}
pub fn name(&self) -> &str {
self.name
}
}
#[derive(Debug)]
pub enum TestFileDownloadError {
UrlParse(url::ParseError),
Request(Box<ureq::Error>),
Download(std::io::Error),
}
impl From<url::ParseError> for TestFileDownloadError {
fn from(value: url::ParseError) -> Self {
Self::UrlParse(value)
}
}
impl From<ureq::Error> for TestFileDownloadError {
fn from(value: ureq::Error) -> Self {
Self::Request(Box::new(value))
}
}
impl From<std::io::Error> for TestFileDownloadError {
fn from(value: std::io::Error) -> Self {
Self::Download(value)
}
}
impl Display for TestFileDownloadError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TestFileDownloadError::UrlParse(inner) => {
write!(f, "Failed to parse url: {inner}")
}
TestFileDownloadError::Request(inner) => {
write!(f, "Failed to download file: {inner}")
}
TestFileDownloadError::Download(inner) => {
write!(f, "Failed to download file: {inner}")
}
}
}
}
impl std::error::Error for TestFileDownloadError {}