mirror of
https://github.com/python/cpython.git
synced 2025-09-26 10:19:53 +00:00
[3.14] gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432) (#134442)
gh-132983: Refactor shared code in train_dict and finalize_dict (GH-134432)
Refactor shared code in train_dict and finalize_dict
(cherry picked from commit c64a21454b
)
Co-authored-by: Emma Smith <emma@emmatyping.dev>
This commit is contained in:
parent
17bf6ab0c1
commit
cdc92cd9fc
1 changed files with 55 additions and 68 deletions
|
@ -172,6 +172,49 @@ get_zstd_state(PyObject *module)
|
|||
return (_zstd_state *)state;
|
||||
}
|
||||
|
||||
static Py_ssize_t
|
||||
calculate_samples_stats(PyBytesObject *samples_bytes, PyObject *samples_sizes,
|
||||
size_t **chunk_sizes)
|
||||
{
|
||||
Py_ssize_t chunks_number;
|
||||
Py_ssize_t sizes_sum;
|
||||
Py_ssize_t i;
|
||||
|
||||
chunks_number = Py_SIZE(samples_sizes);
|
||||
if ((size_t) chunks_number > UINT32_MAX) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"The number of samples should be <= %u.", UINT32_MAX);
|
||||
return -1;
|
||||
}
|
||||
|
||||
/* Prepare chunk_sizes */
|
||||
*chunk_sizes = PyMem_New(size_t, chunks_number);
|
||||
if (*chunk_sizes == NULL) {
|
||||
PyErr_NoMemory();
|
||||
return -1;
|
||||
}
|
||||
|
||||
sizes_sum = 0;
|
||||
for (i = 0; i < chunks_number; i++) {
|
||||
PyObject *size = PyTuple_GetItem(samples_sizes, i);
|
||||
(*chunk_sizes)[i] = PyLong_AsSize_t(size);
|
||||
if ((*chunk_sizes)[i] == (size_t)-1 && PyErr_Occurred()) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Items in samples_sizes should be an int "
|
||||
"object, with a value between 0 and %u.", SIZE_MAX);
|
||||
return -1;
|
||||
}
|
||||
sizes_sum += (*chunk_sizes)[i];
|
||||
}
|
||||
|
||||
if (sizes_sum != Py_SIZE(samples_bytes)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"The samples size tuple doesn't match the concatenation's size.");
|
||||
return -1;
|
||||
}
|
||||
return chunks_number;
|
||||
}
|
||||
|
||||
|
||||
/*[clinic input]
|
||||
_zstd.train_dict
|
||||
|
@ -192,14 +235,10 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
|
|||
PyObject *samples_sizes, Py_ssize_t dict_size)
|
||||
/*[clinic end generated code: output=8e87fe43935e8f77 input=d20dedb21c72cb62]*/
|
||||
{
|
||||
// TODO(emmatyping): The preamble and suffix to this function and _finalize_dict
|
||||
// are pretty similar. We should see if we can refactor them to share that code.
|
||||
Py_ssize_t chunks_number;
|
||||
size_t *chunk_sizes = NULL;
|
||||
PyObject *dst_dict_bytes = NULL;
|
||||
size_t *chunk_sizes = NULL;
|
||||
Py_ssize_t chunks_number;
|
||||
size_t zstd_ret;
|
||||
Py_ssize_t sizes_sum;
|
||||
Py_ssize_t i;
|
||||
|
||||
/* Check arguments */
|
||||
if (dict_size <= 0) {
|
||||
|
@ -207,39 +246,14 @@ _zstd_train_dict_impl(PyObject *module, PyBytesObject *samples_bytes,
|
|||
return NULL;
|
||||
}
|
||||
|
||||
chunks_number = Py_SIZE(samples_sizes);
|
||||
if ((size_t) chunks_number > UINT32_MAX) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"The number of samples should be <= %u.", UINT32_MAX);
|
||||
/* Check that the samples are valid and get their sizes */
|
||||
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
|
||||
&chunk_sizes);
|
||||
if (chunks_number < 0)
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Prepare chunk_sizes */
|
||||
chunk_sizes = PyMem_New(size_t, chunks_number);
|
||||
if (chunk_sizes == NULL) {
|
||||
PyErr_NoMemory();
|
||||
goto error;
|
||||
}
|
||||
|
||||
sizes_sum = 0;
|
||||
for (i = 0; i < chunks_number; i++) {
|
||||
PyObject *size = PyTuple_GetItem(samples_sizes, i);
|
||||
chunk_sizes[i] = PyLong_AsSize_t(size);
|
||||
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Items in samples_sizes should be an int "
|
||||
"object, with a value between 0 and %u.", SIZE_MAX);
|
||||
goto error;
|
||||
}
|
||||
sizes_sum += chunk_sizes[i];
|
||||
}
|
||||
|
||||
if (sizes_sum != Py_SIZE(samples_bytes)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"The samples size tuple doesn't match the concatenation's size.");
|
||||
goto error;
|
||||
}
|
||||
|
||||
/* Allocate dict buffer */
|
||||
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
|
||||
if (dst_dict_bytes == NULL) {
|
||||
|
@ -307,8 +321,6 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
|
|||
PyObject *dst_dict_bytes = NULL;
|
||||
size_t zstd_ret;
|
||||
ZDICT_params_t params;
|
||||
Py_ssize_t sizes_sum;
|
||||
Py_ssize_t i;
|
||||
|
||||
/* Check arguments */
|
||||
if (dict_size <= 0) {
|
||||
|
@ -316,39 +328,14 @@ _zstd_finalize_dict_impl(PyObject *module, PyBytesObject *custom_dict_bytes,
|
|||
return NULL;
|
||||
}
|
||||
|
||||
chunks_number = Py_SIZE(samples_sizes);
|
||||
if ((size_t) chunks_number > UINT32_MAX) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"The number of samples should be <= %u.", UINT32_MAX);
|
||||
/* Check that the samples are valid and get their sizes */
|
||||
chunks_number = calculate_samples_stats(samples_bytes, samples_sizes,
|
||||
&chunk_sizes);
|
||||
if (chunks_number < 0)
|
||||
{
|
||||
return NULL;
|
||||
}
|
||||
|
||||
/* Prepare chunk_sizes */
|
||||
chunk_sizes = PyMem_New(size_t, chunks_number);
|
||||
if (chunk_sizes == NULL) {
|
||||
PyErr_NoMemory();
|
||||
goto error;
|
||||
}
|
||||
|
||||
sizes_sum = 0;
|
||||
for (i = 0; i < chunks_number; i++) {
|
||||
PyObject *size = PyTuple_GetItem(samples_sizes, i);
|
||||
chunk_sizes[i] = PyLong_AsSize_t(size);
|
||||
if (chunk_sizes[i] == (size_t)-1 && PyErr_Occurred()) {
|
||||
PyErr_Format(PyExc_ValueError,
|
||||
"Items in samples_sizes should be an int "
|
||||
"object, with a value between 0 and %u.", SIZE_MAX);
|
||||
goto error;
|
||||
}
|
||||
sizes_sum += chunk_sizes[i];
|
||||
}
|
||||
|
||||
if (sizes_sum != Py_SIZE(samples_bytes)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"The samples size tuple doesn't match the concatenation's size.");
|
||||
goto error;
|
||||
}
|
||||
|
||||
/* Allocate dict buffer */
|
||||
dst_dict_bytes = PyBytes_FromStringAndSize(NULL, dict_size);
|
||||
if (dst_dict_bytes == NULL) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue