Skip to content

Commit 598e7e5

Browse files
samdowpytorchmergebot
samdow
authored andcommitted
[Reland] Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other Pull Request resolved: pytorch#76562 Approved by: https://github.com/zou3519, https://github.com/albanD
1 parent 4796955 commit 598e7e5

18 files changed

+105
-105
lines changed

aten/src/ATen/ThreadLocalState.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ThreadLocalState::ThreadLocalState()
1919

2020
saved_tensors_default_hooks_ = at::SavedTensorDefaultHooks::get_stack();
2121

22-
python_mode_state_ = at::impl::PythonModeTLS::get_state();
22+
torch_dispatch_mode_state_ = at::impl::TorchDispatchModeTLS::get_state();
2323
}
2424

2525
void ThreadLocalState::set_grad_mode(bool enabled) {
@@ -33,7 +33,7 @@ void ThreadLocalState::setThreadLocalState(
3333
// restore the dispatch key set TLS at the same time.
3434
c10::AutogradState::set_tls_state(state.autograd_tls_);
3535

36-
at::impl::PythonModeTLS::set_state(state.python_mode_state_);
36+
at::impl::TorchDispatchModeTLS::set_state(state.torch_dispatch_mode_state_);
3737

3838
at::impl::PythonTorchFunctionTLS::set_state(state.python_torch_function_state_);
3939

aten/src/ATen/ThreadLocalState.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
#include <ATen/record_function.h>
1111
#include <ATen/FuncTorchTLS.h>
12-
#include <ATen/core/PythonModeTLS.h>
12+
#include <ATen/core/TorchDispatchModeTLS.h>
1313
#include <ATen/PythonTorchFunctionTLS.h>
1414

1515
namespace at {
@@ -54,8 +54,8 @@ class TORCH_API ThreadLocalState {
5454
// TLS for AutogradModes
5555
AutogradState autograd_tls_;
5656

57-
// TLS for enable_python_mode (__torch_dispatch__)
58-
std::shared_ptr<SafePyObject> python_mode_state_;
57+
// TLS for enable_torch_dispatch_mode
58+
std::shared_ptr<SafePyObject> torch_dispatch_mode_state_;
5959

6060
// TLS for __torch_function__ (mode and disable_torch_function)
6161
at::impl::PythonTorchFunctionTLS python_torch_function_state_;

aten/src/ATen/core/PythonFallbackKernel.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include <ATen/core/PythonModeTLS.h>
1+
#include <ATen/core/TorchDispatchModeTLS.h>
22
#include <ATen/core/PythonFallbackKernel.h>
33
#include <c10/core/SafePyObject.h>
44

@@ -50,10 +50,10 @@ void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
5050
c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
5151

5252

53-
// If Python Mode is active, use its PyInterpreter for dispatch
54-
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
55-
if (maybe_python_mode_state) {
56-
maybe_python_mode_state->pyinterpreter()->dispatch(op, stack, maybe_python_mode_state);
53+
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
54+
const auto& maybe_torch_dispatch_mode_state = at::impl::TorchDispatchModeTLS::get_state();
55+
if (maybe_torch_dispatch_mode_state) {
56+
maybe_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack, maybe_torch_dispatch_mode_state);
5757
return;
5858
}
5959

aten/src/ATen/core/PythonModeTLS.cpp renamed to aten/src/ATen/core/TorchDispatchModeTLS.cpp

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
#include <ATen/core/PythonModeTLS.h>
1+
#include <ATen/core/TorchDispatchModeTLS.h>
22
#include <c10/core/SafePyObject.h>
33

44
namespace at { namespace impl {
55

6-
thread_local std::shared_ptr<SafePyObject> pythonModeState;
6+
thread_local std::shared_ptr<SafePyObject> torchDispatchModeState;
77

8-
void PythonModeTLS::set_state(std::shared_ptr<SafePyObject> state) {
8+
void TorchDispatchModeTLS::set_state(std::shared_ptr<SafePyObject> state) {
99
if (state) {
1010
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, true);
1111
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, true);
1212
} else {
13-
PythonModeTLS::reset_state();
13+
TorchDispatchModeTLS::reset_state();
1414
}
15-
pythonModeState = std::move(state);
15+
torchDispatchModeState = std::move(state);
1616
}
1717

18-
const std::shared_ptr<SafePyObject>& PythonModeTLS::get_state() {
19-
return pythonModeState;
18+
const std::shared_ptr<SafePyObject>& TorchDispatchModeTLS::get_state() {
19+
return torchDispatchModeState;
2020
}
2121

22-
void PythonModeTLS::reset_state() {
23-
pythonModeState.reset();
22+
void TorchDispatchModeTLS::reset_state() {
23+
torchDispatchModeState.reset();
2424
c10::impl::tls_set_dispatch_key_included(DispatchKey::Python, false);
2525
c10::impl::tls_set_dispatch_key_included(DispatchKey::PythonTLSSnapshot, false);
2626
}

aten/src/ATen/core/PythonModeTLS.h renamed to aten/src/ATen/core/TorchDispatchModeTLS.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
namespace at {
88
namespace impl {
99

10-
struct TORCH_API PythonModeTLS {
10+
struct TORCH_API TorchDispatchModeTLS {
1111
static void set_state(std::shared_ptr<SafePyObject> state);
1212
static const std::shared_ptr<SafePyObject>& get_state();
1313
static void reset_state();

test/test_decomp.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections import defaultdict
44
from torch import Tensor
55
import torch.autograd
6-
from torch.utils._python_dispatch import enable_python_mode
6+
from torch.utils._python_dispatch import enable_torch_dispatch_mode
77
from torch._decomp import decomposition_table
88

99
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
@@ -450,7 +450,7 @@ def check_decomposed(aten_name):
450450
# explicit clearing is necessary as I will create a fresh mode
451451
# for each region
452452
decomposed.clear()
453-
with enable_python_mode(DecompCrossRefMode):
453+
with enable_torch_dispatch_mode(DecompCrossRefMode):
454454
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
455455
if aten_name in decomposition_names:
456456
check_decomposed(aten_name)
@@ -459,7 +459,7 @@ def check_decomposed(aten_name):
459459
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
460460

461461
decomposed.clear()
462-
with enable_python_mode(DecompCrossRefMode):
462+
with enable_torch_dispatch_mode(DecompCrossRefMode):
463463
decomp_vjp_fn(cotangents)
464464
if not run_all:
465465
check_decomposed(op.aten_backward_name)
@@ -468,7 +468,7 @@ def check_decomposed(aten_name):
468468
args = [sample_input.input] + list(sample_input.args)
469469
kwargs = sample_input.kwargs
470470
decomposed.clear()
471-
with enable_python_mode(DecompCrossRefMode):
471+
with enable_torch_dispatch_mode(DecompCrossRefMode):
472472
func(*args, **kwargs)
473473
if not run_all:
474474
check_decomposed(aten_name)

test/test_python_dispatch.py

+22-22
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.testing._internal.logging_tensor import LoggingTensor, LoggingTensorReentrant, LoggingTensorMode, \
88
log_input, capture_logs, no_dispatch
99
from torch.utils._pytree import tree_map
10-
from torch.utils._python_dispatch import enable_python_mode
10+
from torch.utils._python_dispatch import enable_torch_dispatch_mode
1111

1212
import logging
1313

@@ -447,28 +447,28 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
447447
res = x.index_put_(idxs, v)
448448
self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
449449

450-
def test_enable_python_mode_error(self) -> None:
450+
def test_enable_torch_dispatch_mode_error(self) -> None:
451451
with self.assertRaisesRegex(ValueError, "__torch_dispatch__"):
452-
with enable_python_mode(torch.Tensor):
452+
with enable_torch_dispatch_mode(torch.Tensor):
453453
pass
454454
z = LoggingTensor(torch.empty([]))
455455
with self.assertRaisesRegex(ValueError, "must be the type"):
456-
with enable_python_mode(z):
456+
with enable_torch_dispatch_mode(z):
457457
pass
458458

459-
def test_enable_python_mode_basic(self) -> None:
460-
with enable_python_mode(LoggingTensorMode):
459+
def test_enable_torch_dispatch_mode_basic(self) -> None:
460+
with enable_torch_dispatch_mode(LoggingTensorMode):
461461
z = torch.empty([])
462462
self.assertTrue(isinstance(z, LoggingTensorMode))
463463

464-
def test_enable_python_mode_unrelated_tensors(self) -> None:
464+
def test_enable_torch_dispatch_mode_unrelated_tensors(self) -> None:
465465
x = torch.randn([])
466466
y = torch.randn([])
467-
with enable_python_mode(LoggingTensorMode):
467+
with enable_torch_dispatch_mode(LoggingTensorMode):
468468
z = x + y
469469
self.assertTrue(isinstance(z, LoggingTensorMode))
470470

471-
def test_enable_python_mode_subclass_priority(self) -> None:
471+
def test_enable_torch_dispatch_mode_subclass_priority(self) -> None:
472472
class ErrorA(RuntimeError):
473473
pass
474474

@@ -500,30 +500,30 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
500500

501501
# B has precedence over A due to the subclass relationship
502502
with self.assertRaises(ErrorB):
503-
with enable_python_mode(A):
503+
with enable_torch_dispatch_mode(A):
504504
b + b
505505
with self.assertRaises(ErrorB):
506-
with enable_python_mode(B):
506+
with enable_torch_dispatch_mode(B):
507507
a + a
508508
with self.assertRaises(ErrorB):
509-
with enable_python_mode(B):
509+
with enable_torch_dispatch_mode(B):
510510
a + b
511511

512-
def test_enable_python_mode_respects_no_dispatch(self) -> None:
513-
with enable_python_mode(LoggingTensorMode):
512+
def test_enable_torch_dispatch_mode_respects_no_dispatch(self) -> None:
513+
with enable_torch_dispatch_mode(LoggingTensorMode):
514514
z = torch.ones([2, 3])
515515
self.assertTrue(isinstance(z, LoggingTensorMode))
516516
with no_dispatch():
517517
expected = torch.ones([2, 3])
518518
self.assertEqual(z.elem, expected)
519519

520-
def test_nested_enable_python_mode(self) -> None:
520+
def test_nested_enable_torch_dispatch_mode(self) -> None:
521521
with self.assertRaisesRegex(RuntimeError, "has already been set"):
522-
with enable_python_mode(LoggingTensorMode):
523-
with enable_python_mode(LoggingTensorMode):
522+
with enable_torch_dispatch_mode(LoggingTensorMode):
523+
with enable_torch_dispatch_mode(LoggingTensorMode):
524524
pass
525525

526-
def test_tolist_numpy_with_python_mode(self) -> None:
526+
def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
527527
x = LoggingTensor(torch.tensor([2.0, 3.0]))
528528
with self.assertRaisesRegex(RuntimeError, "is not supported for tensor subclasses."):
529529
x.tolist()
@@ -532,7 +532,7 @@ def test_tolist_numpy_with_python_mode(self) -> None:
532532
with self.assertRaises(AssertionError):
533533
self.assertEqual(x, None)
534534

535-
def test_enable_python_mode_subclass_autograd_device_check(self) -> None:
535+
def test_enable_torch_dispatch_mode_subclass_autograd_device_check(self) -> None:
536536
class NonWrapperSubclass(torch.Tensor):
537537
elem: torch.Tensor
538538

@@ -554,7 +554,7 @@ def unwrap(e):
554554
def wrap(e):
555555
return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
556556

557-
# no_dispatch is only needed if you use enable_python_mode.
557+
# no_dispatch is only needed if you use enable_torch_dispatch_mode.
558558
# It prevents infinite recursion.
559559
with no_dispatch():
560560
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
@@ -591,7 +591,7 @@ def unwrap(e):
591591
def wrap(e):
592592
return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
593593

594-
# no_dispatch is only needed if you use enable_python_mode.
594+
# no_dispatch is only needed if you use enable_torch_dispatch_mode.
595595
# It prevents infinite recursion.
596596
with no_dispatch():
597597
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
@@ -616,7 +616,7 @@ def wrap(e):
616616
out.backward()
617617

618618
def test_storage_can_be_converted_to_python_object(self):
619-
with enable_python_mode(LoggingTensorMode):
619+
with enable_torch_dispatch_mode(LoggingTensorMode):
620620
s = torch.Storage()
621621
z = LoggingTensorMode(torch.empty([]))
622622
z.set_(s)

tools/build_variables.bzl

+2-2
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ libtorch_python_core_sources = [
898898
"torch/csrc/autograd/profiler_python.cpp",
899899
"torch/csrc/autograd/python_anomaly_mode.cpp",
900900
"torch/csrc/autograd/python_saved_variable_hooks.cpp",
901-
"torch/csrc/autograd/python_mode.cpp",
901+
"torch/csrc/autograd/torch_dispatch_mode.cpp",
902902
"torch/csrc/autograd/python_cpp_function.cpp",
903903
"torch/csrc/autograd/python_engine.cpp",
904904
"torch/csrc/autograd/python_function.cpp",
@@ -1081,7 +1081,7 @@ aten_cpu_source_non_codegen_list = [
10811081
"aten/src/ATen/core/op_registration/infer_schema.cpp",
10821082
"aten/src/ATen/core/op_registration/op_registration.cpp",
10831083
"aten/src/ATen/core/operator_name.cpp",
1084-
"aten/src/ATen/core/PythonModeTLS.cpp",
1084+
"aten/src/ATen/core/TorchDispatchModeTLS.cpp",
10851085
"aten/src/ATen/core/register_symbols.cpp",
10861086
"aten/src/ATen/core/class_type.cpp",
10871087
"aten/src/ATen/core/type.cpp",

torch/_C/__init__.pyi.in

+2-2
Original file line numberDiff line numberDiff line change
@@ -747,8 +747,8 @@ def __set_forward_AD_enabled(enabled: _bool) -> None: ...
747747
def __is_forward_AD_enabled() -> _bool: ...
748748
def _register_default_hooks(pack_hook: Callable, unpack_hook: Callable) -> None: ...
749749
def _reset_default_hooks() -> None: ...
750-
def _enter_python_mode(cls: Type) -> None: ...
751-
def _exit_python_mode() -> None: ...
750+
def _enter_torch_dispatch_mode(cls: Type) -> None: ...
751+
def _exit_torch_dispatch_mode() -> None: ...
752752

753753
class _InferenceMode(object):
754754
def __init__(self, mode: _bool) -> None: ...

torch/csrc/autograd/init.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
2020
#include <torch/csrc/autograd/utils/wrap_outputs.h>
2121
#include <torch/csrc/autograd/utils/python_arg_parsing.h>
22-
#include <torch/csrc/autograd/python_mode.h>
22+
#include <torch/csrc/autograd/torch_dispatch_mode.h>
2323
#include <torch/csrc/autograd/python_variable.h>
2424
#include <torch/csrc/autograd/record_function_ops.h>
2525
#include <torch/csrc/utils/pycfunction_helpers.h>
@@ -603,16 +603,16 @@ static PyObject * python_exit_dual_level(PyObject* _unused, PyObject* args, PyOb
603603
END_HANDLE_TH_ERRORS
604604
}
605605

606-
static PyObject * enter_python_mode(PyObject* _unused, PyObject* arg) {
606+
static PyObject * enter_torch_dispatch_mode(PyObject* _unused, PyObject* arg) {
607607
HANDLE_TH_ERRORS
608-
PythonMode::enter(arg);
608+
TorchDispatchMode::enter(arg);
609609
Py_RETURN_NONE;
610610
END_HANDLE_TH_ERRORS
611611
}
612612

613-
static PyObject * exit_python_mode(PyObject* _unused, PyObject* arg) {
613+
static PyObject * exit_torch_dispatch_mode(PyObject* _unused, PyObject* arg) {
614614
HANDLE_TH_ERRORS
615-
PythonMode::exit();
615+
TorchDispatchMode::exit();
616616
Py_RETURN_NONE;
617617
END_HANDLE_TH_ERRORS
618618
}
@@ -664,8 +664,8 @@ static PyMethodDef methods[] = { // NOLINT
664664
{"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
665665
{"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
666666
{"_exit_dual_level", castPyCFunctionWithKeywords(python_exit_dual_level), METH_VARARGS | METH_KEYWORDS, nullptr},
667-
{"_enter_python_mode", enter_python_mode, METH_O, nullptr},
668-
{"_exit_python_mode", exit_python_mode, METH_NOARGS, nullptr},
667+
{"_enter_torch_dispatch_mode", enter_torch_dispatch_mode, METH_O, nullptr},
668+
{"_exit_torch_dispatch_mode", exit_torch_dispatch_mode, METH_NOARGS, nullptr},
669669
{"_set_torch_function_mode", set_torch_function_mode, METH_O, nullptr},
670670
{"_get_torch_function_mode", get_torch_function_mode, METH_NOARGS, nullptr},
671671
{nullptr, nullptr, 0, nullptr}

torch/csrc/autograd/python_mode.cpp

-28
This file was deleted.

torch/csrc/autograd/python_variable.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
#include <torch/library.h>
3535
#include <torch/csrc/jit/python/pybind_utils.h>
36-
#include <torch/csrc/autograd/python_mode.h>
36+
#include <torch/csrc/autograd/torch_dispatch_mode.h>
3737

3838

3939
#include <ATen/ATen.h>
@@ -1739,9 +1739,9 @@ bool isPythonTensor(const Tensor& tensor) {
17391739
}
17401740

17411741
// NOTE [dispatch_fn's type argument]
1742-
// `type` is nullable and represents the PythonMode going on.
1743-
// Right now we only support a single PythonMode, but in the future we could
1744-
// change this to a stack of PythonModes.
1742+
// `type` is nullable and represents the TorchDispatchMode going on.
1743+
// Right now we only support a single TorchDispatchMode, but in the future we could
1744+
// change this to a stack of TorchDispatchModes.
17451745
//
17461746
// If `type` isn't null, then we consider the type for dispatch by prepending
17471747
// it to the overloaded_args list. `handle_torch_funciton_no_python_arg_parser`

0 commit comments

Comments
 (0)