mirror of
https://github.com/python/cpython.git
synced 2025-08-26 19:55:24 +00:00
Close #15573: use value-based memoryview comparisons (patch by Stefan Krah)
This commit is contained in:
parent
5c0b1ca55e
commit
06e1ab0a6b
5 changed files with 778 additions and 132 deletions
|
@ -246,7 +246,7 @@ Create a new memoryview object which references the given object.");
|
|||
(view->suboffsets && view->suboffsets[dest->ndim-1] >= 0)
|
||||
|
||||
Py_LOCAL_INLINE(int)
|
||||
last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src)
|
||||
last_dim_is_contiguous(const Py_buffer *dest, const Py_buffer *src)
|
||||
{
|
||||
assert(dest->ndim > 0 && src->ndim > 0);
|
||||
return (!HAVE_SUBOFFSETS_IN_LAST_DIM(dest) &&
|
||||
|
@ -255,37 +255,63 @@ last_dim_is_contiguous(Py_buffer *dest, Py_buffer *src)
|
|||
src->strides[src->ndim-1] == src->itemsize);
|
||||
}
|
||||
|
||||
/* Check that the logical structure of the destination and source buffers
|
||||
is identical. */
|
||||
static int
|
||||
cmp_structure(Py_buffer *dest, Py_buffer *src)
|
||||
/* This is not a general function for determining format equivalence.
|
||||
It is used in copy_single() and copy_buffer() to weed out non-matching
|
||||
formats. Skipping the '@' character is specifically used in slice
|
||||
assignments, where the lvalue is already known to have a single character
|
||||
format. This is a performance hack that could be rewritten (if properly
|
||||
benchmarked). */
|
||||
Py_LOCAL_INLINE(int)
|
||||
equiv_format(const Py_buffer *dest, const Py_buffer *src)
|
||||
{
|
||||
const char *dfmt, *sfmt;
|
||||
int i;
|
||||
|
||||
assert(dest->format && src->format);
|
||||
dfmt = dest->format[0] == '@' ? dest->format+1 : dest->format;
|
||||
sfmt = src->format[0] == '@' ? src->format+1 : src->format;
|
||||
|
||||
if (strcmp(dfmt, sfmt) != 0 ||
|
||||
dest->itemsize != src->itemsize ||
|
||||
dest->ndim != src->ndim) {
|
||||
goto value_error;
|
||||
dest->itemsize != src->itemsize) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
/* Two shapes are equivalent if they are either equal or identical up
|
||||
to a zero element at the same position. For example, in NumPy arrays
|
||||
the shapes [1, 0, 5] and [1, 0, 7] are equivalent. */
|
||||
Py_LOCAL_INLINE(int)
|
||||
equiv_shape(const Py_buffer *dest, const Py_buffer *src)
|
||||
{
|
||||
int i;
|
||||
|
||||
if (dest->ndim != src->ndim)
|
||||
return 0;
|
||||
|
||||
for (i = 0; i < dest->ndim; i++) {
|
||||
if (dest->shape[i] != src->shape[i])
|
||||
goto value_error;
|
||||
return 0;
|
||||
if (dest->shape[i] == 0)
|
||||
break;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return 1;
|
||||
}
|
||||
|
||||
value_error:
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"ndarray assignment: lvalue and rvalue have different structures");
|
||||
return -1;
|
||||
/* Check that the logical structure of the destination and source buffers
|
||||
is identical. */
|
||||
static int
|
||||
equiv_structure(const Py_buffer *dest, const Py_buffer *src)
|
||||
{
|
||||
if (!equiv_format(dest, src) ||
|
||||
!equiv_shape(dest, src)) {
|
||||
PyErr_SetString(PyExc_ValueError,
|
||||
"ndarray assignment: lvalue and rvalue have different structures");
|
||||
return 0;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
/* Base case for recursive multi-dimensional copying. Contiguous arrays are
|
||||
|
@ -358,7 +384,7 @@ copy_single(Py_buffer *dest, Py_buffer *src)
|
|||
|
||||
assert(dest->ndim == 1);
|
||||
|
||||
if (cmp_structure(dest, src) < 0)
|
||||
if (!equiv_structure(dest, src))
|
||||
return -1;
|
||||
|
||||
if (!last_dim_is_contiguous(dest, src)) {
|
||||
|
@ -390,7 +416,7 @@ copy_buffer(Py_buffer *dest, Py_buffer *src)
|
|||
|
||||
assert(dest->ndim > 0);
|
||||
|
||||
if (cmp_structure(dest, src) < 0)
|
||||
if (!equiv_structure(dest, src))
|
||||
return -1;
|
||||
|
||||
if (!last_dim_is_contiguous(dest, src)) {
|
||||
|
@ -1827,6 +1853,131 @@ err_format:
|
|||
}
|
||||
|
||||
|
||||
/****************************************************************************/
|
||||
/* unpack using the struct module */
|
||||
/****************************************************************************/
|
||||
|
||||
/* For reasonable performance it is necessary to cache all objects required
|
||||
for unpacking. An unpacker can handle the format passed to unpack_from().
|
||||
Invariant: All pointer fields of the struct should either be NULL or valid
|
||||
pointers. */
|
||||
struct unpacker {
|
||||
PyObject *unpack_from; /* Struct.unpack_from(format) */
|
||||
PyObject *mview; /* cached memoryview */
|
||||
char *item; /* buffer for mview */
|
||||
Py_ssize_t itemsize; /* len(item) */
|
||||
};
|
||||
|
||||
static struct unpacker *
|
||||
unpacker_new(void)
|
||||
{
|
||||
struct unpacker *x = PyMem_Malloc(sizeof *x);
|
||||
|
||||
if (x == NULL) {
|
||||
PyErr_NoMemory();
|
||||
return NULL;
|
||||
}
|
||||
|
||||
x->unpack_from = NULL;
|
||||
x->mview = NULL;
|
||||
x->item = NULL;
|
||||
x->itemsize = 0;
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
static void
|
||||
unpacker_free(struct unpacker *x)
|
||||
{
|
||||
if (x) {
|
||||
Py_XDECREF(x->unpack_from);
|
||||
Py_XDECREF(x->mview);
|
||||
PyMem_Free(x->item);
|
||||
PyMem_Free(x);
|
||||
}
|
||||
}
|
||||
|
||||
/* Return a new unpacker for the given format. */
|
||||
static struct unpacker *
|
||||
struct_get_unpacker(const char *fmt, Py_ssize_t itemsize)
|
||||
{
|
||||
PyObject *structmodule; /* XXX cache these two */
|
||||
PyObject *Struct = NULL; /* XXX in globals? */
|
||||
PyObject *structobj = NULL;
|
||||
PyObject *format = NULL;
|
||||
struct unpacker *x = NULL;
|
||||
|
||||
structmodule = PyImport_ImportModule("struct");
|
||||
if (structmodule == NULL)
|
||||
return NULL;
|
||||
|
||||
Struct = PyObject_GetAttrString(structmodule, "Struct");
|
||||
Py_DECREF(structmodule);
|
||||
if (Struct == NULL)
|
||||
return NULL;
|
||||
|
||||
x = unpacker_new();
|
||||
if (x == NULL)
|
||||
goto error;
|
||||
|
||||
format = PyBytes_FromString(fmt);
|
||||
if (format == NULL)
|
||||
goto error;
|
||||
|
||||
structobj = PyObject_CallFunctionObjArgs(Struct, format, NULL);
|
||||
if (structobj == NULL)
|
||||
goto error;
|
||||
|
||||
x->unpack_from = PyObject_GetAttrString(structobj, "unpack_from");
|
||||
if (x->unpack_from == NULL)
|
||||
goto error;
|
||||
|
||||
x->item = PyMem_Malloc(itemsize);
|
||||
if (x->item == NULL) {
|
||||
PyErr_NoMemory();
|
||||
goto error;
|
||||
}
|
||||
x->itemsize = itemsize;
|
||||
|
||||
x->mview = PyMemoryView_FromMemory(x->item, itemsize, PyBUF_WRITE);
|
||||
if (x->mview == NULL)
|
||||
goto error;
|
||||
|
||||
|
||||
out:
|
||||
Py_XDECREF(Struct);
|
||||
Py_XDECREF(format);
|
||||
Py_XDECREF(structobj);
|
||||
return x;
|
||||
|
||||
error:
|
||||
unpacker_free(x);
|
||||
x = NULL;
|
||||
goto out;
|
||||
}
|
||||
|
||||
/* unpack a single item */
|
||||
static PyObject *
|
||||
struct_unpack_single(const char *ptr, struct unpacker *x)
|
||||
{
|
||||
PyObject *v;
|
||||
|
||||
memcpy(x->item, ptr, x->itemsize);
|
||||
v = PyObject_CallFunctionObjArgs(x->unpack_from, x->mview, NULL);
|
||||
if (v == NULL)
|
||||
return NULL;
|
||||
|
||||
if (PyTuple_GET_SIZE(v) == 1) {
|
||||
PyObject *tmp = PyTuple_GET_ITEM(v, 0);
|
||||
Py_INCREF(tmp);
|
||||
Py_DECREF(v);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
|
||||
/****************************************************************************/
|
||||
/* Representations */
|
||||
/****************************************************************************/
|
||||
|
@ -2261,6 +2412,58 @@ static PySequenceMethods memory_as_sequence = {
|
|||
/* Comparisons */
|
||||
/**************************************************************************/
|
||||
|
||||
#define MV_COMPARE_EX -1 /* exception */
|
||||
#define MV_COMPARE_NOT_IMPL -2 /* not implemented */
|
||||
|
||||
/* Translate a StructError to "not equal". Preserve other exceptions. */
|
||||
static int
|
||||
fix_struct_error_int(void)
|
||||
{
|
||||
assert(PyErr_Occurred());
|
||||
/* XXX Cannot get at StructError directly? */
|
||||
if (PyErr_ExceptionMatches(PyExc_ImportError) ||
|
||||
PyErr_ExceptionMatches(PyExc_MemoryError)) {
|
||||
return MV_COMPARE_EX;
|
||||
}
|
||||
/* StructError: invalid or unknown format -> not equal */
|
||||
PyErr_Clear();
|
||||
return 0;
|
||||
}
|
||||
|
||||
/* Unpack and compare single items of p and q using the struct module. */
|
||||
static int
|
||||
struct_unpack_cmp(const char *p, const char *q,
|
||||
struct unpacker *unpack_p, struct unpacker *unpack_q)
|
||||
{
|
||||
PyObject *v, *w;
|
||||
int ret;
|
||||
|
||||
/* At this point any exception from the struct module should not be
|
||||
StructError, since both formats have been accepted already. */
|
||||
v = struct_unpack_single(p, unpack_p);
|
||||
if (v == NULL)
|
||||
return MV_COMPARE_EX;
|
||||
|
||||
w = struct_unpack_single(q, unpack_q);
|
||||
if (w == NULL) {
|
||||
Py_DECREF(v);
|
||||
return MV_COMPARE_EX;
|
||||
}
|
||||
|
||||
/* MV_COMPARE_EX == -1: exceptions are preserved */
|
||||
ret = PyObject_RichCompareBool(v, w, Py_EQ);
|
||||
Py_DECREF(v);
|
||||
Py_DECREF(w);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/* Unpack and compare single items of p and q. If both p and q have the same
|
||||
single element native format, the comparison uses a fast path (gcc creates
|
||||
a jump table and converts memcpy into simple assignments on x86/x64).
|
||||
|
||||
Otherwise, the comparison is delegated to the struct module, which is
|
||||
30-60x slower. */
|
||||
#define CMP_SINGLE(p, q, type) \
|
||||
do { \
|
||||
type x; \
|
||||
|
@ -2271,11 +2474,12 @@ static PySequenceMethods memory_as_sequence = {
|
|||
} while (0)
|
||||
|
||||
Py_LOCAL_INLINE(int)
|
||||
unpack_cmp(const char *p, const char *q, const char *fmt)
|
||||
unpack_cmp(const char *p, const char *q, char fmt,
|
||||
struct unpacker *unpack_p, struct unpacker *unpack_q)
|
||||
{
|
||||
int equal;
|
||||
|
||||
switch (fmt[0]) {
|
||||
switch (fmt) {
|
||||
|
||||
/* signed integers and fast path for 'B' */
|
||||
case 'B': return *((unsigned char *)p) == *((unsigned char *)q);
|
||||
|
@ -2317,9 +2521,17 @@ unpack_cmp(const char *p, const char *q, const char *fmt)
|
|||
/* pointer */
|
||||
case 'P': CMP_SINGLE(p, q, void *); return equal;
|
||||
|
||||
/* Py_NotImplemented */
|
||||
default: return -1;
|
||||
/* use the struct module */
|
||||
case '_':
|
||||
assert(unpack_p);
|
||||
assert(unpack_q);
|
||||
return struct_unpack_cmp(p, q, unpack_p, unpack_q);
|
||||
}
|
||||
|
||||
/* NOT REACHED */
|
||||
PyErr_SetString(PyExc_RuntimeError,
|
||||
"memoryview: internal error in richcompare");
|
||||
return MV_COMPARE_EX;
|
||||
}
|
||||
|
||||
/* Base case for recursive array comparisons. Assumption: ndim == 1. */
|
||||
|
@ -2327,7 +2539,7 @@ static int
|
|||
cmp_base(const char *p, const char *q, const Py_ssize_t *shape,
|
||||
const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets,
|
||||
const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets,
|
||||
const char *fmt)
|
||||
char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q)
|
||||
{
|
||||
Py_ssize_t i;
|
||||
int equal;
|
||||
|
@ -2335,7 +2547,7 @@ cmp_base(const char *p, const char *q, const Py_ssize_t *shape,
|
|||
for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) {
|
||||
const char *xp = ADJUST_PTR(p, psuboffsets);
|
||||
const char *xq = ADJUST_PTR(q, qsuboffsets);
|
||||
equal = unpack_cmp(xp, xq, fmt);
|
||||
equal = unpack_cmp(xp, xq, fmt, unpack_p, unpack_q);
|
||||
if (equal <= 0)
|
||||
return equal;
|
||||
}
|
||||
|
@ -2350,7 +2562,7 @@ cmp_rec(const char *p, const char *q,
|
|||
Py_ssize_t ndim, const Py_ssize_t *shape,
|
||||
const Py_ssize_t *pstrides, const Py_ssize_t *psuboffsets,
|
||||
const Py_ssize_t *qstrides, const Py_ssize_t *qsuboffsets,
|
||||
const char *fmt)
|
||||
char fmt, struct unpacker *unpack_p, struct unpacker *unpack_q)
|
||||
{
|
||||
Py_ssize_t i;
|
||||
int equal;
|
||||
|
@ -2364,7 +2576,7 @@ cmp_rec(const char *p, const char *q,
|
|||
return cmp_base(p, q, shape,
|
||||
pstrides, psuboffsets,
|
||||
qstrides, qsuboffsets,
|
||||
fmt);
|
||||
fmt, unpack_p, unpack_q);
|
||||
}
|
||||
|
||||
for (i = 0; i < shape[0]; p+=pstrides[0], q+=qstrides[0], i++) {
|
||||
|
@ -2373,7 +2585,7 @@ cmp_rec(const char *p, const char *q,
|
|||
equal = cmp_rec(xp, xq, ndim-1, shape+1,
|
||||
pstrides+1, psuboffsets ? psuboffsets+1 : NULL,
|
||||
qstrides+1, qsuboffsets ? qsuboffsets+1 : NULL,
|
||||
fmt);
|
||||
fmt, unpack_p, unpack_q);
|
||||
if (equal <= 0)
|
||||
return equal;
|
||||
}
|
||||
|
@ -2385,9 +2597,12 @@ static PyObject *
|
|||
memory_richcompare(PyObject *v, PyObject *w, int op)
|
||||
{
|
||||
PyObject *res;
|
||||
Py_buffer wbuf, *vv, *ww = NULL;
|
||||
const char *vfmt, *wfmt;
|
||||
int equal = -1; /* Py_NotImplemented */
|
||||
Py_buffer wbuf, *vv;
|
||||
Py_buffer *ww = NULL;
|
||||
struct unpacker *unpack_v = NULL;
|
||||
struct unpacker *unpack_w = NULL;
|
||||
char vfmt, wfmt;
|
||||
int equal = MV_COMPARE_NOT_IMPL;
|
||||
|
||||
if (op != Py_EQ && op != Py_NE)
|
||||
goto result; /* Py_NotImplemented */
|
||||
|
@ -2414,38 +2629,59 @@ memory_richcompare(PyObject *v, PyObject *w, int op)
|
|||
ww = &wbuf;
|
||||
}
|
||||
|
||||
vfmt = adjust_fmt(vv);
|
||||
wfmt = adjust_fmt(ww);
|
||||
if (vfmt == NULL || wfmt == NULL) {
|
||||
PyErr_Clear();
|
||||
goto result; /* Py_NotImplemented */
|
||||
}
|
||||
|
||||
if (cmp_structure(vv, ww) < 0) {
|
||||
if (!equiv_shape(vv, ww)) {
|
||||
PyErr_Clear();
|
||||
equal = 0;
|
||||
goto result;
|
||||
}
|
||||
|
||||
/* Use fast unpacking for identical primitive C type formats. */
|
||||
if (get_native_fmtchar(&vfmt, vv->format) < 0)
|
||||
vfmt = '_';
|
||||
if (get_native_fmtchar(&wfmt, ww->format) < 0)
|
||||
wfmt = '_';
|
||||
if (vfmt == '_' || wfmt == '_' || vfmt != wfmt) {
|
||||
/* Use struct module unpacking. NOTE: Even for equal format strings,
|
||||
memcmp() cannot be used for item comparison since it would give
|
||||
incorrect results in the case of NaNs or uninitialized padding
|
||||
bytes. */
|
||||
vfmt = '_';
|
||||
unpack_v = struct_get_unpacker(vv->format, vv->itemsize);
|
||||
if (unpack_v == NULL) {
|
||||
equal = fix_struct_error_int();
|
||||
goto result;
|
||||
}
|
||||
unpack_w = struct_get_unpacker(ww->format, ww->itemsize);
|
||||
if (unpack_w == NULL) {
|
||||
equal = fix_struct_error_int();
|
||||
goto result;
|
||||
}
|
||||
}
|
||||
|
||||
if (vv->ndim == 0) {
|
||||
equal = unpack_cmp(vv->buf, ww->buf, vfmt);
|
||||
equal = unpack_cmp(vv->buf, ww->buf,
|
||||
vfmt, unpack_v, unpack_w);
|
||||
}
|
||||
else if (vv->ndim == 1) {
|
||||
equal = cmp_base(vv->buf, ww->buf, vv->shape,
|
||||
vv->strides, vv->suboffsets,
|
||||
ww->strides, ww->suboffsets,
|
||||
vfmt);
|
||||
vfmt, unpack_v, unpack_w);
|
||||
}
|
||||
else {
|
||||
equal = cmp_rec(vv->buf, ww->buf, vv->ndim, vv->shape,
|
||||
vv->strides, vv->suboffsets,
|
||||
ww->strides, ww->suboffsets,
|
||||
vfmt);
|
||||
vfmt, unpack_v, unpack_w);
|
||||
}
|
||||
|
||||
result:
|
||||
if (equal < 0)
|
||||
res = Py_NotImplemented;
|
||||
if (equal < 0) {
|
||||
if (equal == MV_COMPARE_NOT_IMPL)
|
||||
res = Py_NotImplemented;
|
||||
else /* exception */
|
||||
res = NULL;
|
||||
}
|
||||
else if ((equal && op == Py_EQ) || (!equal && op == Py_NE))
|
||||
res = Py_True;
|
||||
else
|
||||
|
@ -2453,7 +2689,11 @@ result:
|
|||
|
||||
if (ww == &wbuf)
|
||||
PyBuffer_Release(ww);
|
||||
Py_INCREF(res);
|
||||
|
||||
unpacker_free(unpack_v);
|
||||
unpacker_free(unpack_w);
|
||||
|
||||
Py_XINCREF(res);
|
||||
return res;
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue