From a348cfa79cc2b43964c818ea7962f3684a09a9be Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev Date: Tue, 25 Jun 2024 23:56:47 +0000 Subject: [PATCH] Run black on all files --- tools/generate_report.py | 46 ++++++++++++++++++++++----------- tools/run_torchvision.py | 47 ++++++++++++++++++++++++++-------- torch_ttnn/passes/stat_pass.py | 15 +++++------ 3 files changed, 74 insertions(+), 34 deletions(-) diff --git a/tools/generate_report.py b/tools/generate_report.py index 9b8e27bea..e933b7be3 100644 --- a/tools/generate_report.py +++ b/tools/generate_report.py @@ -5,7 +5,8 @@ import math import matplotlib.pyplot as plt -def parse_status_json_files(status_folder, prefix = "fw_"): + +def parse_status_json_files(status_folder, prefix="fw_"): stat_dict = {} titles = set() @@ -15,7 +16,7 @@ def parse_status_json_files(status_folder, prefix = "fw_"): if not p.startswith(prefix): continue try: - with open(os.path.join(status_folder,p)) as f: + with open(os.path.join(status_folder, p)) as f: j = json.load(f) model_name = p.replace(prefix, "").replace(".json", "") stat_dict[model_name] = j @@ -27,11 +28,11 @@ def parse_status_json_files(status_folder, prefix = "fw_"): return titles, stat_dict - def generate_node_count(titles, stat_dict, node_count_csv): def get_op_cnt(op_type, op_infos): op_types = [op_info["op_type"] for op_info in op_infos] return op_types.count(op_type) + rows = [["model_name"] + titles] for model_name in sorted(stat_dict.keys()): stat = stat_dict[model_name] @@ -41,26 +42,28 @@ def get_op_cnt(op_type, op_infos): row.append(cnt) rows.append(row) row = ["TOTAL"] - for i in range(1,len(rows[0])): + for i in range(1, len(rows[0])): row.append(sum([int(rows[j][i]) for j in range(1, len(rows))])) rows.append(row) with open(node_count_csv, "w") as f: - csv.writer(f, quotechar = '"').writerows(rows) + csv.writer(f, quotechar='"').writerows(rows) print(f"{node_count_csv} generated") + def generate_total_size(stat_dict, size_dir, key): - assert(key in ["inputs", "outputs"]) + assert key in ["inputs", "outputs"] op_sizes = {} + def sizeof(dtype: str): if dtype in ["torch.bool"]: - return 1/8 + return 1 / 8 if dtype in ["torch.float32", "torch.int32"]: return 4 if dtype in ["torch.float64", "torch.int64"]: return 8 print(f"{dtype} not support") - assert(0) + assert 0 for model_name in stat_dict.keys(): stat = stat_dict[model_name] @@ -71,7 +74,9 @@ def sizeof(dtype: str): name = f"{op_type}_{idx}" tensor_info = op_info[key][idx] if "shape" in tensor_info.keys() and "dtype" in tensor_info.keys(): - size = math.prod(tensor_info["shape"]) * sizeof(tensor_info["dtype"]) + size = math.prod(tensor_info["shape"]) * sizeof( + tensor_info["dtype"] + ) op_sizes.setdefault(name, []) op_sizes[name].append(size) @@ -88,18 +93,29 @@ def sizeof(dtype: str): plt.cla() print(f"{size_dir} prepared") + if __name__ == "__main__": - out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(),"stat") - raw = os.path.join(out,"raw") + out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(), "stat") + raw = os.path.join(out, "raw") assert os.path.isdir(raw) and "cannot find stat/raw folder" - def generate(prefix = "fw_"): + def generate(prefix="fw_"): titles, stat_dict = parse_status_json_files(raw, prefix) if titles == []: return - generate_node_count(titles, stat_dict, os.path.join(out,f"{prefix}node_count.csv")) - generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_input_size_dist/"), key = "inputs") - generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_output_size_dist/"), key = "outputs") + generate_node_count( + titles, stat_dict, os.path.join(out, f"{prefix}node_count.csv") + ) + generate_total_size( + stat_dict, + os.path.join(out, f"{prefix}total_input_size_dist/"), + key="inputs", + ) + generate_total_size( + stat_dict, + os.path.join(out, f"{prefix}total_output_size_dist/"), + key="outputs", + ) generate("fw_") generate("bw_") diff --git a/tools/run_torchvision.py b/tools/run_torchvision.py index 1f9e6a9cd..e9f8f407a 100644 --- a/tools/run_torchvision.py +++ b/tools/run_torchvision.py @@ -3,9 +3,18 @@ import torch import torchvision -def run_model(model_name: str, backend: str, backward: bool, out_path: str, graphviz: bool, to_profile: bool, device = None): + +def run_model( + model_name: str, + backend: str, + backward: bool, + out_path: str, + graphviz: bool, + to_profile: bool, + device=None, +): if model_name == "dinov2_vits14": - m = torch.hub.load('facebookresearch/dinov2', model_name) + m = torch.hub.load("facebookresearch/dinov2", model_name) else: try: m = torchvision.models.get_model(model_name, pretrained=True) @@ -30,16 +39,21 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap option = torch_ttnn.TorchTtnnOption(device=device) m = torch.compile(m, backend=torch_ttnn.backend(option)) elif backend == "torch_stat": - option = torch_stat.TorchStatOption(model_name = model_name, backward = backward, - out = out_path, gen_graphviz=graphviz) + option = torch_stat.TorchStatOption( + model_name=model_name, + backward=backward, + out=out_path, + gen_graphviz=graphviz, + ) m = torch.compile(m, backend=torch_stat.backend(option)) else: - assert(0 and "Unsupport backend") + assert 0 and "Unsupport backend" inputs = [torch.randn([1, 3, 224, 224])] if to_profile: from torch.profiler import profile, record_function, ProfilerActivity + trace_file = os.path.join(out_path, "profile", model_name) os.makedirs(os.path.dirname(trace_file), exist_ok=True) activities = [ProfilerActivity.CPU] @@ -55,17 +69,20 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap if backward: result.backward(torch.ones_like(result)) + if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--out_path", "-o", type = str, default = os.path.join(os.getcwd(),"stat")) - parser.add_argument("--backend", type = str) + parser.add_argument( + "--out_path", "-o", type=str, default=os.path.join(os.getcwd(), "stat") + ) + parser.add_argument("--backend", type=str) parser.add_argument("--graphviz", action="store_true") parser.add_argument("--backward", action="store_true") parser.add_argument("--profile", action="store_true") args = parser.parse_args() - assert(args.backend in ["torch_ttnn", "torch_stat"]) + assert args.backend in ["torch_ttnn", "torch_stat"] if args.backend == "torch_ttnn" and args.backward: - assert(0 and "torch_ttnn not yet support backward") + assert 0 and "torch_ttnn not yet support backward" if args.backend == "torch_ttnn": import torch_ttnn @@ -77,8 +94,16 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap device = torch_ttnn.ttnn.open(0) if args.backend == "torch_ttnn" else None for m in models: try: - run_model(m, args.backend, args.backward, args.out_path, args.graphviz, args.profile, device) + run_model( + m, + args.backend, + args.backward, + args.out_path, + args.graphviz, + args.profile, + device, + ) except: print(f"{m} FAIL") if args.backend == "torch_ttnn": - torch_ttnn.ttnn.close(device) \ No newline at end of file + torch_ttnn.ttnn.close(device) diff --git a/torch_ttnn/passes/stat_pass.py b/torch_ttnn/passes/stat_pass.py index 3f5dc4998..59ba70282 100644 --- a/torch_ttnn/passes/stat_pass.py +++ b/torch_ttnn/passes/stat_pass.py @@ -4,6 +4,7 @@ from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.pass_base import PassBase, PassResult + def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file): try: FakeTensorProp(gm).propagate(*example_inputs) @@ -16,13 +17,11 @@ def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file): def get_tensor_info(t): def no_symInt_in_list(the_list): return not any(isinstance(element, torch.SymInt) for element in the_list) + # Only record if the tensor is torch.Tensor # some shape is referenced by a variable, like [2, 256, s0, s1] if isinstance(t, torch.Tensor) and no_symInt_in_list(list(t.shape)): - return { - "shape": list(t.shape), - "dtype": str(t.dtype) - } + return {"shape": list(t.shape), "dtype": str(t.dtype)} else: return {} @@ -48,8 +47,7 @@ def no_symInt_in_list(the_list): # set node's outputs info node_info["outputs"] = [] outputs_info = node.meta["val"] - if isinstance(outputs_info, tuple) or \ - isinstance(outputs_info, list): + if isinstance(outputs_info, tuple) or isinstance(outputs_info, list): for output_info in outputs_info: output = get_tensor_info(output_info) node_info["outputs"].append(output) @@ -57,13 +55,14 @@ def no_symInt_in_list(the_list): output = get_tensor_info(outputs_info) node_info["outputs"].append(output) else: - assert(0 and "unsupport outputs_info") + assert 0 and "unsupport outputs_info" out.append(node_info) os.makedirs(os.path.dirname(out_file), exist_ok=True) with open(out_file, "w") as f: json.dump(out, f, indent=4) + # The pass to collect node's information # Run tools/generate_report.py to genetate report class StatPass(PassBase): @@ -75,4 +74,4 @@ def __init__(self, filename, example_inputs): def call(self, gm: torch.fx.GraphModule): parse_fx_stat(gm, self.example_inputs, self.filename) modified = False - return PassResult(gm, modified) \ No newline at end of file + return PassResult(gm, modified)