diff --git a/.github/workflows/before_merge.yaml b/.github/workflows/before_merge.yaml index 036b95e99..b48e84996 100644 --- a/.github/workflows/before_merge.yaml +++ b/.github/workflows/before_merge.yaml @@ -12,7 +12,7 @@ permissions: id-token: write jobs: - run-pytest: + validate-pr: env: ARCH_NAME: wormhole_b0 TT_METAL_HOME: ${pwd} @@ -26,8 +26,9 @@ jobs: python3 -m venv venv source venv/bin/activate python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu - python3 -m pip install --upgrade pip - python3 -m pip install -r requirements-dev.txt + python3 -m pip install --upgrade pip + python3 -m pip install -r requirements-dev.txt + python3 -m pip install pytest-github-report - name: Run Tests env: pytest_verbosity: 2 diff --git a/.github/workflows/pull-request.yaml b/.github/workflows/pull-request.yaml index 568c158cc..cbce3605a 100644 --- a/.github/workflows/pull-request.yaml +++ b/.github/workflows/pull-request.yaml @@ -5,7 +5,9 @@ on: jobs: validate-pr: - runs-on: ["in-service", "n150"] + runs-on: ubuntu-latest steps: - - name: Skip run - run: echo "Empty check passed" + - name: Checkout + uses: actions/checkout@v4 + - name: Black + uses: psf/black@23.10.1 diff --git a/.gitignore b/.gitignore index e2225f44b..dd0d9e98f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ __pycache__ +venv .vscode stat *.dot *.svg -*.csv \ No newline at end of file +*.csv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 000000000..3514657e5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,6 @@ +repos: +- repo: https://github.com/psf/black-pre-commit-mirror + rev: 23.10.1 + hooks: + - id: black + language_version: python3 diff --git a/requirements-dev.txt b/requirements-dev.txt index db50088c6..76b8cc39b 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ -r requirements.txt -pytest -pytest-github-report +pytest==7.2.2 +pytest-timeout==2.2.0 +pre-commit==3.0.4 diff --git a/requirements.txt b/requirements.txt index 5240d2108..04730422c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ torch==2.2.1.0+cpu torchvision==0.17.1+cpu -tabulate +tabulate==0.9.0 +networkx==3.1 graphviz matplotlib diff --git a/tools/generate_report.py b/tools/generate_report.py index 9b8e27bea..e933b7be3 100644 --- a/tools/generate_report.py +++ b/tools/generate_report.py @@ -5,7 +5,8 @@ import math import matplotlib.pyplot as plt -def parse_status_json_files(status_folder, prefix = "fw_"): + +def parse_status_json_files(status_folder, prefix="fw_"): stat_dict = {} titles = set() @@ -15,7 +16,7 @@ def parse_status_json_files(status_folder, prefix = "fw_"): if not p.startswith(prefix): continue try: - with open(os.path.join(status_folder,p)) as f: + with open(os.path.join(status_folder, p)) as f: j = json.load(f) model_name = p.replace(prefix, "").replace(".json", "") stat_dict[model_name] = j @@ -27,11 +28,11 @@ def parse_status_json_files(status_folder, prefix = "fw_"): return titles, stat_dict - def generate_node_count(titles, stat_dict, node_count_csv): def get_op_cnt(op_type, op_infos): op_types = [op_info["op_type"] for op_info in op_infos] return op_types.count(op_type) + rows = [["model_name"] + titles] for model_name in sorted(stat_dict.keys()): stat = stat_dict[model_name] @@ -41,26 +42,28 @@ def get_op_cnt(op_type, op_infos): row.append(cnt) rows.append(row) row = ["TOTAL"] - for i in range(1,len(rows[0])): + for i in range(1, len(rows[0])): row.append(sum([int(rows[j][i]) for j in range(1, len(rows))])) rows.append(row) with open(node_count_csv, "w") as f: - csv.writer(f, quotechar = '"').writerows(rows) + csv.writer(f, quotechar='"').writerows(rows) print(f"{node_count_csv} generated") + def generate_total_size(stat_dict, size_dir, key): - assert(key in ["inputs", "outputs"]) + assert key in ["inputs", "outputs"] op_sizes = {} + def sizeof(dtype: str): if dtype in ["torch.bool"]: - return 1/8 + return 1 / 8 if dtype in ["torch.float32", "torch.int32"]: return 4 if dtype in ["torch.float64", "torch.int64"]: return 8 print(f"{dtype} not support") - assert(0) + assert 0 for model_name in stat_dict.keys(): stat = stat_dict[model_name] @@ -71,7 +74,9 @@ def sizeof(dtype: str): name = f"{op_type}_{idx}" tensor_info = op_info[key][idx] if "shape" in tensor_info.keys() and "dtype" in tensor_info.keys(): - size = math.prod(tensor_info["shape"]) * sizeof(tensor_info["dtype"]) + size = math.prod(tensor_info["shape"]) * sizeof( + tensor_info["dtype"] + ) op_sizes.setdefault(name, []) op_sizes[name].append(size) @@ -88,18 +93,29 @@ def sizeof(dtype: str): plt.cla() print(f"{size_dir} prepared") + if __name__ == "__main__": - out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(),"stat") - raw = os.path.join(out,"raw") + out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(), "stat") + raw = os.path.join(out, "raw") assert os.path.isdir(raw) and "cannot find stat/raw folder" - def generate(prefix = "fw_"): + def generate(prefix="fw_"): titles, stat_dict = parse_status_json_files(raw, prefix) if titles == []: return - generate_node_count(titles, stat_dict, os.path.join(out,f"{prefix}node_count.csv")) - generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_input_size_dist/"), key = "inputs") - generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_output_size_dist/"), key = "outputs") + generate_node_count( + titles, stat_dict, os.path.join(out, f"{prefix}node_count.csv") + ) + generate_total_size( + stat_dict, + os.path.join(out, f"{prefix}total_input_size_dist/"), + key="inputs", + ) + generate_total_size( + stat_dict, + os.path.join(out, f"{prefix}total_output_size_dist/"), + key="outputs", + ) generate("fw_") generate("bw_") diff --git a/tools/run_torchvision.py b/tools/run_torchvision.py index 1f9e6a9cd..e9f8f407a 100644 --- a/tools/run_torchvision.py +++ b/tools/run_torchvision.py @@ -3,9 +3,18 @@ import torch import torchvision -def run_model(model_name: str, backend: str, backward: bool, out_path: str, graphviz: bool, to_profile: bool, device = None): + +def run_model( + model_name: str, + backend: str, + backward: bool, + out_path: str, + graphviz: bool, + to_profile: bool, + device=None, +): if model_name == "dinov2_vits14": - m = torch.hub.load('facebookresearch/dinov2', model_name) + m = torch.hub.load("facebookresearch/dinov2", model_name) else: try: m = torchvision.models.get_model(model_name, pretrained=True) @@ -30,16 +39,21 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap option = torch_ttnn.TorchTtnnOption(device=device) m = torch.compile(m, backend=torch_ttnn.backend(option)) elif backend == "torch_stat": - option = torch_stat.TorchStatOption(model_name = model_name, backward = backward, - out = out_path, gen_graphviz=graphviz) + option = torch_stat.TorchStatOption( + model_name=model_name, + backward=backward, + out=out_path, + gen_graphviz=graphviz, + ) m = torch.compile(m, backend=torch_stat.backend(option)) else: - assert(0 and "Unsupport backend") + assert 0 and "Unsupport backend" inputs = [torch.randn([1, 3, 224, 224])] if to_profile: from torch.profiler import profile, record_function, ProfilerActivity + trace_file = os.path.join(out_path, "profile", model_name) os.makedirs(os.path.dirname(trace_file), exist_ok=True) activities = [ProfilerActivity.CPU] @@ -55,17 +69,20 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap if backward: result.backward(torch.ones_like(result)) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--out_path", "-o", type = str, default = os.path.join(os.getcwd(),"stat")) - parser.add_argument("--backend", type = str) + parser.add_argument( + "--out_path", "-o", type=str, default=os.path.join(os.getcwd(), "stat") + ) + parser.add_argument("--backend", type=str) parser.add_argument("--graphviz", action="store_true") parser.add_argument("--backward", action="store_true") parser.add_argument("--profile", action="store_true") args = parser.parse_args() - assert(args.backend in ["torch_ttnn", "torch_stat"]) + assert args.backend in ["torch_ttnn", "torch_stat"] if args.backend == "torch_ttnn" and args.backward: - assert(0 and "torch_ttnn not yet support backward") + assert 0 and "torch_ttnn not yet support backward" if args.backend == "torch_ttnn": import torch_ttnn @@ -77,8 +94,16 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap device = torch_ttnn.ttnn.open(0) if args.backend == "torch_ttnn" else None for m in models: try: - run_model(m, args.backend, args.backward, args.out_path, args.graphviz, args.profile, device) + run_model( + m, + args.backend, + args.backward, + args.out_path, + args.graphviz, + args.profile, + device, + ) except: print(f"{m} FAIL") if args.backend == "torch_ttnn": - torch_ttnn.ttnn.close(device) \ No newline at end of file + torch_ttnn.ttnn.close(device) diff --git a/torch_ttnn/passes/stat_pass.py b/torch_ttnn/passes/stat_pass.py index 3f5dc4998..59ba70282 100644 --- a/torch_ttnn/passes/stat_pass.py +++ b/torch_ttnn/passes/stat_pass.py @@ -4,6 +4,7 @@ from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.pass_base import PassBase, PassResult + def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file): try: FakeTensorProp(gm).propagate(*example_inputs) @@ -16,13 +17,11 @@ def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file): def get_tensor_info(t): def no_symInt_in_list(the_list): return not any(isinstance(element, torch.SymInt) for element in the_list) + # Only record if the tensor is torch.Tensor # some shape is referenced by a variable, like [2, 256, s0, s1] if isinstance(t, torch.Tensor) and no_symInt_in_list(list(t.shape)): - return { - "shape": list(t.shape), - "dtype": str(t.dtype) - } + return {"shape": list(t.shape), "dtype": str(t.dtype)} else: return {} @@ -48,8 +47,7 @@ def no_symInt_in_list(the_list): # set node's outputs info node_info["outputs"] = [] outputs_info = node.meta["val"] - if isinstance(outputs_info, tuple) or \ - isinstance(outputs_info, list): + if isinstance(outputs_info, tuple) or isinstance(outputs_info, list): for output_info in outputs_info: output = get_tensor_info(output_info) node_info["outputs"].append(output) @@ -57,13 +55,14 @@ def no_symInt_in_list(the_list): output = get_tensor_info(outputs_info) node_info["outputs"].append(output) else: - assert(0 and "unsupport outputs_info") + assert 0 and "unsupport outputs_info" out.append(node_info) os.makedirs(os.path.dirname(out_file), exist_ok=True) with open(out_file, "w") as f: json.dump(out, f, indent=4) + # The pass to collect node's information # Run tools/generate_report.py to genetate report class StatPass(PassBase): @@ -75,4 +74,4 @@ def __init__(self, filename, example_inputs): def call(self, gm: torch.fx.GraphModule): parse_fx_stat(gm, self.example_inputs, self.filename) modified = False - return PassResult(gm, modified) \ No newline at end of file + return PassResult(gm, modified)