diff --git a/tests/lowering/eltwise/test_add.py b/tests/lowering/eltwise/test_add.py index 9737e34e9..ea2ce7046 100644 --- a/tests/lowering/eltwise/test_add.py +++ b/tests/lowering/eltwise/test_add.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn from torch.fx.passes.dialect.common.cse_pass import CSEPass diff --git a/tests/lowering/matmul/test_only_add_matmul.py b/tests/lowering/matmul/test_only_add_matmul.py index 9ffe58993..1cdf0b24d 100644 --- a/tests/lowering/matmul/test_only_add_matmul.py +++ b/tests/lowering/matmul/test_only_add_matmul.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn from torch_ttnn.utils import check_with_pcc diff --git a/tests/lowering/tensor_manipulation/test_expand.py b/tests/lowering/tensor_manipulation/test_expand.py index 875017e5c..5c280f5ef 100644 --- a/tests/lowering/tensor_manipulation/test_expand.py +++ b/tests/lowering/tensor_manipulation/test_expand.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn import tt_lib from torch_ttnn.utils import ( DummyTtnnRowMajorLayout, diff --git a/tests/lowering/test_if.py b/tests/lowering/test_if.py index ea05dd9b7..6b62247b6 100644 --- a/tests/lowering/test_if.py +++ b/tests/lowering/test_if.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn from torch.fx.passes.dialect.common.cse_pass import CSEPass diff --git a/tests/lowering/test_more_ops.py b/tests/lowering/test_more_ops.py index 5bf3a0b5d..104839bca 100644 --- a/tests/lowering/test_more_ops.py +++ b/tests/lowering/test_more_ops.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn import tt_lib from torch_ttnn.utils import check_with_pcc diff --git a/tests/models/bert/test_bert.py b/tests/models/bert/test_bert.py index 1c0b90f55..90b27df9d 100644 --- a/tests/models/bert/test_bert.py +++ b/tests/models/bert/test_bert.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn import collections # Load model directly diff --git a/tests/models/resnet/test_resnet.py b/tests/models/resnet/test_resnet.py index 198d9aad8..b8ff3bbf4 100644 --- a/tests/models/resnet/test_resnet.py +++ b/tests/models/resnet/test_resnet.py @@ -2,7 +2,7 @@ import torchvision import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn import collections diff --git a/tests/test_fall_back.py b/tests/test_fall_back.py index c1613b845..2b0579a52 100644 --- a/tests/test_fall_back.py +++ b/tests/test_fall_back.py @@ -1,7 +1,7 @@ import torch import torch_ttnn import unittest -from torch_ttnn import ttnn +import ttnn from torch_ttnn.utils import check_with_pcc diff --git a/tests/tools/test_stats.py b/tests/tools/test_stats.py index ab5c36005..5cdd78bd3 100644 --- a/tests/tools/test_stats.py +++ b/tests/tools/test_stats.py @@ -1,7 +1,7 @@ import os import shutil import torch -from torch_ttnn import torch_stat +import torch_stat import unittest import json from torch.fx.passes.dialect.common.cse_pass import CSEPass diff --git a/torch_ttnn/__init__.py b/torch_ttnn/__init__.py index 60a1067d8..af0df5daf 100644 --- a/torch_ttnn/__init__.py +++ b/torch_ttnn/__init__.py @@ -1,8 +1,10 @@ -from .backend import backend -from .backend import TorchTtnnOption +from torch_ttnn.backend import backend +from torch_ttnn.backend import TorchTtnnOption try: import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from . import mock_ttnn as ttnn +except ImportError as e: + print( + "ttnn is not installed. Run `python3 -m pip install -r requirements.txt` or `python3 -m pip install -r requirements-dev.txt` if you are developing the compiler" + ) + raise e diff --git a/torch_ttnn/backend.py b/torch_ttnn/backend.py index 2a4ad7f29..0d0b6de5f 100644 --- a/torch_ttnn/backend.py +++ b/torch_ttnn/backend.py @@ -3,16 +3,11 @@ from typing import List import torch._dynamo from functorch.compile import make_boxed_func +import ttnn torch._dynamo.config.suppress_errors = False torch._dynamo.config.verbose = True -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from . import mock_ttnn as ttnn - # The backend for torch.compile that converts a graph to use ttnn. # The "option" parameter is a dict that contains one key "torch_ttnn_option". diff --git a/torch_ttnn/fx_graphviz.py b/torch_ttnn/fx_graphviz.py index 23304d384..e681681fb 100644 --- a/torch_ttnn/fx_graphviz.py +++ b/torch_ttnn/fx_graphviz.py @@ -4,11 +4,7 @@ import math from collections import defaultdict -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_tnn import mock_ttnn as ttnn +import ttnn def _tensor_weight(t): diff --git a/torch_ttnn/mock_ttnn.py b/torch_ttnn/mock_ttnn.py deleted file mode 100644 index ea8a4bdd0..000000000 --- a/torch_ttnn/mock_ttnn.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch - - -############################################################ -# Device related functions -############################################################ -class Device: - def __init__(self, device_id): - self.device_id: int = device_id - - # def __repr__(self): - # """ - # This function is necessary for torch.fx.graph._register_custom_builtin. - # The generated code will use this function to generate the code. - # The compiler must use `torch.fx.graph._register_custom_builtin` - # to register the custom builtin. - # - # However, the ttnn module's Device does not provide a __repr__ function. - # So the mock here can not provide a __repr__ function either. - # We have to find other way. See `torch_ttnn.py` for details. - # """ - # return f"Device({self.device_id})" - - -def open_device(device_id): - print(f"Device {device_id} is opened") - return Device(device_id) - - -def close_device(device): - print(f"Device {device.device_id} is closed") - pass - - -def from_torch(tensor): - return tensor - - -def to_torch(tensor): - return tensor - - -def from_device(tensor): - print(f"Tensor with shape {tensor.shape} is moved from device") - return tensor - - -def to_device(tensor, device): - print(f"Tensor with shape {tensor.shape} is moved to device {device.device_id}") - return tensor - - -def to_layout(tensor, layout): - print(f"Tensor with shape {tensor.shape} is convert to layout {layout}") - return tensor - - -############################################################ -# Operations -############################################################ - - -@torch.fx.wrap -def add(x, y): - z = x + y - return z - - -@torch.fx.wrap -def matmul(x, y): - mm = torch.ops.aten.mm(x, y) - return mm - - -@torch.fx.wrap -def sub(x, y): - z = x - y - return z - - -@torch.fx.wrap -def mul(x, y): - z = x * y - return z - - -@torch.fx.wrap -def softmax(x, axis): - r = torch.softmax(x, axis) - return r - - -@torch.fx.wrap -def tanh(x): - r = torch.tanh(x) - return r - - -@torch.fx.wrap -def reshape(x, new_shape): - r = torch.reshape(x, new_shape) - return r - - -@torch.fx.wrap -def permute(x, order): - r = torch.permute(x, order) - return r - - -ROW_MAJOR_LAYOUT = 0 -TILE_LAYOUT = 1 - -# Wrap the functions so that they can be used in torch.fx -# and block the further recusive tracing. See: -# https://pytorch.org/docs/stable/fx.html#torch.fx.wrap diff --git a/torch_ttnn/passes/graphviz_pass.py b/torch_ttnn/passes/graphviz_pass.py index 37e0937f1..d0928bd55 100644 --- a/torch_ttnn/passes/graphviz_pass.py +++ b/torch_ttnn/passes/graphviz_pass.py @@ -1,11 +1,7 @@ import torch from torch_ttnn import fx_graphviz -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn +import ttnn from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/torch_ttnn/passes/lowering/add_data_move_pass.py b/torch_ttnn/passes/lowering/add_data_move_pass.py index 4531f196a..7379e81e3 100644 --- a/torch_ttnn/passes/lowering/add_data_move_pass.py +++ b/torch_ttnn/passes/lowering/add_data_move_pass.py @@ -1,4 +1,5 @@ import torch +import ttnn from torch_ttnn.utils import ( DummyTtnnUint32, DummyTtnnRowMajorLayout, @@ -6,11 +7,6 @@ DummyDevice, ) -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/torch_ttnn/passes/lowering/eliminate_data_move_pass.py b/torch_ttnn/passes/lowering/eliminate_data_move_pass.py index f71b575ee..a0e9f7407 100644 --- a/torch_ttnn/passes/lowering/eliminate_data_move_pass.py +++ b/torch_ttnn/passes/lowering/eliminate_data_move_pass.py @@ -1,10 +1,5 @@ import torch - -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn +import ttnn from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/torch_ttnn/passes/lowering/permute_reshape_tuple.py b/torch_ttnn/passes/lowering/permute_reshape_tuple.py index f07ee53af..cf5aa9f49 100644 --- a/torch_ttnn/passes/lowering/permute_reshape_tuple.py +++ b/torch_ttnn/passes/lowering/permute_reshape_tuple.py @@ -1,10 +1,5 @@ import torch - -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn +import ttnn from torch.fx.passes.infra.pass_base import PassBase, PassResult diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 5920a0306..027807bd9 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -1,4 +1,5 @@ import torch +import ttnn from torch_ttnn.utils import ( GraphCleanup, DummyTtlTensorTensorMemoryLayoutInterleaved, @@ -9,12 +10,6 @@ ) import numpy as np -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn - from torch.fx.passes.infra.pass_base import PassBase, PassResult import torch.fx.traceback as fx_traceback diff --git a/torch_ttnn/patterns/add.py b/torch_ttnn/patterns/add.py index b2711eaaa..13dbe7db9 100644 --- a/torch_ttnn/patterns/add.py +++ b/torch_ttnn/patterns/add.py @@ -2,12 +2,7 @@ # This file may not be needed anymore. import torch - -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn +import ttnn # NOTE(yoco) The name `add` must be the same as the name of the function. diff --git a/torch_ttnn/patterns/mm.py b/torch_ttnn/patterns/mm.py index f9da9e779..bd8c75084 100644 --- a/torch_ttnn/patterns/mm.py +++ b/torch_ttnn/patterns/mm.py @@ -2,12 +2,7 @@ # This file may not be needed anymore. import torch - -try: - import ttnn -except ImportError: - print("ttnn is not installed, use mock_ttnn instead") - from torch_ttnn import mock_ttnn as ttnn +import ttnn # NOTE(yoco) The name `matmul` must be the same as the name of the function.