Skip to content

Commit

Permalink
Run black on all files
Browse files Browse the repository at this point in the history
  • Loading branch information
ayerofieiev-tt committed Jun 25, 2024
1 parent 54e328e commit a348cfa
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 34 deletions.
46 changes: 31 additions & 15 deletions tools/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import math
import matplotlib.pyplot as plt

def parse_status_json_files(status_folder, prefix = "fw_"):

def parse_status_json_files(status_folder, prefix="fw_"):
stat_dict = {}
titles = set()

Expand All @@ -15,7 +16,7 @@ def parse_status_json_files(status_folder, prefix = "fw_"):
if not p.startswith(prefix):
continue
try:
with open(os.path.join(status_folder,p)) as f:
with open(os.path.join(status_folder, p)) as f:
j = json.load(f)
model_name = p.replace(prefix, "").replace(".json", "")
stat_dict[model_name] = j
Expand All @@ -27,11 +28,11 @@ def parse_status_json_files(status_folder, prefix = "fw_"):
return titles, stat_dict



def generate_node_count(titles, stat_dict, node_count_csv):
def get_op_cnt(op_type, op_infos):
op_types = [op_info["op_type"] for op_info in op_infos]
return op_types.count(op_type)

rows = [["model_name"] + titles]
for model_name in sorted(stat_dict.keys()):
stat = stat_dict[model_name]
Expand All @@ -41,26 +42,28 @@ def get_op_cnt(op_type, op_infos):
row.append(cnt)
rows.append(row)
row = ["TOTAL"]
for i in range(1,len(rows[0])):
for i in range(1, len(rows[0])):
row.append(sum([int(rows[j][i]) for j in range(1, len(rows))]))
rows.append(row)

with open(node_count_csv, "w") as f:
csv.writer(f, quotechar = '"').writerows(rows)
csv.writer(f, quotechar='"').writerows(rows)
print(f"{node_count_csv} generated")


def generate_total_size(stat_dict, size_dir, key):
assert(key in ["inputs", "outputs"])
assert key in ["inputs", "outputs"]
op_sizes = {}

def sizeof(dtype: str):
if dtype in ["torch.bool"]:
return 1/8
return 1 / 8
if dtype in ["torch.float32", "torch.int32"]:
return 4
if dtype in ["torch.float64", "torch.int64"]:
return 8
print(f"{dtype} not support")
assert(0)
assert 0

for model_name in stat_dict.keys():
stat = stat_dict[model_name]
Expand All @@ -71,7 +74,9 @@ def sizeof(dtype: str):
name = f"{op_type}_{idx}"
tensor_info = op_info[key][idx]
if "shape" in tensor_info.keys() and "dtype" in tensor_info.keys():
size = math.prod(tensor_info["shape"]) * sizeof(tensor_info["dtype"])
size = math.prod(tensor_info["shape"]) * sizeof(
tensor_info["dtype"]
)
op_sizes.setdefault(name, [])
op_sizes[name].append(size)

Expand All @@ -88,18 +93,29 @@ def sizeof(dtype: str):
plt.cla()
print(f"{size_dir} prepared")


if __name__ == "__main__":
out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(),"stat")
raw = os.path.join(out,"raw")
out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(), "stat")
raw = os.path.join(out, "raw")
assert os.path.isdir(raw) and "cannot find stat/raw folder"

def generate(prefix = "fw_"):
def generate(prefix="fw_"):
titles, stat_dict = parse_status_json_files(raw, prefix)
if titles == []:
return
generate_node_count(titles, stat_dict, os.path.join(out,f"{prefix}node_count.csv"))
generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_input_size_dist/"), key = "inputs")
generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_output_size_dist/"), key = "outputs")
generate_node_count(
titles, stat_dict, os.path.join(out, f"{prefix}node_count.csv")
)
generate_total_size(
stat_dict,
os.path.join(out, f"{prefix}total_input_size_dist/"),
key="inputs",
)
generate_total_size(
stat_dict,
os.path.join(out, f"{prefix}total_output_size_dist/"),
key="outputs",
)

generate("fw_")
generate("bw_")
47 changes: 36 additions & 11 deletions tools/run_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import torch
import torchvision

def run_model(model_name: str, backend: str, backward: bool, out_path: str, graphviz: bool, to_profile: bool, device = None):

