Skip to content

Commit 027c0d7

Browse files
aws-rhsolnpytorchmergebot
authored andcommitted
fixed compilations on xla tensor print (pytorch#71147)
Summary: Fixes multiple compilation on xla tensor print. Please check the conversation here: pytorch/xla#3253 This is done to avoid compilations during tensor printing. Torch performs some tensor operations like slicing to make the tensor readable. These operations result in compilations. Hence to avoid the compilations, copying the tensor to cpu before printing. example: ``` dev = xm.xla_device() def test_linear(input_shape=(8, 1024)): import pdb pdb.set_trace() linear = torch.nn.Linear(in_features=1024, out_features=4096, bias=True).to(dev) inp = torch.randn(*input_shape).to(dev) output = linear(inp) xm.mark_step() return output ``` Returning from this function would have resulted in 63 compiles, since PDB prints the value of the return output. In this case it is a xla tensor. Now with the current change, there is no compilation. Pull Request resolved: pytorch#71147 Reviewed By: shunting314 Differential Revision: D33795177 Pulled By: wconstab fbshipit-source-id: 74b53d9a1cb7ef67f9d8b0a32064f3896be449b5 (cherry picked from commit a9e0687)
1 parent 76a2c22 commit 027c0d7

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

torch/_tensor_str.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ def _str_intern(inp):
318318
or (self.device.type == 'cuda' and torch.cuda.current_device() != self.device.index):
319319
suffixes.append('device=\'' + str(self.device) + '\'')
320320

321+
# Tensor printing performs tensor operations like slice, indexing, etc to make it in a
322+
# representable format. These operations on xla/lazy tensor results in compilations. Hence,
323+
# to avoid compilations, copying the tensor to cpu before printing.
324+
if self.device.type == 'xla' or self.device.type == 'lazy':
325+
self = self.to('cpu')
326+
321327
# TODO: add an API to map real -> complex dtypes
322328
_default_complex_dtype = torch.cdouble if torch.get_default_dtype() == torch.double else torch.cfloat
323329
has_default_dtype = self.dtype in (torch.get_default_dtype(), _default_complex_dtype, torch.int64, torch.bool)

0 commit comments

Comments
 (0)