mirror of
https://github.com/python/cpython.git
synced 2025-09-25 01:43:11 +00:00
[3.14] gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432) (#134442)
gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432)
Refactor shared code in train_dict and finalize_dict
(cherry picked from commit c64a21454b
)
Co-authored-by: Emma Smith <emma@emmatyping.dev>
This commit is contained in:
parent
17bf6ab0c1
commit
cdc92cd9fc
1 changed files with 55 additions and 68 deletions
|
@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
|
||||||
return (_zstd_state *)state;
|
return (_zstd_state *)state;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Py_ssize_t
|
||||||
|
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
|
||||||
|
size_t **chunk_sizes)
|
||||||
|
{
|
||||||
|
Py_ssize_t chunks_number;
|
||||||
|
Py_ssize_t sizes_sum;
|
||||||
|
Py_ssize_t i;
|
||||||
|
|
||||||
|
chunks_number = Py_SIZE(samples_sizes);
|
||||||
|
if ((size_t) chunks_number > UINT32_MAX) {
|
||||||
|
PyErr_Format(PyExc_ValueError,
|
||||||
|
"The number of samples should be <= %u.", UINT32_MAX);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Prepare chunk_sizes */
|
||||||
|
*chunk_sizes = PyMem_New(size_t, chunks_number);
|
||||||
|
if (*chunk_sizes == NULL) {
|
||||||
|
PyErr_NoMemory();
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
sizes_sum = 0;
|
||||||
|
for (i = 0; i < chunks_number; i++) {
|
||||||
|
PyObject *size = PyTuple_GetItem(samples_sizes, i);
|
||||||
|
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
|
||||||
|
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
|
||||||
|
PyErr_Format(PyExc_ValueError,
|
||||||
|
"Items in samples_sizes should be an int "
|
||||||
|
"object, with a value between 0 and %u.", SIZE_MAX);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
sizes_sum += (*chunk_sizes)[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sizes_sum != Py_SIZE(samples_bytes)) {
|
||||||
|
PyErr_SetString(PyExc_ValueError,
|
||||||
|
"The samples size tuple doesn't match the concatenation's size.");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return chunks_number;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*[clinic input]
|
/*[clinic input]
|
||||||
_zstd.train_dict
|
_zstd.train_dict
|
||||||
|
@ -192,14 +235,10 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
|
||||||
PyObject *samples_sizes, Py_ssize_t dict_size)
|
PyObject *samples_sizes, Py_ssize_t dict_size)
|
||||||
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
|
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
|
||||||
{
|
{
|
||||||
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
|
|
||||||
// are pretty similar. We should see if we can refactor them to share that code.
|
|
||||||
Py_ssize_t chunks_number;
|
|
||||||
size_t *chunk_sizes = NULL;
|
|
||||||
PyObject *dst_dict_bytes = NULL;
|
PyObject *dst_dict_bytes = NULL;
|
||||||
|
size_t *chunk_sizes = NULL;
|
||||||
|
Py_ssize_t chunks_number;
|
||||||
size_t zstd_ret;
|
size_t zstd_ret;
|
||||||
Py_ssize_t sizes_sum;
|
|
||||||
Py_ssize_t i;
|
|
||||||
|
|
||||||
/* Check arguments */
|
/* Check arguments */
|
||||||
if (dict_size <= 0) {
|
if (dict_size <= 0) {
|
||||||
|
@ -207,39 +246,14 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks_number = Py_SIZE(samples_sizes);
|
/* Check that the samples are valid and get their sizes */
|
||||||
if ((size_t) chunks_number > UINT32_MAX) {
|
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
|
||||||
PyErr_Format(PyExc_ValueError,
|
&chunk_sizes);
|
||||||
"The number of samples should be <= %u.", UINT32_MAX);
|
if (chunks_number < 0)
|
||||||
|
{
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Prepare chunk_sizes */
|
|
||||||
chunk_sizes = PyMem_New(size_t, chunks_number);
|
|
||||||
if (chunk_sizes == NULL) {
|
|
||||||
PyErr_NoMemory();
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
|
|
||||||
sizes_sum = 0;
|
|
||||||
for (i = 0; i < chunks_number; i++) {
|
|
||||||
PyObject *size = PyTuple_GetItem(samples_sizes, i);
|
|
||||||
chunk_sizes[i] = PyLong_AsSize_t(size);
|
|
||||||
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
|
|
||||||
PyErr_Format(PyExc_ValueError,
|
|
||||||
"Items in samples_sizes should be an int "
|
|
||||||
"object, with a value between 0 and %u.", SIZE_MAX);
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
sizes_sum += chunk_sizes[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sizes_sum != Py_SIZE(samples_bytes)) {
|
|
||||||
PyErr_SetString(PyExc_ValueError,
|
|
||||||
"The samples size tuple doesn't match the concatenation's size.");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Allocate dict buffer */
|
/* Allocate dict buffer */
|
||||||
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
|
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
|
||||||
if (dst_dict_bytes == NULL) {
|
if (dst_dict_bytes == NULL) {
|
||||||
|
@ -307,8 +321,6 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
|
||||||
PyObject *dst_dict_bytes = NULL;
|
PyObject *dst_dict_bytes = NULL;
|
||||||
size_t zstd_ret;
|
size_t zstd_ret;
|
||||||
ZDICT_params_t params;
|
ZDICT_params_t params;
|
||||||
Py_ssize_t sizes_sum;
|
|
||||||
Py_ssize_t i;
|
|
||||||
|
|
||||||
/* Check arguments */
|
/* Check arguments */
|
||||||
if (dict_size <= 0) {
|
if (dict_size <= 0) {
|
||||||
|
@ -316,39 +328,14 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks_number = Py_SIZE(samples_sizes);
|
/* Check that the samples are valid and get their sizes */
|
||||||
if ((size_t) chunks_number > UINT32_MAX) {
|
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
|
||||||
PyErr_Format(PyExc_ValueError,
|
&chunk_sizes);
|
||||||
"The number of samples should be <= %u.", UINT32_MAX);
|
if (chunks_number < 0)
|
||||||
|
{
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Prepare chunk_sizes */
|
|
||||||
chunk_sizes = PyMem_New(size_t, chunks_number);
|
|
||||||
if (chunk_sizes == NULL) {
|
|
||||||
PyErr_NoMemory();
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
|
|
||||||
sizes_sum = 0;
|
|
||||||
for (i = 0; i < chunks_number; i++) {
|
|
||||||
PyObject *size = PyTuple_GetItem(samples_sizes, i);
|
|
||||||
chunk_sizes[i] = PyLong_AsSize_t(size);
|
|
||||||
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
|
|
||||||
PyErr_Format(PyExc_ValueError,
|
|
||||||
"Items in samples_sizes should be an int "
|
|
||||||
"object, with a value between 0 and %u.", SIZE_MAX);
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
sizes_sum += chunk_sizes[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if (sizes_sum != Py_SIZE(samples_bytes)) {
|
|
||||||
PyErr_SetString(PyExc_ValueError,
|
|
||||||
"The samples size tuple doesn't match the concatenation's size.");
|
|
||||||
goto error;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Allocate dict buffer */
|
/* Allocate dict buffer */
|
||||||
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
|
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
|
||||||
if (dst_dict_bytes == NULL) {
|
if (dst_dict_bytes == NULL) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue