-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generate standalone code that compares the accuracy between corresponding aten ops and lowered TTNN ops #611
Conversation
@@ -0,0 +1,280 @@ | |||
import inspect |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to tools/
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to find a better location, instead of a top level module dir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to tools/
. In a future update, I will further separate them into
tools\
|-- export code base module
|-- generate accuracy code
|-- generate profiling code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great tooling! Thank you!
Here are some thoughts on how to improve this further
Code Organization
The current script is hard to navigate because it mixes logic, utils, all together.
Can you split the code?
Error Handling:
In main_code, the try-except block silently catches all exceptions without specifying the error type. This can mask issues unrelated to file opening.
Solution:
Catch specific exceptions like FileNotFoundError or IOError:
try:
with open("{full_input_pkl_path}", "rb") as file:
inputs = pickle.load(file)
except FileNotFoundError:
with open("{input_pkl_file}", "rb") as file:
inputs = pickle.load(file)
- Redundant Code
There is repetitive code, especially for string manipulations and conditions like:
if node.op != "placeholder" and node.op != "output":
Maybe create a function to encapsulate repetitive logic:
def is_valid_node(node):
return node.op not in ["placeholder", "output"]
Documentation
Consider adding docstrings
def rename_nodes(graph, prefix):
"""
Renames nodes in the graph to prevent conflicts with wrapper or built-in functions.
Args:
graph: The computational graph to process.
prefix: A string prefix for renaming.
Returns:
The modified graph with renamed nodes.
"""
for node in graph.nodes:
if node.op not in ["placeholder", "output"]:
opname = str(node.target) if str(node.target).startswith("aten.") else node.target.__name__
if not opname.startswith(("aten.", "ttnn.")):
node._rename(f"{prefix}_{node.name}")
return graph
inspect.getsource Edge Cases
I am minorly concerned abotu using inspect.getsource() with an assumtion that the target function is always accessible and valid. If node.target is a dynamically created or non-source-mapped function, it may fail.
Verbose Debug Printing
There are many print() statements for debugging (print("aten graph args:", arg_nodes)), which might clutter outputs in production. Use Python's logging module for more controlled logging:
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
logger.info("aten graph args: %s", arg_nodes)
Now you can control verbosity using:
python script.py --log-level DEBUG
List Operations
List comprehensions and dictionary lookups (like compute_key) are repeated within loops. Cache results outside loops.
…t is 1" This reverts commit 39b81e7.
326af78
to
52b47f3
Compare
…ve their own shapes instead of inheriting from original aten
f10cf5f
to
9e3ca15
Compare
…ntout of generated items
9e3ca15
to
4f815b7
Compare
This reverts commit 17ea0ed.
Ticket
None
Problem description
Some models are reporting to have bad accuracy, (low pcc values). We want to pin-point the ops that caused this.
What's changed
This PR extracts the aten and ttnn graphs during compilation/runtime and create a standalone Python script that can be run with minimum dependencies. In between each aten op(s) and matching lowered ttnn ops(s), a function will be called to compare the outputs. Currently, an assertion will terminate the script if the pcc is below the desired value. This will narrow down the offending set of ops. Input data values from the model run will also be exported alongside the script, so we will not use synthetic values.
Suggestions:
There are issues with saving the tensors in this format because of share memory.
save_model
won't work because this is not atorch.nn.Module
. May separate this into another issue.inspect.getsource