Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pre-commit / update workflows / format files #13

Merged
merged 12 commits into from
Jun 26, 2024
7 changes: 4 additions & 3 deletions .github/workflows/before_merge.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ permissions:
id-token: write

jobs:
run-pytest:
validate-pr:
env:
ARCH_NAME: wormhole_b0
TT_METAL_HOME: ${pwd}
Expand All @@ -26,8 +26,9 @@ jobs:
python3 -m venv venv
source venv/bin/activate
python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements-dev.txt
python3 -m pip install --upgrade pip
python3 -m pip install -r requirements-dev.txt
python3 -m pip install pytest-github-report
- name: Run Tests
env:
pytest_verbosity: 2
Expand Down
8 changes: 5 additions & 3 deletions .github/workflows/pull-request.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ on:

jobs:
validate-pr:
runs-on: ["in-service", "n150"]
runs-on: ubuntu-latest
steps:
- name: Skip run
run: echo "Empty check passed"
- name: Checkout
uses: actions/checkout@v4
- name: Black
uses: psf/[email protected]
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
__pycache__
venv
.vscode
stat
*.dot
*.svg
*.csv
*.csv
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
repos:
- repo: https://github.com/psf/black-pre-commit-mirror
rev: 23.10.1
hooks:
- id: black
language_version: python3
5 changes: 3 additions & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
-r requirements.txt
pytest
pytest-github-report
pytest==7.2.2
pytest-timeout==2.2.0
pre-commit==3.0.4
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch==2.2.1.0+cpu
torchvision==0.17.1+cpu
tabulate
tabulate==0.9.0
networkx==3.1
graphviz
matplotlib
46 changes: 31 additions & 15 deletions tools/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import math
import matplotlib.pyplot as plt

def parse_status_json_files(status_folder, prefix = "fw_"):

def parse_status_json_files(status_folder, prefix="fw_"):
stat_dict = {}
titles = set()

Expand All @@ -15,7 +16,7 @@ def parse_status_json_files(status_folder, prefix = "fw_"):
if not p.startswith(prefix):
continue
try:
with open(os.path.join(status_folder,p)) as f:
with open(os.path.join(status_folder, p)) as f:
j = json.load(f)
model_name = p.replace(prefix, "").replace(".json", "")
stat_dict[model_name] = j
Expand All @@ -27,11 +28,11 @@ def parse_status_json_files(status_folder, prefix = "fw_"):
return titles, stat_dict



def generate_node_count(titles, stat_dict, node_count_csv):
def get_op_cnt(op_type, op_infos):
op_types = [op_info["op_type"] for op_info in op_infos]
return op_types.count(op_type)

rows = [["model_name"] + titles]
for model_name in sorted(stat_dict.keys()):
stat = stat_dict[model_name]
Expand All @@ -41,26 +42,28 @@ def get_op_cnt(op_type, op_infos):
row.append(cnt)
rows.append(row)
row = ["TOTAL"]
for i in range(1,len(rows[0])):
for i in range(1, len(rows[0])):
row.append(sum([int(rows[j][i]) for j in range(1, len(rows))]))
rows.append(row)

with open(node_count_csv, "w") as f:
csv.writer(f, quotechar = '"').writerows(rows)
csv.writer(f, quotechar='"').writerows(rows)
print(f"{node_count_csv} generated")


def generate_total_size(stat_dict, size_dir, key):
assert(key in ["inputs", "outputs"])
assert key in ["inputs", "outputs"]
op_sizes = {}

def sizeof(dtype: str):
if dtype in ["torch.bool"]:
return 1/8
return 1 / 8
if dtype in ["torch.float32", "torch.int32"]:
return 4
if dtype in ["torch.float64", "torch.int64"]:
return 8
print(f"{dtype} not support")
assert(0)
assert 0

for model_name in stat_dict.keys():
stat = stat_dict[model_name]
Expand All @@ -71,7 +74,9 @@ def sizeof(dtype: str):
name = f"{op_type}_{idx}"
tensor_info = op_info[key][idx]
if "shape" in tensor_info.keys() and "dtype" in tensor_info.keys():
size = math.prod(tensor_info["shape"]) * sizeof(tensor_info["dtype"])
size = math.prod(tensor_info["shape"]) * sizeof(
tensor_info["dtype"]
)
op_sizes.setdefault(name, [])
op_sizes[name].append(size)

Expand All @@ -88,18 +93,29 @@ def sizeof(dtype: str):
plt.cla()
print(f"{size_dir} prepared")


if __name__ == "__main__":
out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(),"stat")
raw = os.path.join(out,"raw")
out = sys.argv[1] if len(sys.argv) > 1 else os.path.join(os.getcwd(), "stat")
raw = os.path.join(out, "raw")
assert os.path.isdir(raw) and "cannot find stat/raw folder"

def generate(prefix = "fw_"):
def generate(prefix="fw_"):
titles, stat_dict = parse_status_json_files(raw, prefix)
if titles == []:
return
generate_node_count(titles, stat_dict, os.path.join(out,f"{prefix}node_count.csv"))
generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_input_size_dist/"), key = "inputs")
generate_total_size(stat_dict, os.path.join(out,f"{prefix}total_output_size_dist/"), key = "outputs")
generate_node_count(
titles, stat_dict, os.path.join(out, f"{prefix}node_count.csv")
)
generate_total_size(
stat_dict,
os.path.join(out, f"{prefix}total_input_size_dist/"),
key="inputs",
)
generate_total_size(
stat_dict,
os.path.join(out, f"{prefix}total_output_size_dist/"),
key="outputs",
)

generate("fw_")
generate("bw_")
47 changes: 36 additions & 11 deletions tools/run_torchvision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,18 @@
import torch
import torchvision

def run_model(model_name: str, backend: str, backward: bool, out_path: str, graphviz: bool, to_profile: bool, device = None):

def run_model(
model_name: str,
backend: str,
backward: bool,
out_path: str,
graphviz: bool,
to_profile: bool,
device=None,
):
if model_name == "dinov2_vits14":
m = torch.hub.load('facebookresearch/dinov2', model_name)
m = torch.hub.load("facebookresearch/dinov2", model_name)
else:
try:
m = torchvision.models.get_model(model_name, pretrained=True)
Expand All @@ -30,16 +39,21 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap
option = torch_ttnn.TorchTtnnOption(device=device)
m = torch.compile(m, backend=torch_ttnn.backend(option))
elif backend == "torch_stat":
option = torch_stat.TorchStatOption(model_name = model_name, backward = backward,
out = out_path, gen_graphviz=graphviz)
option = torch_stat.TorchStatOption(
model_name=model_name,
backward=backward,
out=out_path,
gen_graphviz=graphviz,
)
m = torch.compile(m, backend=torch_stat.backend(option))
else:
assert(0 and "Unsupport backend")
assert 0 and "Unsupport backend"

inputs = [torch.randn([1, 3, 224, 224])]

if to_profile:
from torch.profiler import profile, record_function, ProfilerActivity

trace_file = os.path.join(out_path, "profile", model_name)
os.makedirs(os.path.dirname(trace_file), exist_ok=True)
activities = [ProfilerActivity.CPU]
Expand All @@ -55,17 +69,20 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap
if backward:
result.backward(torch.ones_like(result))


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--out_path", "-o", type = str, default = os.path.join(os.getcwd(),"stat"))
parser.add_argument("--backend", type = str)
parser.add_argument(
"--out_path", "-o", type=str, default=os.path.join(os.getcwd(), "stat")
)
parser.add_argument("--backend", type=str)
parser.add_argument("--graphviz", action="store_true")
parser.add_argument("--backward", action="store_true")
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()
assert(args.backend in ["torch_ttnn", "torch_stat"])
assert args.backend in ["torch_ttnn", "torch_stat"]
if args.backend == "torch_ttnn" and args.backward:
assert(0 and "torch_ttnn not yet support backward")
assert 0 and "torch_ttnn not yet support backward"

if args.backend == "torch_ttnn":
import torch_ttnn
Expand All @@ -77,8 +94,16 @@ def run_model(model_name: str, backend: str, backward: bool, out_path: str, grap
device = torch_ttnn.ttnn.open(0) if args.backend == "torch_ttnn" else None
for m in models:
try:
run_model(m, args.backend, args.backward, args.out_path, args.graphviz, args.profile, device)
run_model(
m,
args.backend,
args.backward,
args.out_path,
args.graphviz,
args.profile,
device,
)
except:
print(f"{m} FAIL")
if args.backend == "torch_ttnn":
torch_ttnn.ttnn.close(device)
torch_ttnn.ttnn.close(device)
15 changes: 7 additions & 8 deletions torch_ttnn/passes/stat_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.fx.passes.infra.pass_base import PassBase, PassResult


def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file):
try:
FakeTensorProp(gm).propagate(*example_inputs)
Expand All @@ -16,13 +17,11 @@ def parse_fx_stat(gm: torch.fx.GraphModule, example_inputs, out_file):
def get_tensor_info(t):
def no_symInt_in_list(the_list):
return not any(isinstance(element, torch.SymInt) for element in the_list)

# Only record if the tensor is torch.Tensor
# some shape is referenced by a variable, like [2, 256, s0, s1]
if isinstance(t, torch.Tensor) and no_symInt_in_list(list(t.shape)):
return {
"shape": list(t.shape),
"dtype": str(t.dtype)
}
return {"shape": list(t.shape), "dtype": str(t.dtype)}
else:
return {}

Expand All @@ -48,22 +47,22 @@ def no_symInt_in_list(the_list):
# set node's outputs info
node_info["outputs"] = []
outputs_info = node.meta["val"]
if isinstance(outputs_info, tuple) or \
isinstance(outputs_info, list):
if isinstance(outputs_info, tuple) or isinstance(outputs_info, list):
for output_info in outputs_info:
output = get_tensor_info(output_info)
node_info["outputs"].append(output)
elif isinstance(outputs_info, torch.Tensor):
output = get_tensor_info(outputs_info)
node_info["outputs"].append(output)
else:
assert(0 and "unsupport outputs_info")
assert 0 and "unsupport outputs_info"
out.append(node_info)
os.makedirs(os.path.dirname(out_file), exist_ok=True)

with open(out_file, "w") as f:
json.dump(out, f, indent=4)


# The pass to collect node's information
# Run tools/generate_report.py to genetate report
class StatPass(PassBase):
Expand All @@ -75,4 +74,4 @@ def __init__(self, filename, example_inputs):
def call(self, gm: torch.fx.GraphModule):
parse_fx_stat(gm, self.example_inputs, self.filename)
modified = False
return PassResult(gm, modified)
return PassResult(gm, modified)