Skip to content

Commit 324d75e

Browse files
committed
gh-150724: Optimize JIT keyword calls to exact args
1 parent 84630e2 commit 324d75e

3 files changed

Lines changed: 304 additions & 2 deletions

File tree

Lib/test/test_capi/test_opt.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,89 @@ def dummy(x):
715715
self.assertNotIn("_GUARD_CODE_VERSION__PUSH_FRAME", uops)
716716
self.assertNotIn("_GUARD_IP__PUSH_FRAME", uops)
717717

718+
def assert_kw_call_optimized(self, ex):
719+
uops = get_opnames(ex)
720+
self.assertNotIn("_PY_FRAME_KW", uops)
721+
init_index = next(
722+
(i for i, opname in enumerate(uops)
723+
if opname.startswith("_INIT_CALL_PY_EXACT_ARGS")),
724+
None,
725+
)
726+
self.assertIsNotNone(init_index, uops)
727+
pop_index = max(
728+
(i for i, opname in enumerate(uops[:init_index])
729+
if opname == "_POP_TOP"),
730+
default=None,
731+
)
732+
self.assertIsNotNone(pop_index, uops)
733+
stack_check_index = max(
734+
(i for i, opname in enumerate(uops[:init_index])
735+
if opname == "_CHECK_STACK_SPACE_OPERAND"),
736+
default=None,
737+
)
738+
self.assertIsNotNone(stack_check_index, uops)
739+
self.assertLess(stack_check_index, pop_index, uops)
740+
self.assertNotIn("_CHECK_FUNCTION_EXACT_ARGS", uops[pop_index:init_index])
741+
return uops, pop_index, init_index
742+
743+
def test_call_kw_py_exact_args(self):
744+
def callee(x, a, b):
745+
return x + a + b
746+
747+
def testfunc(n):
748+
total = 0
749+
for i in range(n):
750+
total += callee(i, b=2, a=1)
751+
return total
752+
753+
res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
754+
self.assertEqual(res, TIER2_THRESHOLD * (TIER2_THRESHOLD - 1) // 2 + 3 * TIER2_THRESHOLD)
755+
self.assertIsNotNone(ex)
756+
uops, pop_index, init_index = self.assert_kw_call_optimized(ex)
757+
self.assertTrue(
758+
any(opname.startswith("_SWAP") for opname in uops[pop_index:init_index]),
759+
uops,
760+
)
761+
self.assertIn("_BINARY_OP_ADD_INT", uops)
762+
763+
def test_call_kw_py_exact_args_no_reorder(self):
764+
def callee(x, a, b):
765+
return x + a + b
766+
767+
def testfunc(n):
768+
total = 0
769+
for i in range(n):
770+
total += callee(i, a=1, b=2)
771+
return total
772+
773+
res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
774+
self.assertEqual(res, TIER2_THRESHOLD * (TIER2_THRESHOLD - 1) // 2 + 3 * TIER2_THRESHOLD)
775+
self.assertIsNotNone(ex)
776+
uops, pop_index, init_index = self.assert_kw_call_optimized(ex)
777+
self.assertFalse(
778+
any(opname.startswith("_SWAP") for opname in uops[pop_index:init_index]),
779+
uops,
780+
)
781+
782+
def test_call_kw_bound_method_exact_args(self):
783+
class C:
784+
def callee(self, x, a, b):
785+
return x + a + b
786+
787+
obj = C()
788+
789+
def testfunc(n):
790+
total = 0
791+
for i in range(n):
792+
total += obj.callee(i, b=2, a=1)
793+
return total
794+
795+
res, ex = self._run_with_optimizer(testfunc, TIER2_THRESHOLD)
796+
self.assertEqual(res, TIER2_THRESHOLD * (TIER2_THRESHOLD - 1) // 2 + 3 * TIER2_THRESHOLD)
797+
self.assertIsNotNone(ex)
798+
uops, _, _ = self.assert_kw_call_optimized(ex)
799+
self.assertIn("_BINARY_OP_ADD_INT", uops)
800+
718801
def test_int_type_propagate_through_range(self):
719802
def testfunc(n):
720803

Python/optimizer_bytecodes.c

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1286,7 +1286,117 @@ dummy_func(void) {
12861286
}
12871287

12881288
op(_PY_FRAME_KW, (callable, self_or_null, args[oparg], kwnames -- new_frame)) {
1289-
new_frame = PyJitRef_WrapInvalid(frame_new_from_symbol(ctx, callable, NULL, 0));
1289+
bool valid = false;
1290+
PyObject *func_o = sym_get_const(ctx, callable);
1291+
PyObject *kwnames_o = sym_get_const(ctx, kwnames);
1292+
bool has_self = sym_is_not_null(self_or_null);
1293+
PyCodeObject *co = NULL;
1294+
Py_ssize_t total_args = 0;
1295+
int desired[256];
1296+
JitOptRef frame_args[257];
1297+
1298+
if ((has_self || sym_is_null(self_or_null)) &&
1299+
func_o != NULL && PyFunction_Check(func_o) &&
1300+
kwnames_o != NULL && PyTuple_CheckExact(kwnames_o) &&
1301+
oparg <= 256)
1302+
{
1303+
PyFunctionObject *func = (PyFunctionObject *)func_o;
1304+
co = (PyCodeObject *)func->func_code;
1305+
Py_ssize_t kwcount = PyTuple_GET_SIZE(kwnames_o);
1306+
total_args = oparg + has_self;
1307+
Py_ssize_t positional_args = total_args - kwcount;
1308+
Py_ssize_t positional_stack_args = positional_args - has_self;
1309+
1310+
if ((co->co_flags & (CO_OPTIMIZED | CO_VARARGS | CO_VARKEYWORDS)) == CO_OPTIMIZED &&
1311+
co->co_kwonlyargcount == 0 &&
1312+
co->co_argcount == total_args &&
1313+
positional_args >= has_self)
1314+
{
1315+
int source_for_local[257];
1316+
for (int i = 0; i < total_args; i++) {
1317+
source_for_local[i] = -1;
1318+
}
1319+
if (has_self) {
1320+
source_for_local[0] = -2;
1321+
}
1322+
for (int i = 0; i < positional_stack_args; i++) {
1323+
source_for_local[has_self + i] = i;
1324+
}
1325+
1326+
valid = true;
1327+
for (Py_ssize_t i = 0; valid && i < kwcount; i++) {
1328+
PyObject *keyword = PyTuple_GET_ITEM(kwnames_o, i);
1329+
if (!PyUnicode_CheckExact(keyword)) {
1330+
valid = false;
1331+
break;
1332+
}
1333+
int target = -1;
1334+
for (int j = co->co_posonlyargcount; j < co->co_argcount; j++) {
1335+
PyObject *varname = PyTuple_GET_ITEM(co->co_localsplusnames, j);
1336+
if (keyword == varname || PyUnicode_Equal(keyword, varname)) {
1337+
target = j;
1338+
break;
1339+
}
1340+
}
1341+
if (target < has_self || target < 0 || source_for_local[target] != -1) {
1342+
valid = false;
1343+
break;
1344+
}
1345+
source_for_local[target] = (int)(positional_stack_args + i);
1346+
}
1347+
1348+
if (has_self) {
1349+
frame_args[0] = self_or_null;
1350+
}
1351+
for (int local = 0; valid && local < co->co_argcount; local++) {
1352+
if (source_for_local[local] == -1) {
1353+
valid = false;
1354+
break;
1355+
}
1356+
if (local >= has_self) {
1357+
int source = source_for_local[local];
1358+
desired[local - has_self] = source;
1359+
frame_args[local] = args[source];
1360+
}
1361+
}
1362+
}
1363+
}
1364+
1365+
if (!valid) {
1366+
new_frame = PyJitRef_WrapInvalid(frame_new_from_symbol(ctx, callable, NULL, 0));
1367+
}
1368+
else {
1369+
int current[256];
1370+
for (int i = 0; i < oparg; i++) {
1371+
current[i] = i;
1372+
}
1373+
1374+
ADD_OP(_CHECK_STACK_SPACE_OPERAND, 0, co->co_framesize);
1375+
ADD_OP(_POP_TOP, 0, 0);
1376+
for (int pos = 0; pos < oparg - 1; pos++) {
1377+
int source = desired[pos];
1378+
int source_pos = pos;
1379+
while (current[source_pos] != source) {
1380+
source_pos++;
1381+
}
1382+
if (source_pos != pos) {
1383+
int top = oparg - 1;
1384+
if (source_pos != top) {
1385+
ADD_OP(_SWAP, oparg - source_pos, 0);
1386+
int temp = current[source_pos];
1387+
current[source_pos] = current[top];
1388+
current[top] = temp;
1389+
}
1390+
ADD_OP(_SWAP, oparg - pos, 0);
1391+
int temp = current[pos];
1392+
current[pos] = current[top];
1393+
current[top] = temp;
1394+
}
1395+
}
1396+
1397+
ADD_OP(_INIT_CALL_PY_EXACT_ARGS, oparg, 0);
1398+
new_frame = PyJitRef_WrapInvalid(frame_new_from_symbol(ctx, callable, frame_args, (int)total_args));
1399+
}
12901400
}
12911401

12921402
op(_PY_FRAME_EX, (func_st, null, callargs_st, kwargs_st -- ex_frame)) {

Python/optimizer_cases.c.h

Lines changed: 110 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)