gh-105499: Merge typing.Union and types.UnionType (#105511)

Co-authored-by: Alex Waygood <Alex.Waygood@Gmail.com>
Co-authored-by: Ken Jin <kenjin@python.org>
Co-authored-by: Carl Meyer <carl@oddbird.net>
This commit is contained in:
Jelle Zijlstra 2025-03-04 11:44:19 -08:00 committed by GitHub
parent e091520fdb
commit dc6d66f44c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 562 additions and 327 deletions

View file

@ -2,8 +2,8 @@
#include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK, PyAnnotateFormat
#include "pycore_typevarobject.h"
#include "pycore_unionobject.h" // _Py_union_type_or
#include "pycore_unionobject.h" // _Py_union_type_or, _Py_union_from_tuple
#include "structmember.h"
/*[clinic input]
class typevar "typevarobject *" "&_PyTypeVar_Type"
@ -370,9 +370,13 @@ type_check(PyObject *arg, const char *msg)
static PyObject *
make_union(PyObject *self, PyObject *other)
{
PyObject *args[2] = {self, other};
PyObject *result = call_typing_func_object("_make_union", args, 2);
return result;
PyObject *args = PyTuple_Pack(2, self, other);
if (args == NULL) {
return NULL;
}
PyObject *u = _Py_union_from_tuple(args);
Py_DECREF(args);
return u;
}
static PyObject *

View file

@ -1,17 +1,17 @@
// types.UnionType -- used to represent e.g. Union[int, str], int | str
// typing.Union -- used to represent e.g. Union[int, str], int | str
#include "Python.h"
#include "pycore_object.h" // _PyObject_GC_TRACK/UNTRACK
#include "pycore_typevarobject.h" // _PyTypeAlias_Type, _Py_typing_type_repr
#include "pycore_unionobject.h"
static PyObject *make_union(PyObject *);
typedef struct {
PyObject_HEAD
PyObject *args;
PyObject *args; // all args (tuple)
PyObject *hashable_args; // frozenset or NULL
PyObject *unhashable_args; // tuple or NULL
PyObject *parameters;
PyObject *weakreflist;
} unionobject;
static void
@ -20,8 +20,13 @@ unionobject_dealloc(PyObject *self)
unionobject *alias = (unionobject *)self;
_PyObject_GC_UNTRACK(self);
if (alias->weakreflist != NULL) {
PyObject_ClearWeakRefs((PyObject *)alias);
}
Py_XDECREF(alias->args);
Py_XDECREF(alias->hashable_args);
Py_XDECREF(alias->unhashable_args);
Py_XDECREF(alias->parameters);
Py_TYPE(self)->tp_free(self);
}
@ -31,6 +36,8 @@ union_traverse(PyObject *self, visitproc visit, void *arg)
{
unionobject *alias = (unionobject *)self;
Py_VISIT(alias->args);
Py_VISIT(alias->hashable_args);
Py_VISIT(alias->unhashable_args);
Py_VISIT(alias->parameters);
return 0;
}
@ -39,13 +46,67 @@ static Py_hash_t
union_hash(PyObject *self)
{
unionobject *alias = (unionobject *)self;
PyObject *args = PyFrozenSet_New(alias->args);
if (args == NULL) {
return (Py_hash_t)-1;
// If there are any unhashable args, treat this union as unhashable.
// Otherwise, two unions might compare equal but have different hashes.
if (alias->unhashable_args) {
// Attempt to get an error from one of the values.
assert(PyTuple_CheckExact(alias->unhashable_args));
Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
Py_hash_t hash = PyObject_Hash(arg);
if (hash == -1) {
return -1;
}
}
// The unhashable values somehow became hashable again. Still raise
// an error.
PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
return -1;
}
Py_hash_t hash = PyObject_Hash(args);
Py_DECREF(args);
return hash;
return PyObject_Hash(alias->hashable_args);
}
static int
unions_equal(unionobject *a, unionobject *b)
{
int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
if (result == -1) {
return -1;
}
if (result == 0) {
return 0;
}
if (a->unhashable_args && b->unhashable_args) {
Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
return 0;
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
int result = PySequence_Contains(b->unhashable_args, arg_a);
if (result == -1) {
return -1;
}
if (!result) {
return 0;
}
}
for (Py_ssize_t i = 0; i < n; i++) {
PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
int result = PySequence_Contains(a->unhashable_args, arg_b);
if (result == -1) {
return -1;
}
if (!result) {
return 0;
}
}
}
else if (a->unhashable_args || b->unhashable_args) {
return 0;
}
return 1;
}
static PyObject *
@ -55,95 +116,130 @@ union_richcompare(PyObject *a, PyObject *b, int op)
Py_RETURN_NOTIMPLEMENTED;
}
PyObject *a_set = PySet_New(((unionobject*)a)->args);
if (a_set == NULL) {
int equal = unions_equal((unionobject*)a, (unionobject*)b);
if (equal == -1) {
return NULL;
}
PyObject *b_set = PySet_New(((unionobject*)b)->args);
if (b_set == NULL) {
Py_DECREF(a_set);
return NULL;
}
PyObject *result = PyObject_RichCompare(a_set, b_set, op);
Py_DECREF(b_set);
Py_DECREF(a_set);
return result;
}
static int
is_same(PyObject *left, PyObject *right)
{
int is_ga = _PyGenericAlias_Check(left) && _PyGenericAlias_Check(right);
return is_ga ? PyObject_RichCompareBool(left, right, Py_EQ) : left == right;
}
static int
contains(PyObject **items, Py_ssize_t size, PyObject *obj)
{
for (Py_ssize_t i = 0; i < size; i++) {
int is_duplicate = is_same(items[i], obj);
if (is_duplicate) { // -1 or 1
return is_duplicate;
}
}
return 0;
}
static PyObject *
merge(PyObject **items1, Py_ssize_t size1,
PyObject **items2, Py_ssize_t size2)
{
PyObject *tuple = NULL;
Py_ssize_t pos = 0;
for (Py_ssize_t i = 0; i < size2; i++) {
PyObject *arg = items2[i];
int is_duplicate = contains(items1, size1, arg);
if (is_duplicate < 0) {
Py_XDECREF(tuple);
return NULL;
}
if (is_duplicate) {
continue;
}
if (tuple == NULL) {
tuple = PyTuple_New(size1 + size2 - i);
if (tuple == NULL) {
return NULL;
}
for (; pos < size1; pos++) {
PyObject *a = items1[pos];
PyTuple_SET_ITEM(tuple, pos, Py_NewRef(a));
}
}
PyTuple_SET_ITEM(tuple, pos, Py_NewRef(arg));
pos++;
}
if (tuple) {
(void) _PyTuple_Resize(&tuple, pos);
}
return tuple;
}
static PyObject **
get_types(PyObject **obj, Py_ssize_t *size)
{
if (*obj == Py_None) {
*obj = (PyObject *)&_PyNone_Type;
}
if (_PyUnion_Check(*obj)) {
PyObject *args = ((unionobject *) *obj)->args;
*size = PyTuple_GET_SIZE(args);
return &PyTuple_GET_ITEM(args, 0);
if (op == Py_EQ) {
return PyBool_FromLong(equal);
}
else {
*size = 1;
return obj;
return PyBool_FromLong(!equal);
}
}
typedef struct {
PyObject *args; // list
PyObject *hashable_args; // set
PyObject *unhashable_args; // list or NULL
bool is_checked; // whether to call type_check()
} unionbuilder;
static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
static PyObject *make_union(unionbuilder *);
static PyObject *type_check(PyObject *, const char *);
static bool
unionbuilder_init(unionbuilder *ub, bool is_checked)
{
ub->args = PyList_New(0);
if (ub->args == NULL) {
return false;
}
ub->hashable_args = PySet_New(NULL);
if (ub->hashable_args == NULL) {
Py_DECREF(ub->args);
return false;
}
ub->unhashable_args = NULL;
ub->is_checked = is_checked;
return true;
}
static void
unionbuilder_finalize(unionbuilder *ub)
{
Py_DECREF(ub->args);
Py_DECREF(ub->hashable_args);
Py_XDECREF(ub->unhashable_args);
}
static bool
unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
{
Py_hash_t hash = PyObject_Hash(arg);
if (hash == -1) {
PyErr_Clear();
if (ub->unhashable_args == NULL) {
ub->unhashable_args = PyList_New(0);
if (ub->unhashable_args == NULL) {
return false;
}
}
else {
int contains = PySequence_Contains(ub->unhashable_args, arg);
if (contains < 0) {
return false;
}
if (contains == 1) {
return true;
}
}
if (PyList_Append(ub->unhashable_args, arg) < 0) {
return false;
}
}
else {
int contains = PySet_Contains(ub->hashable_args, arg);
if (contains < 0) {
return false;
}
if (contains == 1) {
return true;
}
if (PySet_Add(ub->hashable_args, arg) < 0) {
return false;
}
}
return PyList_Append(ub->args, arg) == 0;
}
static bool
unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
{
if (Py_IsNone(arg)) {
arg = (PyObject *)&_PyNone_Type; // immortal, so no refcounting needed
}
else if (_PyUnion_Check(arg)) {
PyObject *args = ((unionobject *)arg)->args;
return unionbuilder_add_tuple(ub, args);
}
if (ub->is_checked) {
PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
if (type == NULL) {
return false;
}
bool result = unionbuilder_add_single_unchecked(ub, type);
Py_DECREF(type);
return result;
}
else {
return unionbuilder_add_single_unchecked(ub, arg);
}
}
static bool
unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
{
Py_ssize_t n = PyTuple_GET_SIZE(tuple);
for (Py_ssize_t i = 0; i < n; i++) {
if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
return false;
}
}
return true;
}
static int
is_unionable(PyObject *obj)
{
@ -164,19 +260,18 @@ _Py_union_type_or(PyObject* self, PyObject* other)
Py_RETURN_NOTIMPLEMENTED;
}
Py_ssize_t size1, size2;
PyObject **items1 = get_types(&self, &size1);
PyObject **items2 = get_types(&other, &size2);
PyObject *tuple = merge(items1, size1, items2, size2);
if (tuple == NULL) {
if (PyErr_Occurred()) {
return NULL;
}
return Py_NewRef(self);
unionbuilder ub;
// unchecked because we already checked is_unionable()
if (!unionbuilder_init(&ub, false)) {
return NULL;
}
if (!unionbuilder_add_single(&ub, self) ||
!unionbuilder_add_single(&ub, other)) {
unionbuilder_finalize(&ub);
return NULL;
}
PyObject *new_union = make_union(tuple);
Py_DECREF(tuple);
PyObject *new_union = make_union(&ub);
return new_union;
}
@ -202,6 +297,18 @@ union_repr(PyObject *self)
goto error;
}
}
#if 0
PyUnicodeWriter_WriteUTF8(writer, "|args=", 6);
PyUnicodeWriter_WriteRepr(writer, alias->args);
PyUnicodeWriter_WriteUTF8(writer, "|h=", 3);
PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
if (alias->unhashable_args) {
PyUnicodeWriter_WriteUTF8(writer, "|u=", 3);
PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
}
#endif
return PyUnicodeWriter_Finish(writer);
error:
@ -231,21 +338,7 @@ union_getitem(PyObject *self, PyObject *item)
return NULL;
}
PyObject *res;
Py_ssize_t nargs = PyTuple_GET_SIZE(newargs);
if (nargs == 0) {
res = make_union(newargs);
}
else {
res = Py_NewRef(PyTuple_GET_ITEM(newargs, 0));
for (Py_ssize_t iarg = 1; iarg < nargs; iarg++) {
PyObject *arg = PyTuple_GET_ITEM(newargs, iarg);
Py_SETREF(res, PyNumber_Or(res, arg));
if (res == NULL) {
break;
}
}
}
PyObject *res = _Py_union_from_tuple(newargs);
Py_DECREF(newargs);
return res;
}
@ -267,7 +360,25 @@ union_parameters(PyObject *self, void *Py_UNUSED(unused))
return Py_NewRef(alias->parameters);
}
static PyObject *
union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
{
return PyUnicode_FromString("Union");
}
static PyObject *
union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
{
return Py_NewRef(&_PyUnion_Type);
}
static PyGetSetDef union_properties[] = {
{"__name__", union_name, NULL,
PyDoc_STR("Name of the type"), NULL},
{"__qualname__", union_name, NULL,
PyDoc_STR("Qualified name of the type"), NULL},
{"__origin__", union_origin, NULL,
PyDoc_STR("Always returns the type"), NULL},
{"__parameters__", union_parameters, (setter)NULL,
PyDoc_STR("Type variables in the types.UnionType."), NULL},
{0}
@ -306,10 +417,88 @@ _Py_union_args(PyObject *self)
return ((unionobject *) self)->args;
}
static PyObject *
call_typing_func_object(const char *name, PyObject **args, size_t nargs)
{
PyObject *typing = PyImport_ImportModule("typing");
if (typing == NULL) {
return NULL;
}
PyObject *func = PyObject_GetAttrString(typing, name);
if (func == NULL) {
Py_DECREF(typing);
return NULL;
}
PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
Py_DECREF(func);
Py_DECREF(typing);
return result;
}
static PyObject *
type_check(PyObject *arg, const char *msg)
{
if (Py_IsNone(arg)) {
// NoneType is immortal, so don't need an INCREF
return (PyObject *)Py_TYPE(arg);
}
// Fast path to avoid calling into typing.py
if (is_unionable(arg)) {
return Py_NewRef(arg);
}
PyObject *message_str = PyUnicode_FromString(msg);
if (message_str == NULL) {
return NULL;
}
PyObject *args[2] = {arg, message_str};
PyObject *result = call_typing_func_object("_type_check", args, 2);
Py_DECREF(message_str);
return result;
}
PyObject *
_Py_union_from_tuple(PyObject *args)
{
unionbuilder ub;
if (!unionbuilder_init(&ub, true)) {
return NULL;
}
if (PyTuple_CheckExact(args)) {
if (!unionbuilder_add_tuple(&ub, args)) {
return NULL;
}
}
else {
if (!unionbuilder_add_single(&ub, args)) {
return NULL;
}
}
return make_union(&ub);
}
static PyObject *
union_class_getitem(PyObject *cls, PyObject *args)
{
return _Py_union_from_tuple(args);
}
static PyObject *
union_mro_entries(PyObject *self, PyObject *args)
{
return PyErr_Format(PyExc_TypeError,
"Cannot subclass %R", self);
}
static PyMethodDef union_methods[] = {
{"__mro_entries__", union_mro_entries, METH_O},
{"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
{0}
};
PyTypeObject _PyUnion_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
.tp_name = "types.UnionType",
.tp_doc = PyDoc_STR("Represent a PEP 604 union type\n"
.tp_name = "typing.Union",
.tp_doc = PyDoc_STR("Represent a union type\n"
"\n"
"E.g. for int | str"),
.tp_basicsize = sizeof(unionobject),
@ -321,25 +510,64 @@ PyTypeObject _PyUnion_Type = {
.tp_hash = union_hash,
.tp_getattro = union_getattro,
.tp_members = union_members,
.tp_methods = union_methods,
.tp_richcompare = union_richcompare,
.tp_as_mapping = &union_as_mapping,
.tp_as_number = &union_as_number,
.tp_repr = union_repr,
.tp_getset = union_properties,
.tp_weaklistoffset = offsetof(unionobject, weakreflist),
};
static PyObject *
make_union(PyObject *args)
make_union(unionbuilder *ub)
{
assert(PyTuple_CheckExact(args));
Py_ssize_t n = PyList_GET_SIZE(ub->args);
if (n == 0) {
PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
unionbuilder_finalize(ub);
return NULL;
}
if (n == 1) {
PyObject *result = PyList_GET_ITEM(ub->args, 0);
Py_INCREF(result);
unionbuilder_finalize(ub);
return result;
}
PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
args = PyList_AsTuple(ub->args);
if (args == NULL) {
goto error;
}
hashable_args = PyFrozenSet_New(ub->hashable_args);
if (hashable_args == NULL) {
goto error;
}
if (ub->unhashable_args != NULL) {
unhashable_args = PyList_AsTuple(ub->unhashable_args);
if (unhashable_args == NULL) {
goto error;
}
}
unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
if (result == NULL) {
return NULL;
goto error;
}
unionbuilder_finalize(ub);
result->parameters = NULL;
result->args = Py_NewRef(args);
result->args = args;
result->hashable_args = hashable_args;
result->unhashable_args = unhashable_args;
result->weakreflist = NULL;
_PyObject_GC_TRACK(result);
return (PyObject*)result;
error:
Py_XDECREF(args);
Py_XDECREF(hashable_args);
Py_XDECREF(unhashable_args);
unionbuilder_finalize(ub);
return NULL;
}