[3.14] gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432) (#134442)

gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432)

Refactor shared code in train_dict and finalize_dict
(cherry picked from commit c64a21454b)

Co-authored-by: Emma Smith <emma@emmatyping.dev>
This commit is contained in:
Miss Islington (bot) 2025-05-21 18:19:25 +02:00 committed by GitHub
parent 17bf6ab0c1
commit cdc92cd9fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
return (_zstd_state *)state; return (_zstd_state *)state;
} }
static Py_ssize_t
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
size_t **chunk_sizes)
{
Py_ssize_t chunks_number;
Py_ssize_t sizes_sum;
Py_ssize_t i;
chunks_number = Py_SIZE(samples_sizes);
if ((size_t) chunks_number > UINT32_MAX) {
PyErr_Format(PyExc_ValueError,
"The number of samples should be <= %u.", UINT32_MAX);
return -1;
}
/* Prepare chunk_sizes */
*chunk_sizes = PyMem_New(size_t, chunks_number);
if (*chunk_sizes == NULL) {
PyErr_NoMemory();
return -1;
}
sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
return -1;
}
sizes_sum += (*chunk_sizes)[i];
}
if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
return -1;
}
return chunks_number;
}
/*[clinic input] /*[clinic input]
_zstd.train_dict _zstd.train_dict
@ -192,14 +235,10 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
PyObject *samples_sizes, Py_ssize_t dict_size) PyObject *samples_sizes, Py_ssize_t dict_size)
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/ /*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
{ {
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
// are pretty similar. We should see if we can refactor them to share that code.
Py_ssize_t chunks_number;
size_t *chunk_sizes = NULL;
PyObject *dst_dict_bytes = NULL; PyObject *dst_dict_bytes = NULL;
size_t *chunk_sizes = NULL;
Py_ssize_t chunks_number;
size_t zstd_ret; size_t zstd_ret;
Py_ssize_t sizes_sum;
Py_ssize_t i;
/* Check arguments */ /* Check arguments */
if (dict_size <= 0) { if (dict_size <= 0) {
@ -207,39 +246,14 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
return NULL; return NULL;
} }
chunks_number = Py_SIZE(samples_sizes); /* Check that the samples are valid and get their sizes */
if ((size_t) chunks_number > UINT32_MAX) { chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
PyErr_Format(PyExc_ValueError, &chunk_sizes);
"The number of samples should be <= %u.", UINT32_MAX); if (chunks_number < 0)
{
return NULL; return NULL;
} }
/* Prepare chunk_sizes */
chunk_sizes = PyMem_New(size_t, chunks_number);
if (chunk_sizes == NULL) {
PyErr_NoMemory();
goto error;
}
sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
chunk_sizes[i] = PyLong_AsSize_t(size);
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
goto error;
}
sizes_sum += chunk_sizes[i];
}
if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
goto error;
}
/* Allocate dict buffer */ /* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size); dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) { if (dst_dict_bytes == NULL) {
@ -307,8 +321,6 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
PyObject *dst_dict_bytes = NULL; PyObject *dst_dict_bytes = NULL;
size_t zstd_ret; size_t zstd_ret;
ZDICT_params_t params; ZDICT_params_t params;
Py_ssize_t sizes_sum;
Py_ssize_t i;
/* Check arguments */ /* Check arguments */
if (dict_size <= 0) { if (dict_size <= 0) {
@ -316,39 +328,14 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
return NULL; return NULL;
} }
chunks_number = Py_SIZE(samples_sizes); /* Check that the samples are valid and get their sizes */
if ((size_t) chunks_number > UINT32_MAX) { chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
PyErr_Format(PyExc_ValueError, &chunk_sizes);
"The number of samples should be <= %u.", UINT32_MAX); if (chunks_number < 0)
{
return NULL; return NULL;
} }
/* Prepare chunk_sizes */
chunk_sizes = PyMem_New(size_t, chunks_number);
if (chunk_sizes == NULL) {
PyErr_NoMemory();
goto error;
}
sizes_sum = 0;
for (i = 0; i < chunks_number; i++) {
PyObject *size = PyTuple_GetItem(samples_sizes, i);
chunk_sizes[i] = PyLong_AsSize_t(size);
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
PyErr_Format(PyExc_ValueError,
"Items in samples_sizes should be an int "
"object, with a value between 0 and %u.", SIZE_MAX);
goto error;
}
sizes_sum += chunk_sizes[i];
}
if (sizes_sum != Py_SIZE(samples_bytes)) {
PyErr_SetString(PyExc_ValueError,
"The samples size tuple doesn't match the concatenation's size.");
goto error;
}
/* Allocate dict buffer */ /* Allocate dict buffer */
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size); dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
if (dst_dict_bytes == NULL) { if (dst_dict_bytes == NULL) {