7
7
from torch .testing ._internal .logging_tensor import LoggingTensor , LoggingTensorReentrant , LoggingTensorMode , \
8
8
log_input , capture_logs , no_dispatch
9
9
from torch .utils ._pytree import tree_map
10
- from torch .utils ._python_dispatch import enable_python_mode
10
+ from torch .utils ._python_dispatch import enable_torch_dispatch_mode
11
11
12
12
import logging
13
13
@@ -447,28 +447,28 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
447
447
res = x .index_put_ (idxs , v )
448
448
self .assertEqual (called_funcs , [torch .ops .aten .index_put_ .default ])
449
449
450
- def test_enable_python_mode_error (self ) -> None :
450
+ def test_enable_torch_dispatch_mode_error (self ) -> None :
451
451
with self .assertRaisesRegex (ValueError , "__torch_dispatch__" ):
452
- with enable_python_mode (torch .Tensor ):
452
+ with enable_torch_dispatch_mode (torch .Tensor ):
453
453
pass
454
454
z = LoggingTensor (torch .empty ([]))
455
455
with self .assertRaisesRegex (ValueError , "must be the type" ):
456
- with enable_python_mode (z ):
456
+ with enable_torch_dispatch_mode (z ):
457
457
pass
458
458
459
- def test_enable_python_mode_basic (self ) -> None :
460
- with enable_python_mode (LoggingTensorMode ):
459
+ def test_enable_torch_dispatch_mode_basic (self ) -> None :
460
+ with enable_torch_dispatch_mode (LoggingTensorMode ):
461
461
z = torch .empty ([])
462
462
self .assertTrue (isinstance (z , LoggingTensorMode ))
463
463
464
- def test_enable_python_mode_unrelated_tensors (self ) -> None :
464
+ def test_enable_torch_dispatch_mode_unrelated_tensors (self ) -> None :
465
465
x = torch .randn ([])
466
466
y = torch .randn ([])
467
- with enable_python_mode (LoggingTensorMode ):
467
+ with enable_torch_dispatch_mode (LoggingTensorMode ):
468
468
z = x + y
469
469
self .assertTrue (isinstance (z , LoggingTensorMode ))
470
470
471
- def test_enable_python_mode_subclass_priority (self ) -> None :
471
+ def test_enable_torch_dispatch_mode_subclass_priority (self ) -> None :
472
472
class ErrorA (RuntimeError ):
473
473
pass
474
474
@@ -500,30 +500,30 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
500
500
501
501
# B has precedence over A due to the subclass relationship
502
502
with self .assertRaises (ErrorB ):
503
- with enable_python_mode (A ):
503
+ with enable_torch_dispatch_mode (A ):
504
504
b + b
505
505
with self .assertRaises (ErrorB ):
506
- with enable_python_mode (B ):
506
+ with enable_torch_dispatch_mode (B ):
507
507
a + a
508
508
with self .assertRaises (ErrorB ):
509
- with enable_python_mode (B ):
509
+ with enable_torch_dispatch_mode (B ):
510
510
a + b
511
511
512
- def test_enable_python_mode_respects_no_dispatch (self ) -> None :
513
- with enable_python_mode (LoggingTensorMode ):
512
+ def test_enable_torch_dispatch_mode_respects_no_dispatch (self ) -> None :
513
+ with enable_torch_dispatch_mode (LoggingTensorMode ):
514
514
z = torch .ones ([2 , 3 ])
515
515
self .assertTrue (isinstance (z , LoggingTensorMode ))
516
516
with no_dispatch ():
517
517
expected = torch .ones ([2 , 3 ])
518
518
self .assertEqual (z .elem , expected )
519
519
520
- def test_nested_enable_python_mode (self ) -> None :
520
+ def test_nested_enable_torch_dispatch_mode (self ) -> None :
521
521
with self .assertRaisesRegex (RuntimeError , "has already been set" ):
522
- with enable_python_mode (LoggingTensorMode ):
523
- with enable_python_mode (LoggingTensorMode ):
522
+ with enable_torch_dispatch_mode (LoggingTensorMode ):
523
+ with enable_torch_dispatch_mode (LoggingTensorMode ):
524
524
pass
525
525
526
- def test_tolist_numpy_with_python_mode (self ) -> None :
526
+ def test_tolist_numpy_with_torch_dispatch_mode (self ) -> None :
527
527
x = LoggingTensor (torch .tensor ([2.0 , 3.0 ]))
528
528
with self .assertRaisesRegex (RuntimeError , "is not supported for tensor subclasses." ):
529
529
x .tolist ()
@@ -532,7 +532,7 @@ def test_tolist_numpy_with_python_mode(self) -> None:
532
532
with self .assertRaises (AssertionError ):
533
533
self .assertEqual (x , None )
534
534
535
- def test_enable_python_mode_subclass_autograd_device_check (self ) -> None :
535
+ def test_enable_torch_dispatch_mode_subclass_autograd_device_check (self ) -> None :
536
536
class NonWrapperSubclass (torch .Tensor ):
537
537
elem : torch .Tensor
538
538
@@ -554,7 +554,7 @@ def unwrap(e):
554
554
def wrap (e ):
555
555
return NonWrapperSubclass (e ) if isinstance (e , torch .Tensor ) else e
556
556
557
- # no_dispatch is only needed if you use enable_python_mode .
557
+ # no_dispatch is only needed if you use enable_torch_dispatch_mode .
558
558
# It prevents infinite recursion.
559
559
with no_dispatch ():
560
560
rs = tree_map (wrap , func (* tree_map (unwrap , args ), ** tree_map (unwrap , kwargs )))
@@ -591,7 +591,7 @@ def unwrap(e):
591
591
def wrap (e ):
592
592
return SubclassWithNone (e ) if isinstance (e , torch .Tensor ) else e
593
593
594
- # no_dispatch is only needed if you use enable_python_mode .
594
+ # no_dispatch is only needed if you use enable_torch_dispatch_mode .
595
595
# It prevents infinite recursion.
596
596
with no_dispatch ():
597
597
rs = tree_map (wrap , func (* tree_map (unwrap , args ), ** tree_map (unwrap , kwargs )))
@@ -616,7 +616,7 @@ def wrap(e):
616
616
out .backward ()
617
617
618
618
def test_storage_can_be_converted_to_python_object (self ):
619
- with enable_python_mode (LoggingTensorMode ):
619
+ with enable_torch_dispatch_mode (LoggingTensorMode ):
620
620
s = torch .Storage ()
621
621
z = LoggingTensorMode (torch .empty ([]))
622
622
z .set_ (s )
0 commit comments