def run_model(
model_name: str,
backend: str,
backward: bool,
out_path: str,
graphviz: bool,
to_profile: bool,
device=None,
):
if model_name == "dinov2_vits14":
m = torch.hub.load('facebookresearch/dinov2', model_name)
m = torch.hub.load("facebookresearch/dinov2", model_name)
else:
try:
m = torchvision.models.get_model(model_name, pretrained=True)
Expand All @@ -30,16 +39,21 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap
option = torch_ttnn.TorchTtnnOption(device=device)
m = torch.compile(m, backend=torch_ttnn.backend(option))
elif backend == "torch_stat":
option = torch_stat.TorchStatOption(model_name = model_name, backward = backward,
out = out_path, gen_graphviz=graphviz)
option = torch_stat.TorchStatOption(
model_name=model_name,
backward=backward,
out=out_path,
gen_graphviz=graphviz,
)
m = torch.compile(m, backend=torch_stat.backend(option))
else:
assert(0 and "Unsupport backend")
assert 0 and "Unsupport backend"

inputs = [torch.randn([1, 3, 224, 224])]

if to_profile:
from torch.profiler import profile, record_function, ProfilerActivity

trace_file = os.path.join(out_path, "profile", model_name)
os.makedirs(os.path.dirname(trace_file), exist_ok=True)
activities = [ProfilerActivity.CPU]
Expand All @@ -55,17 +69,20 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap
if backward:
result.backward(torch.ones_like(result))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", "-o", type = str, default = os.path.join(os.getcwd(),"stat"))
parser.add_argument("--backend", type = str)
parser.add_argument(
"--out_path", "-o", type=str, default=os.path.join(os.getcwd(), "stat")
)
parser.add_argument("--backend", type=str)
parser.add_argument("--graphviz", action="store_true")
parser.add_argument("--backward", action="store_true")
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()
assert(args.backend in ["torch_ttnn", "torch_stat"])
assert args.backend in ["torch_ttnn", "torch_stat"]
if args.backend == "torch_ttnn" and args.backward:
assert(0 and "torch_ttnn not yet support backward")
assert 0 and "torch_ttnn not yet support backward"

if args.backend == "torch_ttnn":
import torch_ttnn
Expand All @@ -77,8 +94,16 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap
device = torch_ttnn.ttnn.open(0) if args.backend == "torch_ttnn" else None
for m in models:
try:
run_model(m, args.backend, args.backward, args.out_path, args.graphviz, args.profile, device)
run_model(
m,
args.backend,
args.backward,
args.out_path,
args.graphviz,
args.profile,
device,
)
except:
print(f"{m} FAIL")
if args.backend == "torch_ttnn":
torch_ttnn.ttnn.close(device)
torch_ttnn.ttnn.close(device)
15 changes: 7 additions & 8 deletions torch_ttnn/passes/stat_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.infra.pass_base import PassBase, PassResult


def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file):
try:
FakeTensorProp(gm).propagate(*example_inputs)
Expand All @@ -16,13 +17,11 @@ def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file):
def get_tensor_info(t):
def no_symInt_in_list(the_list):
return not any(isinstance(element, torch.SymInt) for element in the_list)

# Only record if the tensor is torch.Tensor
# some shape is referenced by a variable, like [2, 256, s0, s1]
if isinstance(t, torch.Tensor) and no_symInt_in_list(list(t.shape)):
return {
"shape": list(t.shape),
"dtype": str(t.dtype)
}
return {"shape": list(t.shape), "dtype": str(t.dtype)}
else:
return {}

Expand All @@ -48,22 +47,22 @@ def no_symInt_in_list(the_list):
# set node's outputs info
node_info["outputs"] = []
outputs_info = node.meta["val"]
if isinstance(outputs_info, tuple) or \
isinstance(outputs_info, list):
if isinstance(outputs_info, tuple) or isinstance(outputs_info, list):
for output_info in outputs_info:
output = get_tensor_info(output_info)
node_info["outputs"].append(output)
elif isinstance(outputs_info, torch.Tensor):
output = get_tensor_info(outputs_info)
node_info["outputs"].append(output)
else:
assert(0 and "unsupport outputs_info")
assert 0 and "unsupport outputs_info"
out.append(node_info)
os.makedirs(os.path.dirname(out_file), exist_ok=True)

with open(out_file, "w") as f:
json.dump(out, f, indent=4)


# The pass to collect node's information
# Run tools/generate_report.py to genetate report
class StatPass(PassBase):
Expand All @@ -75,4 +74,4 @@ def __init__(self, filename, example_inputs):
def call(self, gm: torch.fx.GraphModule):
parse_fx_stat(gm, self.example_inputs, self.filename)
modified = False
return PassResult(gm, modified)
return PassResult(gm, modified)

0 comments on commit a348cfa

Please sign in to comment.