Skip to content

Commit 295ccb5

Browse files
committed
gh-92810: Reduce memory usage by ABCMeta.__subclasscheck__
1 parent b1a574f commit 295ccb5

File tree

4 files changed

+220
-58
lines changed

4 files changed

+220
-58
lines changed

Lib/_py_abc.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ def __new__(mcls, name, bases, namespace, /, **kwargs):
4949
cls._abc_cache = WeakSet()
5050
cls._abc_negative_cache = WeakSet()
5151
cls._abc_negative_cache_version = ABCMeta._abc_invalidation_counter
52+
53+
# Performance optimization for common case
54+
cls._abc_should_check_subclasses = False
55+
if "__subclasses__" in namespace:
56+
cls._abc_should_check_subclasses = True
57+
for base in bases:
58+
if hasattr(base, "_abc_should_check_subclasses"):
59+
base._abc_should_check_subclasses = True
5260
return cls
5361

5462
def register(cls, subclass):
@@ -65,8 +73,20 @@ def register(cls, subclass):
6573
if issubclass(cls, subclass):
6674
# This would create a cycle, which is bad for the algorithm below
6775
raise RuntimeError("Refusing to create an inheritance cycle")
76+
# Add registry entry
6877
cls._abc_registry.add(subclass)
6978
ABCMeta._abc_invalidation_counter += 1 # Invalidate negative cache
79+
# Recursively register the subclass in all ABC bases,
80+
# to avoid recursive lookups down the class tree.
81+
# >>> class Ancestor1(ABC): pass
82+
# >>> class Ancestor2(Ancestor1): pass
83+
# >>> class Other: pass
84+
# >>> Ancestor2.register(Other) # calls Ancestor1.register(Other)
85+
# >>> issubclass(Other, Ancestor2) is True
86+
# >>> issubclass(Other, Ancestor1) is True # already in registry
87+
for base in cls.__bases__:
88+
if hasattr(base, "_abc_registry"):
89+
base.register(subclass)
7090
return subclass
7191

7292
def _dump_registry(cls, file=None):
@@ -137,11 +157,16 @@ def __subclasscheck__(cls, subclass):
137157
if issubclass(subclass, rcls):
138158
cls._abc_cache.add(subclass)
139159
return True
140-
# Check if it's a subclass of a subclass (recursive)
141-
for scls in cls.__subclasses__():
142-
if issubclass(subclass, scls):
143-
cls._abc_cache.add(subclass)
144-
return True
160+
# Check if it's a subclass of a subclass (recursive).
161+
# If __subclasses__ contain only ABCs,
162+
# calling issubclass(...) will trigger the same __subclasscheck__
163+
# on *every* element of class inheritance tree.
164+
# Performing that only in resence of `def __subclasses__()` classmethod
165+
if cls._abc_should_check_subclasses:
166+
for scls in cls.__subclasses__():
167+
if issubclass(subclass, scls):
168+
cls._abc_cache.add(subclass)
169+
return True
145170
# No dice; update negative cache
146171
cls._abc_negative_cache.add(subclass)
147172
return False

