From a878536fd940bae52015f8b77fc8f99743416bc2 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 29 Mar 2022 11:51:26 +0000 Subject: [PATCH 1/2] Fixing tensor.numpy on wrapped tensors Fixes #626 Description: - Fixing tensor.numpy on wrapped tensors --- functorch/_src/monkey_patching.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/functorch/_src/monkey_patching.py b/functorch/_src/monkey_patching.py index 1b507d908..95ae3a87f 100644 --- a/functorch/_src/monkey_patching.py +++ b/functorch/_src/monkey_patching.py @@ -98,3 +98,34 @@ def _backward(*args, **kwargs): setattr(torch.Tensor, 'backward', _backward) + + +# Monkeypatch .numpy() to fetch underlying tensor and call .numpy() +_old_numpy = torch.Tensor.numpy + + +@functools.wraps(_old_numpy) +def _numpy(tensor): + level = _C.maybe_get_level(tensor) + if level == -1: + return _old_numpy(tensor) + + if _C.is_functionaltensor(tensor): + # Since we're unwrapping the FunctionalTensorWrapper, we need to make sure + # that it's up to date first + torch._sync(tensor) + + value = _C.get_unwrapped(tensor) + dl_enabled = _C.tls_set_is_included() + try: + # Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys + if (dl_enabled): + _C._set_dynamic_layer_keys_included(False) + return value.numpy() + finally: + # Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys + if (dl_enabled): + _C._set_dynamic_layer_keys_included(True) + + +setattr(torch.Tensor, 'numpy', _numpy) From 9a7358728197b058abed0d72e48e30cba5e68558 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 29 Mar 2022 12:06:40 +0000 Subject: [PATCH 2/2] Added a test --- test/test_eager_transforms.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/test_eager_transforms.py b/test/test_eager_transforms.py index 39310d49d..7ecfa3d65 100644 --- a/test/test_eager_transforms.py +++ b/test/test_eager_transforms.py @@ -863,6 +863,38 @@ def foo(t): expected = expected.replace("\n", "").replace(" ", "") self.assertEqual(expected, buf) + @parametrize("op_list_data", [ + subtest(([vmap, ], [(4, 2), (64, 3, 32, 32)]), name='vmap'), + subtest(([vmap, vmap], [(4, 3, 2), (64, 3, 32, 32)]), name='vmap_vmap'), + subtest(([grad, ], [(0, ), [], (4, 2), (64, 3, 32, 32)]), name='grad'), + subtest(([grad, grad], [[], ]), name='grad_grad'), + subtest(([vmap, grad], [(4, 2)]), name='vmap_grad'), + ]) + def test_tensor_numpy(self, device, op_list_data): + + op_list, shapes = op_list_data + + for dt in [torch.float32, torch.float64]: + data = [torch.randn(s, dtype=dt, device=device) for s in shapes] + + for x in data: + + def foo(t): + n = t.detach().cpu().numpy() + assert n.shape == x.shape + return t.mean() + + fn = foo + bdim = 0 + for op in reversed(op_list): + if op == vmap: + fn = op(fn, in_dims=bdim) + bdim += 1 + else: + fn = op(fn) + + fn(x) + def test_no_grad_outside(self, device): x = torch.randn([], device=device, requires_grad=True) with torch.no_grad():