Skip to content

Commit 3b24191

Browse files
committed
Call cls(fd) if working on a copy
1 parent 79528af commit 3b24191

2 files changed

Lines changed: 30 additions & 13 deletions

File tree

Lib/test/test_dict.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1939,27 +1939,33 @@ def test_fromkeys(self):
19391939
# Subclass which overrides the constructor
19401940
created = frozendict(x=1)
19411941
class FrozenDictSubclass(frozendict):
1942-
def __new__(self):
1943-
return created
1942+
def __new__(cls, *args, **kwargs):
1943+
if args or kwargs:
1944+
return super().__new__(cls, *args, **kwargs)
1945+
else:
1946+
return created
19441947

19451948
fd = FrozenDictSubclass.fromkeys("abc")
19461949
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
1947-
self.assertEqual(type(fd), frozendict)
1950+
self.assertEqual(type(fd), FrozenDictSubclass)
19481951
self.assertEqual(created, frozendict(x=1))
19491952

19501953
fd = FrozenDictSubclass.fromkeys(frozendict(y=2))
19511954
self.assertEqual(fd, frozendict(x=1, y=None))
1952-
self.assertEqual(type(fd), frozendict)
1955+
self.assertEqual(type(fd), FrozenDictSubclass)
19531956
self.assertEqual(created, frozendict(x=1))
19541957

19551958
# Dict subclass which overrides the constructor
19561959
class DictSubclass(dict):
1957-
def __new__(self):
1958-
return created
1960+
def __new__(cls, *args, **kwargs):
1961+
if args or kwargs:
1962+
return super().__new__(cls, *args, **kwargs)
1963+
else:
1964+
return created
19591965

19601966
fd = DictSubclass.fromkeys("abc")
19611967
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
1962-
self.assertEqual(type(fd), frozendict)
1968+
self.assertEqual(type(fd), DictSubclass)
19631969
self.assertEqual(created, frozendict(x=1))
19641970

19651971
# Subclass which doesn't override the constructor

Objects/dictobject.c

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3419,12 +3419,13 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
34193419
PyObject *key;
34203420
PyObject *d;
34213421
int status;
3422+
int need_copy = 0;
34223423

34233424
PyTypeObject *cls_type = _PyType_CAST(cls);
34243425
if (PyObject_IsSubclass(cls, (PyObject*)&PyFrozenDict_Type)
34253426
&& cls_type->tp_new == frozendict_new)
34263427
{
3427-
// gh-151722: Create a frozendict copy which is not tracked by the GC.
3428+
// gh-151722: Create a frozendict which is not tracked by the GC.
34283429
d = frozendict_new_untracked(cls_type);
34293430
}
34303431
else {
@@ -3437,11 +3438,14 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
34373438
// gh-151722: If cls constructor returns a frozendict which is tracked by
34383439
// the GC, create a frozendict copy which is not tracked by the GC.
34393440
//
3440-
// Untracking the dictionary requires tracking again the dictionary on
3441+
// At the function exit, return cls(fd) where fd is a frozendict.
3442+
//
3443+
// Untracking the frozendict requires tracking again the frozendict on
34413444
// error which is more complicated. It's easier to work on a copy.
34423445
if (PyFrozenDict_Check(d) && _PyObject_GC_IS_TRACKED(d)) {
3443-
// Subclass-friendly copy
3444-
PyObject *copy = frozendict_new_untracked(Py_TYPE(d));
3446+
need_copy = 1;
3447+
3448+
PyObject *copy = frozendict_new_untracked(&PyFrozenDict_Type);
34453449
if (copy == NULL) {
34463450
goto Fail;
34473451
}
@@ -3555,8 +3559,15 @@ dict_iter_exit:;
35553559
return NULL;
35563560

35573561
Done:
3558-
// d can be NULL
3559-
if (d != NULL && !_PyObject_GC_IS_TRACKED(d)) {
3562+
if (d == NULL) {
3563+
return NULL;
3564+
}
3565+
3566+
if (need_copy) {
3567+
PyObject *copy = _PyObject_CallOneArg(cls, d);
3568+
Py_SETREF(d, copy);
3569+
}
3570+
else if (!_PyObject_GC_IS_TRACKED(d)) {
35603571
_PyObject_GC_TRACK(d);
35613572
}
35623573
return d;

0 commit comments

Comments
 (0)