Lib/abc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ class ABCMeta(type):
104104
"""
105105
def __new__(mcls, name, bases, namespace, /, **kwargs):
106106
cls = super().__new__(mcls, name, bases, namespace, **kwargs)
107-
_abc_init(cls)
107+
_abc_init(cls, bases, namespace)
108108
return cls
109109

110110
def register(cls, subclass):

Modules/_abc.c

Lines changed: 154 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ typedef struct {
5757
PyObject *_abc_cache;
5858
PyObject *_abc_negative_cache;
5959
uint64_t _abc_negative_cache_version;
60+
uint8_t _abc_should_check_subclasses;
6061
} _abc_data;
6162

6263
#define _abc_data_CAST(op) ((_abc_data *)(op))
@@ -73,6 +74,18 @@ set_cache_version(_abc_data *impl, uint64_t version)
7374
FT_ATOMIC_STORE_UINT64_RELAXED(impl->_abc_negative_cache_version, version);
7475
}
7576

77+
static inline uint8_t
78+
get_should_check_subclasses(_abc_data *impl)
79+
{
80+
return FT_ATOMIC_LOAD_UINT8_RELAXED(impl->_abc_should_check_subclasses);
81+
}
82+
83+
static inline void
84+
set_should_check_subclasses(_abc_data *impl)
85+
{
86+
FT_ATOMIC_STORE_UINT8_RELAXED(impl->_abc_should_check_subclasses, 1);
87+
}
88+
7689
static int
7790
abc_data_traverse(PyObject *op, visitproc visit, void *arg)
7891
{
@@ -123,6 +136,7 @@ abc_data_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
123136
self->_abc_cache = NULL;
124137
self->_abc_negative_cache = NULL;
125138
self->_abc_negative_cache_version = get_invalidation_counter(state);
139+
self->_abc_should_check_subclasses = 0;
126140
return (PyObject *) self;
127141
}
128142

@@ -161,6 +175,30 @@ _get_impl(PyObject *module, PyObject *self)
161175
return (_abc_data *)impl;
162176
}
163177

178+
/* If class is inherited from ABC, set data to point to internal ABC state of class, and return 1.
179+
If object is not inherited from ABC, return 0.
180+
If error is encountered, return -1.
181+
*/
182+
static int
183+
_get_optional_impl(_abcmodule_state *state, PyObject *self, _abc_data **data)
184+
{
185+
assert(data != NULL);
186+
PyObject *impl = NULL;
187+
int res = PyObject_GetOptionalAttr(self, &_Py_ID(_abc_impl), &impl);
188+
if (res <= 0) {
189+
*data = NULL;
190+
return res;
191+
}
192+
if (!Py_IS_TYPE(impl, state->_abc_data_type)) {
193+
PyErr_SetString(PyExc_TypeError, "_abc_impl is set to a wrong type");
194+
Py_DECREF(impl);
195+
*data = NULL;
196+
return -1;
197+
}
198+
*data = (_abc_data *)impl;
199+
return 1;
200+
}
201+
164202
static int
165203
_in_weak_set(_abc_data *impl, PyObject **pset, PyObject *obj)
166204
{
@@ -331,39 +369,34 @@ _abc__get_dump(PyObject *module, PyObject *self)
331369
}
332370
PyObject *res;
333371
Py_BEGIN_CRITICAL_SECTION(impl);
334-
res = Py_BuildValue("NNNK",
372+
res = Py_BuildValue("NNNKK",
335373
PySet_New(impl->_abc_registry),
336374
PySet_New(impl->_abc_cache),
337375
PySet_New(impl->_abc_negative_cache),
338-
get_cache_version(impl));
376+
get_cache_version(impl),
377+
get_should_check_subclasses(impl));
339378
Py_END_CRITICAL_SECTION();
340379
Py_DECREF(impl);
341380
return res;
342381
}
343382

344383
// Compute set of abstract method names.
345384
static int
346-
compute_abstract_methods(PyObject *self)
385+
compute_abstract_methods(PyObject *self, PyObject *bases, PyObject *ns)
347386
{
348387
int ret = -1;
349388
PyObject *abstracts = PyFrozenSet_New(NULL);
350389
if (abstracts == NULL) {
351390
return -1;
352391
}
353392

354-
PyObject *ns = NULL, *items = NULL, *bases = NULL; // Py_XDECREF()ed on error.
355-
356393
/* Stage 1: direct abstract methods. */
357-
ns = PyObject_GetAttr(self, &_Py_ID(__dict__));
358-
if (!ns) {
359-
goto error;
360-
}
361-
362394
// We can't use PyDict_Next(ns) even when ns is dict because
363395
// _PyObject_IsAbstract() can mutate ns.
364-
items = PyMapping_Items(ns);
396+
PyObject *items = PyMapping_Items(ns);
365397
if (!items) {
366-
goto error;
398+
Py_DECREF(abstracts);
399+
return -1;
367400
}
368401
assert(PyList_Check(items));
369402
for (Py_ssize_t pos = 0; pos < PyList_GET_SIZE(items); pos++) {
@@ -398,15 +431,6 @@ compute_abstract_methods(PyObject *self)
398431
}
399432

400433
/* Stage 2: inherited abstract methods. */
401-
bases = PyObject_GetAttr(self, &_Py_ID(__bases__));
402-
if (!bases) {
403-
goto error;
404-
}
405-
if (!PyTuple_Check(bases)) {
406-
PyErr_SetString(PyExc_TypeError, "__bases__ is not tuple");
407-
goto error;
408-
}
409-
410434
for (Py_ssize_t pos = 0; pos < PyTuple_GET_SIZE(bases); pos++) {
411435
PyObject *item = PyTuple_GET_ITEM(bases, pos); // borrowed
412436
PyObject *base_abstracts, *iter;
@@ -459,31 +483,58 @@ compute_abstract_methods(PyObject *self)
459483
ret = 0;
460484
error:
461485
Py_DECREF(abstracts);
462-
Py_XDECREF(ns);
463-
Py_XDECREF(items);
464-
Py_XDECREF(bases);
486+
Py_DECREF(items);
465487
return ret;
466488
}
467489

490+
/*
491+
* Notify base classes that child one has __subclasses__ overriden.
492+
* Used as performance optimization in __subclasscheck__
493+
*/
494+
static int
495+
_abc_notify_subclasses_override(_abcmodule_state *state, PyObject *data, PyObject *bases)
496+
{
497+
set_should_check_subclasses((_abc_data*) data);
498+
499+
for (Py_ssize_t pos = 0; pos < PyTuple_GET_SIZE(bases); pos++) {
500+
PyObject *base_class = PyTuple_GET_ITEM(bases, pos); // borrowed
501+
_abc_data *base_impl = NULL;
502+
int base_is_abc = _get_optional_impl(state, base_class, &base_impl);
503+
if (base_is_abc < 0) {
504+
return -1;
505+
}
506+
if (base_is_abc == 0) {
507+
continue;
508+
}
509+
set_should_check_subclasses(base_impl);
510+
Py_DECREF(base_impl);
511+
}
512+
513+
return 0;
514+
}
515+
468516
#define COLLECTION_FLAGS (Py_TPFLAGS_SEQUENCE | Py_TPFLAGS_MAPPING)
469517

470518
/*[clinic input]
471519
@permit_long_summary
472520
_abc._abc_init
473521
474522
self: object
523+
bases: object(subclass_of="&PyTuple_Type")
524+
namespace: object(subclass_of="&PyDict_Type")
475525
/
476526
477527
Internal ABC helper for class set-up. Should be never used outside abc module.
478528
[clinic start generated code]*/
479529

480530
static PyObject *
481-
_abc__abc_init(PyObject *module, PyObject *self)
482-
/*[clinic end generated code: output=594757375714cda1 input=0b3513f947736d39]*/
531+
_abc__abc_init_impl(PyObject *module, PyObject *self, PyObject *bases,
532+
PyObject *namespace)
533+
/*[clinic end generated code: output=a410180fefc86056 input=a984e4f7d36d6298]*/
483534
{
484535
_abcmodule_state *state = get_abc_state(module);
485536
PyObject *data;
486-
if (compute_abstract_methods(self) < 0) {
537+
if (compute_abstract_methods(self, bases, namespace) < 0) {
487538
return NULL;
488539
}
489540

@@ -492,6 +543,12 @@ _abc__abc_init(PyObject *module, PyObject *self)
492543
if (data == NULL) {
493544
return NULL;
494545
}
546+
if (PyDict_ContainsString(namespace, "__subclasses__")) {
547+
if (_abc_notify_subclasses_override(state, data, bases) < 0) {
548+
Py_DECREF(data);
549+
return NULL;
550+
}
551+
}
495552
if (PyObject_SetAttr(self, &_Py_ID(_abc_impl), data) < 0) {
496553
Py_DECREF(data);
497554
return NULL;
@@ -564,6 +621,7 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
564621
if (result < 0) {
565622
return NULL;
566623
}
624+
/* Add registry entry */
567625
_abc_data *impl = _get_impl(module, self);
568626
if (impl == NULL) {
569627
return NULL;
@@ -575,7 +633,43 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
575633
Py_DECREF(impl);
576634

577635
/* Invalidate negative cache */
578-
increment_invalidation_counter(get_abc_state(module));
636+
_abcmodule_state *state = get_abc_state(module);
637+
increment_invalidation_counter(state);
638+
639+
/*
640+
* Recursively register the subclass in all ABC bases,
641+
* to avoid recursive lookups down the class tree.
642+
* >>> class Ancestor1(ABC): pass
643+
* >>> class Ancestor2(Ancestor1): pass
644+
* >>> class Other: pass
645+
* >>> Ancestor2.register(Other) # calls Ancestor1.register(Other)
646+
* >>> issubclass(Other, Ancestor2) is True
647+
* >>> issubclass(Other, Ancestor1) is True # already in registry
648+
*/
649+
PyObject *bases = PyObject_GetAttr(self, &_Py_ID(__bases__));
650+
if (!bases) {
651+
return NULL;
652+
}
653+
if (!PyTuple_Check(bases)) {
654+
PyErr_SetString(PyExc_TypeError, "__bases__ is not tuple");
655+
goto error;
656+
}
657+
for (Py_ssize_t pos = 0; pos < PyTuple_GET_SIZE(bases); pos++) {
658+
PyObject *base = PyTuple_GET_ITEM(bases, pos); // borrowed
659+
int base_is_abc = PyObject_HasAttrWithError(base, &_Py_ID(_abc_impl));
660+
if (base_is_abc < 0) {
661+
goto error;
662+
}
663+
if (base_is_abc == 0) {
664+
continue;
665+
}
666+
PyObject *res = PyObject_CallMethod(base, "register", "O", subclass);
667+
Py_XDECREF(res);
668+
if (!res) {
669+
goto error;
670+
}
671+
}
672+
Py_DECREF(bases);
579673

580674
/* Set Py_TPFLAGS_SEQUENCE or Py_TPFLAGS_MAPPING flag */
581675
if (PyType_Check(self)) {
@@ -588,6 +682,10 @@ _abc__abc_register_impl(PyObject *module, PyObject *self, PyObject *subclass)
588682
}
589683
}
590684
return Py_NewRef(subclass);
685+
686+
error:
687+
Py_DECREF(bases);
688+
return NULL;
591689
}
592690

593691

@@ -788,31 +886,38 @@ _abc__abc_subclasscheck_impl(PyObject *module, PyObject *self,
788886
goto end;
789887
}
790888

791-
/* 6. Check if it's a subclass of a subclass (recursive). */
792-
subclasses = PyObject_CallMethod(self, "__subclasses__", NULL);
793-
if (subclasses == NULL) {
794-
goto end;
795-
}
796-
if (!PyList_Check(subclasses)) {
797-
PyErr_SetString(PyExc_TypeError, "__subclasses__() must return a list");
798-
goto end;
799-
}
800-
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
801-
PyObject *scls = PyList_GetItemRef(subclasses, pos);
802-
if (scls == NULL) {
889+
/* 6. Check if it's a subclass of a subclass (recursive).
890+
* If __subclasses__ contain only ABCs,
891+
* calling issubclass(...) will trigger the same __subclasscheck__
892+
* on *every* element of class inheritance tree.
893+
* Performing that only in resence of `def __subclasses__()` classmethod
894+
*/
895+
if (get_should_check_subclasses(impl)) {
896+
subclasses = PyObject_CallMethod(self, "__subclasses__", NULL);
897+
if (subclasses == NULL) {
803898
goto end;
804899
}
805-
int r = PyObject_IsSubclass(subclass, scls);
806-
Py_DECREF(scls);
807-
if (r > 0) {
808-
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
809-
goto end;
810-
}
811-
result = Py_True;
900+
if (!PyList_Check(subclasses)) {
901+
PyErr_SetString(PyExc_TypeError, "__subclasses__() must return a list");
812902
goto end;
813903
}
814-
if (r < 0) {
815-
goto end;
904+
for (pos = 0; pos < PyList_GET_SIZE(subclasses); pos++) {
905+
PyObject *scls = PyList_GetItemRef(subclasses, pos);
906+
if (scls == NULL) {
907+
goto end;
908+
}
909+
int r = PyObject_IsSubclass(subclass, scls);
910+
Py_DECREF(scls);
911+
if (r > 0) {
912+
if (_add_to_weak_set(impl, &impl->_abc_cache, subclass) < 0) {
913+
goto end;
914+
}
915+
result = Py_True;
916+
goto end;
917+
}
918+
if (r < 0) {
919+
goto end;
920+
}
816921
}
817922
}
818923

0 commit comments

Comments
 (0)