From 8a2f062eb0fa9490ca1065954382eac48c67faba Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 17 Jul 2024 20:14:07 +0000 Subject: [PATCH] Generate schema metrics and input variation metrics --- tests/models/bert/test_bert.py | 2 +- tests/models/bloom/test_bloom.py | 2 +- tests/models/falcon/test_falcon.py | 2 +- tests/models/gpt2/test_gpt2.py | 2 +- tests/models/llama/test_llama.py | 2 +- tests/models/mnist/test_mnist.py | 2 +- tests/models/resnet/test_resnet.py | 2 +- tests/models/yolos/test_yolos.py | 2 +- tools/collect_metrics.py | 67 ++++++++++- torch_ttnn/backend.py | 38 +++--- torch_ttnn/metrics.py | 180 +++++++++++++++++++++++++++++ torch_ttnn/utils.py | 42 ------- 12 files changed, 266 insertions(+), 77 deletions(-) create mode 100644 torch_ttnn/metrics.py diff --git a/tests/models/bert/test_bert.py b/tests/models/bert/test_bert.py index 78c0e9722..64618251c 100644 --- a/tests/models/bert/test_bert.py +++ b/tests/models/bert/test_bert.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics # Load model directly from transformers import AutoTokenizer, AutoModelForQuestionAnswering diff --git a/tests/models/bloom/test_bloom.py b/tests/models/bloom/test_bloom.py index 9686ab69e..7cd267707 100644 --- a/tests/models/bloom/test_bloom.py +++ b/tests/models/bloom/test_bloom.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM diff --git a/tests/models/falcon/test_falcon.py b/tests/models/falcon/test_falcon.py index 3e1da911a..1f6f69081 100644 --- a/tests/models/falcon/test_falcon.py +++ b/tests/models/falcon/test_falcon.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM diff --git a/tests/models/gpt2/test_gpt2.py b/tests/models/gpt2/test_gpt2.py index 240c5b2ff..9f184f207 100644 --- a/tests/models/gpt2/test_gpt2.py +++ b/tests/models/gpt2/test_gpt2.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics # Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification diff --git a/tests/models/llama/test_llama.py b/tests/models/llama/test_llama.py index 2dc4eaff0..cc278a2cd 100644 --- a/tests/models/llama/test_llama.py +++ b/tests/models/llama/test_llama.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics # Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM diff --git a/tests/models/mnist/test_mnist.py b/tests/models/mnist/test_mnist.py index ed50dafb8..e52ac896f 100644 --- a/tests/models/mnist/test_mnist.py +++ b/tests/models/mnist/test_mnist.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics from tests.utils import check_with_pcc from torchvision import transforms, datasets from torch.utils.data import DataLoader diff --git a/tests/models/resnet/test_resnet.py b/tests/models/resnet/test_resnet.py index 00b8f1114..6d475b0bf 100644 --- a/tests/models/resnet/test_resnet.py +++ b/tests/models/resnet/test_resnet.py @@ -4,7 +4,7 @@ import unittest import ttnn import collections -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics from tests.utils import check_with_pcc diff --git a/tests/models/yolos/test_yolos.py b/tests/models/yolos/test_yolos.py index ff6785713..36223ca68 100644 --- a/tests/models/yolos/test_yolos.py +++ b/tests/models/yolos/test_yolos.py @@ -2,7 +2,7 @@ import torch_ttnn import unittest import ttnn -from torch_ttnn.utils import RunTimeMetrics +from torch_ttnn.metrics import RunTimeMetrics from PIL import Image import requests diff --git a/tools/collect_metrics.py b/tools/collect_metrics.py index 9e7d90c54..07971bceb 100644 --- a/tools/collect_metrics.py +++ b/tools/collect_metrics.py @@ -4,6 +4,7 @@ import csv from pathlib import Path import pandas as pd +import numpy as np from tests.utils import comp_pcc @@ -62,6 +63,9 @@ def load_pickle(path: str): # Holds the concatenation of all the metrics for each model all_metrics = [] + # Holds the concatenation of input variation metrics for all models + all_input_var_metrics = {} + # Assumed directory structure example. Some files will not exist if test failed. """ pytorch2.0_ttnn @@ -156,9 +160,9 @@ def load_pickle(path: str): f"{compiled_outputs}\n" ) accuracy_metric = { - "accuracy": round(accuracy, 2) - if not isinstance(accuracy, str) - else accuracy + "accuracy": ( + round(accuracy, 2) if not isinstance(accuracy, str) else accuracy + ) } # Add links that point to the directory of the model in the model name @@ -203,6 +207,63 @@ def load_pickle(path: str): all_metrics.append(cat_metrics_remapped) + # Process input variation metrics. Currently, this is not per model, but per op. + input_var_metrics_path = model_path / "aten_ops_input_variations.pickle" + input_var_metrics = ( + load_pickle(input_var_metrics_path) + if os.path.isfile(input_var_metrics_path) + else {} + ) + + for key, val in input_var_metrics.items(): + if key not in all_input_var_metrics: + all_input_var_metrics[key] = val + else: + # Only append if shape and value combination have not been collected + for shape, value in zip(val["input_shapes"], val["input_values"]): + if ( + shape not in all_input_var_metrics[key]["input_shapes"] + and value not in all_input_var_metrics[key]["input_values"] + ): + all_input_var_metrics[key]["input_shapes"].append(shape) + all_input_var_metrics[key]["input_values"].append(value) + + # Write input variation metrics to csv + if all_input_var_metrics: + # Holds the rows to generate csv + input_var_list_for_csv = {} + # turn input_shapes and input_values into individual columns + for val in list(all_input_var_metrics.values()): + # holds the variations of input string + input_var_list = [] + for shapes, values in zip(val["input_shapes"], val["input_values"]): + # holds each individual input to be joined to a string + input_string_list = [] + for i, (shape, value) in enumerate(zip(shapes, values)): + # This instance is a kwarg + if isinstance(value, tuple): + arg_name = value[0] + arg_type = val["schema"]["kwargs"][arg_name] + arg_val = f" = {value[1]}" + else: + arg_type = val["schema"]["args"][i][0] + arg_name = val["schema"]["args"][i][1] + arg_val = f" = {value}" if value else "" + + arg_shape = f"<{shape}>" if shape else "" + + input_string_list.append( + f"{arg_type}{arg_shape} {arg_name}{arg_val}" + ) + input_var_list.append(", ".join(input_string_list)) + input_var_list_for_csv[val["opname"]] = input_var_list + + df = pd.DataFrame( + {key: pd.Series(value) for key, value in input_var_list_for_csv.items()} + ) + df.to_csv("input_variations.csv", encoding="utf-8", index=False) + print(f"Data written to input_variations.csv") + # Write metrics to csv if all_metrics: with open(f"metrics.csv", "w", newline="") as f: diff --git a/torch_ttnn/backend.py b/torch_ttnn/backend.py index 6c2c2d92e..56ca9b902 100644 --- a/torch_ttnn/backend.py +++ b/torch_ttnn/backend.py @@ -7,6 +7,7 @@ import pickle from pathlib import Path import os +import torch_ttnn.metrics as metrics torch._dynamo.config.suppress_errors = False torch._dynamo.config.verbose = True @@ -74,24 +75,19 @@ def aten_backend( option: TorchTtnnOption = options["torch_ttnn_option"] - # Helper function to count the number of aten ops in the graph currently - # Returns a tuple of (total ops, total unique ops) - def count_aten_ops(): - aten_ops = [ - str(node.target) - for node in list(gm.graph.nodes) - if node.op in ["call_function", "call_method"] - and isinstance(node.target, torch._ops.OpOverload) - and "aten" in str(node.target) - ] - return (len(aten_ops), len(set(aten_ops))) - # Save the number of aten ops before compilation if option.metrics_path: ( option._metrics["torch_ops_before"], option._metrics["torch_ops_unique_before"], - ) = count_aten_ops() + ) = metrics.count_aten_ops(gm.graph.nodes) + + input_variations = metrics.collect_input_variations_from_nodes(gm.graph.nodes) + # Save the input variation data + p = Path(f"metrics/{option.metrics_path}") + pickle_out_path = p / "aten_ops_input_variations.pickle" + with open(pickle_out_path, "wb") as f: + pickle.dump(input_variations, f) # Register ttnn objects as graph globals register_ttnn_objects(option) @@ -140,17 +136,11 @@ def count_aten_ops(): ( option._metrics["torch_ops_remain"], option._metrics["torch_ops_unique_remain"], - ) = count_aten_ops() - # Save the number of to/from_device ops in current graph - to_from_device_ops = [ - node.target.__name__ - for node in list(gm.graph.nodes) - if node.op in ["call_function", "call_method"] - and ( - "ttnn.to" in node.target.__name__ or "ttnn.from" in node.target.__name__ - ) - ] - option._metrics["to_from_device_ops"] = len(to_from_device_ops) + ) = metrics.count_aten_ops(gm.graph.nodes) + option._metrics["to_from_device_ops"] = metrics.count_to_from_device_ops( + gm.graph.nodes + ) + # Save the data as pickle files p = Path(f"metrics/{option.metrics_path}") pickle_out_path = p / "compiled-op_metrics.pickle" diff --git a/torch_ttnn/metrics.py b/torch_ttnn/metrics.py new file mode 100644 index 000000000..e1d46d690 --- /dev/null +++ b/torch_ttnn/metrics.py @@ -0,0 +1,180 @@ +import torch +import pickle +import time +import os +from pathlib import Path + + +# Count the number of aten ops in the graph currently +# Returns a tuple of (total ops, total unique ops) +def count_aten_ops(nodes: list): + aten_ops = [ + str(node.target) + for node in nodes + if node.op in ["call_function", "call_method"] + and isinstance(node.target, torch._ops.OpOverload) + and "aten" in str(node.target) + ] + return (len(aten_ops), len(set(aten_ops))) + + +# Save the number of to/from device ops in current graph +def count_to_from_device_ops(nodes: list): + to_from_device_ops = [ + node.target.__name__ + for node in nodes + if node.op in ["call_function", "call_method"] + and ("ttnn.to" in node.target.__name__ or "ttnn.from" in node.target.__name__) + ] + return len(to_from_device_ops) + + +def collect_schema_from_nodes(nodes: list): + """Collect a list of opname, schema, input_shapes, and input_values for all nodes. + + Returns: + ``` + [ + { + 'opname': str, + 'schema': {"args": list(tuple), "kwargs": list(tuple)} + 'input_shapes': list(str), + 'input_values': list(str|tuple), + }, + ] + ``` + """ + collection = [] + for node in nodes: + if hasattr(node.target, "_schema"): + node_stats = {} + # Collect the opname + node_stats["opname"] = str(node.target) + + # Get schema from op + pos_args = [ + (str(arg.type), str(arg.name)) + for arg in node.target._schema.arguments + if not arg.kwarg_only + ] + kw_args = { + str(arg.name): str(arg.type) + for arg in node.target._schema.arguments + if arg.kwarg_only + } + op_schema = {"args": pos_args, "kwargs": kw_args} + node_stats["schema"] = op_schema + + arg_shapes = [] + arg_values = [] + for arg in node.args: + # Collect the input values. + # Unknown values will be substituted with an empty string. + if not isinstance(arg, torch.fx.node.Node): + if isinstance(arg, int) and arg == 9223372036854775807: + arg_values.append(str(-1)) + else: + arg_values.append(str(arg)) + else: + arg_values.append("") + + # Collect the input shapes from the metadata if possible. + if hasattr(arg, "meta"): + arg_shapes.append(str(list(arg.meta["val"].size()))) + else: + arg_shapes.append("") + + # Collect any additional kwargs values if they exist + for key, val in node.kwargs.items(): + # Can kwargs values be other nodes? + arg_values.append((key, val)) + arg_shapes.append("") + + # Merge all the input information into a single string + node_stats["input_shapes"] = arg_shapes + node_stats["input_values"] = arg_values + + collection.append(node_stats) + return collection + + +def collect_input_variations_from_nodes(nodes: list): + """Creates a dictionary of unique nodes with their schema and input variations. + + Returns: + ``` + { + : + { + 'opname': str, + 'schema': {"args": list(tuple), "kwargs": list(tuple)} + 'input_shapes': list(str), + 'input_values': list(str|tuple), + }, + : {...}, + } + ``` + + """ + schemas = collect_schema_from_nodes(nodes) + collection = {} + for node in schemas: + if "schema" in node: + opname = node["opname"] + input_shapes = node["input_shapes"] + input_values = node["input_values"] + # Create a new entry if opname has not been registered + if opname not in collection: + entry = { + "opname": opname, + "schema": node["schema"], + "input_shapes": [input_shapes], + "input_values": [input_values], + } + collection[opname] = entry + else: + if ( + input_shapes not in collection[opname]["input_shapes"] + and input_values not in collection[opname]["input_values"] + ): + collection[opname]["input_shapes"].append(input_shapes) + collection[opname]["input_values"].append(input_values) + return collection + + +def RunTimeMetrics(path: str, prefix: str, f): + """ + Measure the runtime of the model in seconds. + * Exports a pickle file containing success and runtime data + * Exports outputs in Pytorch tensor format + + Parameters: + path (str): Typically the name of the model + prefix (str): Either "original" or "compiled" is recommended + f: lambda function of model run + + Example: + output = RunTimeMetrics("BERT", "compiled", lambda: model(**inputs)) + + Returns: + Output from the model or None if model fails + """ + p = Path(f"metrics/{path}") + pt_out_path = p / f"{prefix}-outputs.pt" + pickle_out_path = p / f"{prefix}-runtime_metrics.pickle" + os.makedirs(p, exist_ok=True) + try: + start = time.perf_counter() + ret = f() + end = time.perf_counter() + runtime_metrics = {"success": "✔️", "run_time": round(end - start, 2)} + + torch.save(ret, pt_out_path) + except: + runtime_metrics = {"success": "✘"} + ret = None + + with open(pickle_out_path, "wb") as f: + pickle.dump(runtime_metrics, f) + + return ret diff --git a/torch_ttnn/utils.py b/torch_ttnn/utils.py index 6ab77e0e8..f90e95884 100644 --- a/torch_ttnn/utils.py +++ b/torch_ttnn/utils.py @@ -1,8 +1,4 @@ import torch -import time -import os -from pathlib import Path -import pickle def GraphCleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: @@ -13,44 +9,6 @@ def GraphCleanup(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: return gm -def RunTimeMetrics(path: str, prefix: str, f): - """ - Measure the runtime of the model in seconds. - * Exports a pickle file containing success and runtime data - * Exports outputs in Pytorch tensor format - - Parameters: - path (str): Typically the name of the model - prefix (str): Either "original" or "compiled" is recommended - f: lambda function of model run - - Example: - output = RunTimeMetrics("BERT", "compiled", lambda: model(**inputs)) - - Returns: - Output from the model or None if model fails - """ - p = Path(f"metrics/{path}") - pt_out_path = p / f"{prefix}-outputs.pt" - pickle_out_path = p / f"{prefix}-runtime_metrics.pickle" - os.makedirs(p, exist_ok=True) - try: - start = time.perf_counter() - ret = f() - end = time.perf_counter() - runtime_metrics = {"success": "✔️", "run_time": round(end - start, 2)} - - torch.save(ret, pt_out_path) - except: - runtime_metrics = {"success": "✘"} - ret = None - - with open(pickle_out_path, "wb") as f: - pickle.dump(runtime_metrics, f) - - return ret - - # Ttnn globals added with torch.fx._register_custom_builtin class TtnnDevice: def __repr__(self):