bpo-32630: Use contextvars in decimal (GH-5278)

This commit is contained in:
Yury Selivanov 2018-01-27 13:46:46 -05:00 committed by GitHub
parent bc4123b0b3
commit f13f12d8da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 101 deletions

View file

@ -122,10 +122,7 @@ incr_false(void)
}
/* Key for thread state dictionary */
static PyObject *tls_context_key = NULL;
/* Invariant: NULL or the most recently accessed thread local context */
static PyDecContextObject *cached_context = NULL;
static PyContextVar *current_context_var;
/* Template for creating new thread contexts, calling Context() without
* arguments and initializing the module_context on first access. */
@ -1220,10 +1217,6 @@ context_new(PyTypeObject *type, PyObject *args UNUSED, PyObject *kwds UNUSED)
static void
context_dealloc(PyDecContextObject *self)
{
if (self == cached_context) {
cached_context = NULL;
}
Py_XDECREF(self->traps);
Py_XDECREF(self->flags);
Py_TYPE(self)->tp_free(self);
@ -1498,69 +1491,38 @@ static PyGetSetDef context_getsets [] =
* operation.
*/
/* Get the context from the thread state dictionary. */
static PyObject *
current_context_from_dict(void)
init_current_context(void)
{
PyObject *dict;
PyObject *tl_context;
PyThreadState *tstate;
dict = PyThreadState_GetDict();
if (dict == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"cannot get thread state");
PyObject *tl_context = context_copy(default_context_template, NULL);
if (tl_context == NULL) {
return NULL;
}
CTX(tl_context)->status = 0;
tl_context = PyDict_GetItemWithError(dict, tls_context_key);
if (tl_context != NULL) {
/* We already have a thread local context. */
CONTEXT_CHECK(tl_context);
}
else {
if (PyErr_Occurred()) {
return NULL;
}
/* Set up a new thread local context. */
tl_context = context_copy(default_context_template, NULL);
if (tl_context == NULL) {
return NULL;
}
CTX(tl_context)->status = 0;
if (PyDict_SetItem(dict, tls_context_key, tl_context) < 0) {
Py_DECREF(tl_context);
return NULL;
}
PyContextToken *tok = PyContextVar_Set(current_context_var, tl_context);
if (tok == NULL) {
Py_DECREF(tl_context);
return NULL;
}
Py_DECREF(tok);
/* Cache the context of the current thread, assuming that it
* will be accessed several times before a thread switch. */
tstate = PyThreadState_GET();
if (tstate) {
cached_context = (PyDecContextObject *)tl_context;
cached_context->tstate = tstate;
}
/* Borrowed reference with refcount==1 */
return tl_context;
}
/* Return borrowed reference to thread local context. */
static PyObject *
static inline PyObject *
current_context(void)
{
PyThreadState *tstate;
tstate = PyThreadState_GET();
if (cached_context && cached_context->tstate == tstate) {
return (PyObject *)cached_context;
PyObject *tl_context;
if (PyContextVar_Get(current_context_var, NULL, &tl_context) < 0) {
return NULL;
}
return current_context_from_dict();
if (tl_context != NULL) {
return tl_context;
}
return init_current_context();
}
/* ctxobj := borrowed reference to the current context */
@ -1568,47 +1530,22 @@ current_context(void)
ctxobj = current_context(); \
if (ctxobj == NULL) { \
return NULL; \
}
/* ctx := pointer to the mpd_context_t struct of the current context */
#define CURRENT_CONTEXT_ADDR(ctx) { \
PyObject *_c_t_x_o_b_j = current_context(); \
if (_c_t_x_o_b_j == NULL) { \
return NULL; \
} \
ctx = CTX(_c_t_x_o_b_j); \
}
} \
Py_DECREF(ctxobj);
/* Return a new reference to the current context */
static PyObject *
PyDec_GetCurrentContext(PyObject *self UNUSED, PyObject *args UNUSED)
{
PyObject *context;
context = current_context();
if (context == NULL) {
return NULL;
}
Py_INCREF(context);
return context;
return current_context();
}
/* Set the thread local context to a new context, decrement old reference */
static PyObject *
PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
{
PyObject *dict;
CONTEXT_CHECK(v);
dict = PyThreadState_GetDict();
if (dict == NULL) {
PyErr_SetString(PyExc_RuntimeError,
"cannot get thread state");
return NULL;
}
/* If the new context is one of the templates, make a copy.
* This is the current behavior of decimal.py. */
if (v == default_context_template ||
@ -1624,13 +1561,13 @@ PyDec_SetCurrentContext(PyObject *self UNUSED, PyObject *v)
Py_INCREF(v);
}
cached_context = NULL;
if (PyDict_SetItem(dict, tls_context_key, v) < 0) {
Py_DECREF(v);
PyContextToken *tok = PyContextVar_Set(current_context_var, v);
Py_DECREF(v);
if (tok == NULL) {
return NULL;
}
Py_DECREF(tok);
Py_DECREF(v);
Py_RETURN_NONE;
}
@ -4458,6 +4395,7 @@ _dec_hash(PyDecObject *v)
if (context == NULL) {
return -1;
}
Py_DECREF(context);
if (mpd_isspecial(MPD(v))) {
if (mpd_issnan(MPD(v))) {
@ -5599,6 +5537,11 @@ PyInit__decimal(void)
mpd_free = PyMem_Free;
mpd_setminalloc(_Py_DEC_MINALLOC);
/* Init context variable */
current_context_var = PyContextVar_New("decimal_context", NULL);
if (current_context_var == NULL) {
goto error;
}
/* Init external C-API functions */
_py_long_multiply = PyLong_Type.tp_as_number->nb_multiply;
@ -5768,7 +5711,6 @@ PyInit__decimal(void)
CHECK_INT(PyModule_AddObject(m, "DefaultContext",
default_context_template));
ASSIGN_PTR(tls_context_key, PyUnicode_FromString("___DECIMAL_CTX__"));
Py_INCREF(Py_True);
CHECK_INT(PyModule_AddObject(m, "HAVE_THREADS", Py_True));
@ -5827,9 +5769,9 @@ error:
Py_CLEAR(SignalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(DecimalTuple); /* GCOV_NOT_REACHED */
Py_CLEAR(default_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(tls_context_key); /* GCOV_NOT_REACHED */
Py_CLEAR(basic_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(extended_context_template); /* GCOV_NOT_REACHED */
Py_CLEAR(current_context_var); /* GCOV_NOT_REACHED */
Py_CLEAR(m); /* GCOV_NOT_REACHED */
return NULL; /* GCOV_NOT_REACHED */