Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate standalone code that compares the accuracy between corresponding aten ops and lowered TTNN ops #611

Merged
merged 36 commits into from
Jan 29, 2025

Conversation

kevinwuTT
Copy link
Contributor

@kevinwuTT kevinwuTT commented Dec 16, 2024

Ticket

None

Problem description

Some models are reporting to have bad accuracy, (low pcc values). We want to pin-point the ops that caused this.

What's changed

This PR extracts the aten and ttnn graphs during compilation/runtime and create a standalone Python script that can be run with minimum dependencies. In between each aten op(s) and matching lowered ttnn ops(s), a function will be called to compare the outputs. Currently, an assertion will terminate the script if the pcc is below the desired value. This will narrow down the offending set of ops. Input data values from the model run will also be exported alongside the script, so we will not use synthetic values.

  • Create Github action to enable/disable generating the accuracy test scripts
  • Create README or doc to explain how to use this feature
  • Refactor and cleanup

Suggestions:

  • Move generated scripts to tests, similar to autogen tests
  • Consider using safetensors instead of pickle
    There are issues with saving the tensors in this format because of share memory. save_model won't work because this is not a torch.nn.Module. May separate this into another issue.
RuntimeError: 
            Some tensors share memory, this will lead to duplicate memory on disk and potential differences when loading them again: [{'arg0_1', 'arg27_1'}].
            A potential way to correctly save your model is to use `save_model`.
            More information at https://huggingface.co/docs/safetensors/torch_shared_tensors
  • Error handling of pickle file in main
  • Remaining redundant code
  • Docstrings
  • Verbose debug
  • Considerations for inspect.getsource

@@ -0,0 +1,280 @@
import inspect
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to tools/?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to find a better location, instead of a top level module dir

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to tools/. In a future update, I will further separate them into

tools\
|-- export code base module
    |-- generate accuracy code
    |-- generate profiling code

Copy link
Member

@ayerofieiev-tt ayerofieiev-tt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great tooling! Thank you!

Here are some thoughts on how to improve this further

Code Organization

The current script is hard to navigate because it mixes logic, utils, all together.
Can you split the code?

Error Handling:

In main_code, the try-except block silently catches all exceptions without specifying the error type. This can mask issues unrelated to file opening.
Solution:
Catch specific exceptions like FileNotFoundError or IOError:

try:
    with open("{full_input_pkl_path}", "rb") as file:
        inputs = pickle.load(file)
except FileNotFoundError:
    with open("{input_pkl_file}", "rb") as file:
        inputs = pickle.load(file)
  1. Redundant Code
    There is repetitive code, especially for string manipulations and conditions like:
if node.op != "placeholder" and node.op != "output":

Maybe create a function to encapsulate repetitive logic:

def is_valid_node(node):
    return node.op not in ["placeholder", "output"]

Documentation

Consider adding docstrings

def rename_nodes(graph, prefix):
    """
    Renames nodes in the graph to prevent conflicts with wrapper or built-in functions.

    Args:
        graph: The computational graph to process.
        prefix: A string prefix for renaming.

    Returns:
        The modified graph with renamed nodes.
    """
    for node in graph.nodes:
        if node.op not in ["placeholder", "output"]:
            opname = str(node.target) if str(node.target).startswith("aten.") else node.target.__name__
            if not opname.startswith(("aten.", "ttnn.")):
                node._rename(f"{prefix}_{node.name}")
    return graph

inspect.getsource Edge Cases

I am minorly concerned abotu using inspect.getsource() with an assumtion that the target function is always accessible and valid. If node.target is a dynamically created or non-source-mapped function, it may fail.

Verbose Debug Printing

There are many print() statements for debugging (print("aten graph args:", arg_nodes)), which might clutter outputs in production. Use Python's logging module for more controlled logging:

import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

logger.info("aten graph args: %s", arg_nodes)

Now you can control verbosity using:

python script.py --log-level DEBUG

List Operations

List comprehensions and dictionary lookups (like compute_key) are repeated within loops. Cache results outside loops.

@kevinwuTT kevinwuTT marked this pull request as ready for review January 6, 2025 15:32
@kevinwuTT kevinwuTT added this pull request to the merge queue Jan 29, 2025
Merged via the queue into main with commit 5fd9996 Jan 29, 2025
1 check passed
@kevinwuTT kevinwuTT deleted the kw/gen_acc_tests branch January 29, 2025 18:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants