mirror of
https://github.com/python/cpython.git
synced 2025-07-23 11:15:24 +00:00
PEP 465: a dedicated infix operator for matrix multiplication (closes #21176)
This commit is contained in:
parent
2aad6ef774
commit
d51374ed78
42 changed files with 803 additions and 442 deletions
|
@ -349,13 +349,14 @@ static PyTypeObject *And_type;
|
|||
static PyTypeObject *Or_type;
|
||||
static PyTypeObject *operator_type;
|
||||
static PyObject *Add_singleton, *Sub_singleton, *Mult_singleton,
|
||||
*Div_singleton, *Mod_singleton, *Pow_singleton, *LShift_singleton,
|
||||
*RShift_singleton, *BitOr_singleton, *BitXor_singleton, *BitAnd_singleton,
|
||||
*FloorDiv_singleton;
|
||||
*MatMult_singleton, *Div_singleton, *Mod_singleton, *Pow_singleton,
|
||||
*LShift_singleton, *RShift_singleton, *BitOr_singleton, *BitXor_singleton,
|
||||
*BitAnd_singleton, *FloorDiv_singleton;
|
||||
static PyObject* ast2obj_operator(operator_ty);
|
||||
static PyTypeObject *Add_type;
|
||||
static PyTypeObject *Sub_type;
|
||||
static PyTypeObject *Mult_type;
|
||||
static PyTypeObject *MatMult_type;
|
||||
static PyTypeObject *Div_type;
|
||||
static PyTypeObject *Mod_type;
|
||||
static PyTypeObject *Pow_type;
|
||||
|
@ -970,6 +971,10 @@ static int init_types(void)
|
|||
if (!Mult_type) return 0;
|
||||
Mult_singleton = PyType_GenericNew(Mult_type, NULL, NULL);
|
||||
if (!Mult_singleton) return 0;
|
||||
MatMult_type = make_type("MatMult", operator_type, NULL, 0);
|
||||
if (!MatMult_type) return 0;
|
||||
MatMult_singleton = PyType_GenericNew(MatMult_type, NULL, NULL);
|
||||
if (!MatMult_singleton) return 0;
|
||||
Div_type = make_type("Div", operator_type, NULL, 0);
|
||||
if (!Div_type) return 0;
|
||||
Div_singleton = PyType_GenericNew(Div_type, NULL, NULL);
|
||||
|
@ -3232,6 +3237,9 @@ PyObject* ast2obj_operator(operator_ty o)
|
|||
case Mult:
|
||||
Py_INCREF(Mult_singleton);
|
||||
return Mult_singleton;
|
||||
case MatMult:
|
||||
Py_INCREF(MatMult_singleton);
|
||||
return MatMult_singleton;
|
||||
case Div:
|
||||
Py_INCREF(Div_singleton);
|
||||
return Div_singleton;
|
||||
|
@ -6175,6 +6183,14 @@ obj2ast_operator(PyObject* obj, operator_ty* out, PyArena* arena)
|
|||
*out = Mult;
|
||||
return 0;
|
||||
}
|
||||
isinstance = PyObject_IsInstance(obj, (PyObject *)MatMult_type);
|
||||
if (isinstance == -1) {
|
||||
return 1;
|
||||
}
|
||||
if (isinstance) {
|
||||
*out = MatMult;
|
||||
return 0;
|
||||
}
|
||||
isinstance = PyObject_IsInstance(obj, (PyObject *)Div_type);
|
||||
if (isinstance == -1) {
|
||||
return 1;
|
||||
|
@ -6956,6 +6972,8 @@ PyInit__ast(void)
|
|||
if (PyDict_SetItemString(d, "Add", (PyObject*)Add_type) < 0) return NULL;
|
||||
if (PyDict_SetItemString(d, "Sub", (PyObject*)Sub_type) < 0) return NULL;
|
||||
if (PyDict_SetItemString(d, "Mult", (PyObject*)Mult_type) < 0) return NULL;
|
||||
if (PyDict_SetItemString(d, "MatMult", (PyObject*)MatMult_type) < 0) return
|
||||
NULL;
|
||||
if (PyDict_SetItemString(d, "Div", (PyObject*)Div_type) < 0) return NULL;
|
||||
if (PyDict_SetItemString(d, "Mod", (PyObject*)Mod_type) < 0) return NULL;
|
||||
if (PyDict_SetItemString(d, "Pow", (PyObject*)Pow_type) < 0) return NULL;
|
||||
|
|
|
@ -825,6 +825,8 @@ get_operator(const node *n)
|
|||
return Sub;
|
||||
case STAR:
|
||||
return Mult;
|
||||
case AT:
|
||||
return MatMult;
|
||||
case SLASH:
|
||||
return Div;
|
||||
case DOUBLESLASH:
|
||||
|
@ -1030,6 +1032,8 @@ ast_for_augassign(struct compiling *c, const node *n)
|
|||
return Pow;
|
||||
else
|
||||
return Mult;
|
||||
case '@':
|
||||
return MatMult;
|
||||
default:
|
||||
PyErr_Format(PyExc_SystemError, "invalid augassign: %s", STR(n));
|
||||
return (operator_ty)0;
|
||||
|
@ -2266,7 +2270,7 @@ ast_for_expr(struct compiling *c, const node *n)
|
|||
and_expr: shift_expr ('&' shift_expr)*
|
||||
shift_expr: arith_expr (('<<'|'>>') arith_expr)*
|
||||
arith_expr: term (('+'|'-') term)*
|
||||
term: factor (('*'|'/'|'%'|'//') factor)*
|
||||
term: factor (('*'|'@'|'/'|'%'|'//') factor)*
|
||||
factor: ('+'|'-'|'~') factor | power
|
||||
power: atom trailer* ('**' factor)*
|
||||
*/
|
||||
|
@ -2577,7 +2581,7 @@ ast_for_expr_stmt(struct compiling *c, const node *n)
|
|||
/* expr_stmt: testlist_star_expr (augassign (yield_expr|testlist)
|
||||
| ('=' (yield_expr|testlist))*)
|
||||
testlist_star_expr: (test|star_expr) (',' test|star_expr)* [',']
|
||||
augassign: '+=' | '-=' | '*=' | '/=' | '%=' | '&=' | '|=' | '^='
|
||||
augassign: '+=' | '-=' | '*=' | '@=' | '/=' | '%=' | '&=' | '|=' | '^='
|
||||
| '<<=' | '>>=' | '**=' | '//='
|
||||
test: ... here starts the operator precendence dance
|
||||
*/
|
||||
|
|
|
@ -1495,6 +1495,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
|
|||
DISPATCH();
|
||||
}
|
||||
|
||||
TARGET(BINARY_MATRIX_MULTIPLY) {
|
||||
PyObject *right = POP();
|
||||
PyObject *left = TOP();
|
||||
PyObject *res = PyNumber_MatrixMultiply(left, right);
|
||||
Py_DECREF(left);
|
||||
Py_DECREF(right);
|
||||
SET_TOP(res);
|
||||
if (res == NULL)
|
||||
goto error;
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
TARGET(BINARY_TRUE_DIVIDE) {
|
||||
PyObject *divisor = POP();
|
||||
PyObject *dividend = TOP();
|
||||
|
@ -1685,6 +1697,18 @@ PyEval_EvalFrameEx(PyFrameObject *f, int throwflag)
|
|||
DISPATCH();
|
||||
}
|
||||
|
||||
TARGET(INPLACE_MATRIX_MULTIPLY) {
|
||||
PyObject *right = POP();
|
||||
PyObject *left = TOP();
|
||||
PyObject *res = PyNumber_InPlaceMatrixMultiply(left, right);
|
||||
Py_DECREF(left);
|
||||
Py_DECREF(right);
|
||||
SET_TOP(res);
|
||||
if (res == NULL)
|
||||
goto error;
|
||||
DISPATCH();
|
||||
}
|
||||
|
||||
TARGET(INPLACE_TRUE_DIVIDE) {
|
||||
PyObject *divisor = POP();
|
||||
PyObject *dividend = TOP();
|
||||
|
|
|
@ -881,6 +881,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
|
|||
|
||||
case BINARY_POWER:
|
||||
case BINARY_MULTIPLY:
|
||||
case BINARY_MATRIX_MULTIPLY:
|
||||
case BINARY_MODULO:
|
||||
case BINARY_ADD:
|
||||
case BINARY_SUBTRACT:
|
||||
|
@ -895,6 +896,7 @@ PyCompile_OpcodeStackEffect(int opcode, int oparg)
|
|||
case INPLACE_ADD:
|
||||
case INPLACE_SUBTRACT:
|
||||
case INPLACE_MULTIPLY:
|
||||
case INPLACE_MATRIX_MULTIPLY:
|
||||
case INPLACE_MODULO:
|
||||
return -1;
|
||||
case STORE_SUBSCR:
|
||||
|
@ -2625,6 +2627,8 @@ binop(struct compiler *c, operator_ty op)
|
|||
return BINARY_SUBTRACT;
|
||||
case Mult:
|
||||
return BINARY_MULTIPLY;
|
||||
case MatMult:
|
||||
return BINARY_MATRIX_MULTIPLY;
|
||||
case Div:
|
||||
return BINARY_TRUE_DIVIDE;
|
||||
case Mod:
|
||||
|
@ -2689,6 +2693,8 @@ inplace_binop(struct compiler *c, operator_ty op)
|
|||
return INPLACE_SUBTRACT;
|
||||
case Mult:
|
||||
return INPLACE_MULTIPLY;
|
||||
case MatMult:
|
||||
return INPLACE_MATRIX_MULTIPLY;
|
||||
case Div:
|
||||
return INPLACE_TRUE_DIVIDE;
|
||||
case Mod:
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -15,8 +15,8 @@ static void *opcode_targets[256] = {
|
|||
&&_unknown_opcode,
|
||||
&&_unknown_opcode,
|
||||
&&TARGET_UNARY_INVERT,
|
||||
&&_unknown_opcode,
|
||||
&&_unknown_opcode,
|
||||
&&TARGET_BINARY_MATRIX_MULTIPLY,
|
||||
&&TARGET_INPLACE_MATRIX_MULTIPLY,
|
||||
&&_unknown_opcode,
|
||||
&&TARGET_BINARY_POWER,
|
||||
&&TARGET_BINARY_MULTIPLY,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue