Skip to content

Commit 40b3e4a

Browse files
williamwen42pytorchmergebot
authored andcommitted
[dynamo] expose code execution strategy to python (pytorch#148020)
@anijain2305 this can be used to mark a code object to be skipped/run-only (recursively) while tracing. Pull Request resolved: pytorch#148020 Approved by: https://github.com/jansel
1 parent e74fdbe commit 40b3e4a

File tree

7 files changed

+40
-21
lines changed

7 files changed

+40
-21
lines changed

torch/_C/_dynamo/eval_frame.pyi

+3-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
99
def get_eval_frame_callback() -> DynamoCallback: ...
1010
def reset_code(code: types.CodeType) -> None: ...
1111
def unsupported(obj1: object, obj2: object) -> object: ...
12-
def skip_code(code: types.CodeType) -> None: ...
12+
def set_code_exec_strategy(
13+
code: types.CodeType, strategy: _FrameExecStrategy
14+
) -> None: ...
1315
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
1416
def raise_sigtrap() -> None: ...
1517

torch/_dynamo/decorators.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DynamoStance,
2727
innermost_fn,
2828
RunOnlyContext,
29+
skip_code,
2930
)
3031
from .exc import IncorrectUsage
3132
from .external_utils import is_compiling
@@ -39,7 +40,6 @@
3940
reset_code,
4041
set_eval_frame,
4142
set_guard_error_hook,
42-
skip_code,
4343
unsupported,
4444
)
4545

torch/_dynamo/eval_frame.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@
5656
# see discussion at https://github.com/pytorch/pytorch/issues/120699
5757
from torch._C._dynamo.eval_frame import ( # noqa: F401
5858
reset_code,
59+
set_code_exec_strategy,
5960
set_eval_frame,
6061
set_guard_error_hook,
6162
set_skip_guard_eval_unsafe,
62-
skip_code,
6363
unsupported,
6464
)
6565
from torch._dispatch.python import enable_python_dispatcher
@@ -86,6 +86,7 @@
8686
from .exc import CondOpArgsMismatchError, ShortenTraceback, UserError, UserErrorType
8787
from .hooks import Hooks
8888
from .mutation_guard import install_generation_tagging_init
89+
from .types import FrameAction, FrameExecStrategy
8990
from .utils import common_constant_types, compile_times
9091

9192

@@ -1886,3 +1887,9 @@ def inner_fn(*args, **kwargs):
18861887
return fn(*args, **kwargs)
18871888

18881889
return inner_fn
1890+
1891+
1892+
def skip_code(code: types.CodeType):
1893+
set_code_exec_strategy(
1894+
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
1895+
)

torch/_dynamo/variables/functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def next_variable(self, tx):
527527
# test/dynamo/test_misc.py::test_iterator_limit
528528
raise
529529
except Unsupported as e:
530-
torch._C._dynamo.eval_frame.skip_code(self.get_code())
530+
torch._dynamo.eval_frame.skip_code(self.get_code())
531531
raise SkipFrame from e
532532
finally:
533533
counters["unimplemented"] |= counters["inline_call"]

torch/csrc/dynamo/eval_frame.c

+1-16
Original file line numberDiff line numberDiff line change
@@ -615,21 +615,6 @@ static PyObject* unsupported(PyObject* dummy, PyObject* args) {
615615
return obj2;
616616
}
617617

618-
static PyObject* skip_code(PyObject* dummy, PyObject* obj) {
619-
if (!PyCode_Check(obj)) {
620-
PyErr_SetString(PyExc_TypeError, "expected a code object");
621-
return NULL;
622-
}
623-
624-
PyCodeObject* code = (PyCodeObject*)obj;
625-
ExtraState* extra = get_extra_state(code);
626-
if (extra == NULL) {
627-
extra = init_and_set_extra_state(code);
628-
}
629-
extra_state_set_exec_strategy(extra, (FrameExecStrategy){SKIP, DEFAULT});
630-
Py_RETURN_NONE;
631-
}
632-
633618
static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
634619
if (obj == Py_None) {
635620
obj = NULL;
@@ -676,7 +661,7 @@ static PyMethodDef _methods[] = {
676661
{"get_eval_frame_callback", get_eval_frame_callback_py, METH_NOARGS, NULL},
677662
{"reset_code", reset_code, METH_O, NULL},
678663
{"unsupported", unsupported, METH_VARARGS, NULL},
679-
{"skip_code", skip_code, METH_O, NULL},
664+
{"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL},
680665
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
681666
{"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL},
682667
{NULL, NULL, 0, NULL}};

torch/csrc/dynamo/eval_frame_cpp.cpp

+24
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,27 @@ PyObject* dynamo__custom_eval_frame(
290290
}
291291
return eval_result;
292292
}
293+
294+
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) {
295+
PyObject* code_obj = nullptr;
296+
PyObject* strategy_obj = nullptr;
297+
if (!PyArg_ParseTuple(args, "OO", &code_obj, &strategy_obj)) {
298+
return nullptr;
299+
}
300+
if (!PyCode_Check(code_obj)) {
301+
PyErr_SetString(PyExc_TypeError, "expected a code object");
302+
return nullptr;
303+
}
304+
305+
PyCodeObject* code = (PyCodeObject*)code_obj;
306+
ExtraState* extra = get_extra_state(code);
307+
if (extra == nullptr) {
308+
extra = init_and_set_extra_state(code);
309+
}
310+
311+
FrameExecStrategy strategy =
312+
py::handle(strategy_obj).cast<FrameExecStrategy>();
313+
314+
extra_state_set_exec_strategy(extra, strategy);
315+
Py_RETURN_NONE;
316+
}

torch/csrc/dynamo/eval_frame_cpp.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#pragma once
22
#include <Python.h>
3-
#include <stdbool.h>
43

54
#include <torch/csrc/dynamo/eval_frame.h>
65
#include <torch/csrc/dynamo/extra_state.h>
@@ -17,6 +16,8 @@ PyObject* dynamo__custom_eval_frame(
1716
int throw_flag,
1817
PyObject* callback);
1918

19+
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj);
20+
2021
#ifdef __cplusplus
2122

2223
} // extern "C"

0 commit comments

Comments
 (0)