-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preserve intermediate node meta for composite ops #50
Preserve intermediate node meta for composite ops #50
Conversation
@@ -35,6 +35,10 @@ def test_div(device, input_shapes): | |||
assert target.count(ttnn.mul) == 1 | |||
assert target.index(ttnn.reciprocal) < target.index(ttnn.mul) | |||
assert nodes[target.index(ttnn.mul)].args[1].target == ttnn.reciprocal | |||
# Intermediate node meta check if preserved | |||
for node in nodes: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any idea how this can be added centrally so one don't have to check it like this in every test?
@@ -235,6 +239,10 @@ def ReplaceMoreTtManually(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | |||
full_node = g.call_function( | |||
ttnn.full, args=(arg_metadata.size(),), kwargs=new_kwargs | |||
) | |||
# Intermediate node meta is not preserved, this ensures retention | |||
# Are output dims same for full node and actual node to be replaced? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think output dims are not guaranteed to be the same due to tiling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the meta field stores non-tiled output shape for each of these.
@@ -266,15 +274,22 @@ def ReplaceMoreTtManually(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | |||
# out = beta * input + alpha * (batch1 @ batch2) | |||
# if beta is 0, input is ignored, and nan and inf in it will not be propogated | |||
new_node = g.call_function(ttnn.matmul, args=(args[1], args[2])) | |||
# Intermediate node meta is not preserved, this ensures retention | |||
new_node.meta["val"] = node.meta["val"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any ideas on how to make this a common routine for conversions to avoid duplication and make sure it is consistently used across all conversions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please think how to make such things a part of the infrastructure, so that each individual conversion don't have to care about it that much
…espective tests (tenstorrent#50) Co-authored-by: Artem Yerofieiev <[email protected]>
* change tracer type as a decorator, its API change, too * torch_ttnn.backend apply tracer config * mv json dump out of parse_fx_graph * Wrap leaf func with function form instead of decorator * Rename TorchTtnnOption to TenstorrentBackendOption * Revert "Use subgraph_rewriter" This reverts commit fc09080. * Extrace mock_ttnn to a standalone package * Register ttnn backend * Add setup.py and pyproject.toml * Update README.md * mv torch_ttnn/tracer.py tracer/tracer.py * Add model_sow2_list1 in tools/stat_models * Fix counter bug * fix try except * detr_resnet50 retain_graph=True * Update README.md for package building * Update test case for ttnn inferface change - ttnn.open(0) -> ttnn.open_device(device_id=0) - ttnn.close(d) -> ttnn.close_device(d) * Convert 14 pointwise binary operations Add conversion and unit test cases - add - eq - gt - logical_and - logical_or - logical_xor - lt - maximum - minimum - mul - ne - pow - sub - xlogy * Convert 35 pointwise unary operations - aten.abs - aten.acos - aten.acosh - aten.asin - aten.asinh - aten.atan - aten.atan2 # binary - aten.atanh - aten.clone - aten.cos - aten.cosh - aten.erf - aten.exp - aten.expm1 - aten.gelu - aten.hardtanh - aten.isinf - aten.isnan - aten.leaky_relu - aten.log - aten.log10 - aten.log1p - aten.log2 - aten.logical_not - aten.neg - aten.reciprocal - aten.relu - aten.rsqrt - aten.sigmoid - aten.sign - aten.sin - aten.sinh - aten.sqrt - aten.tan - aten.tanh * Convert 3 pointwise trinary operations - addcdiv - addcmul - where * Convert 2 matmul operations Also use fx.subgraph_rewriter - matmul - linear * Simplify op conversion * Fix wrap for ttnn & update test - ttnn.add(and other ops) don't have __name__, so torch.compile will fail. We hard patch the op with the __name__ - Now ttnn need a to_layout before computation * Fix more ops unit test * Simpify pass insertion for gen_graphviz * Update test cases for to_layour * Fix add_data_move support kwargs * Support linear without bias * Don't gen graphviz for pointwise unary test * Fix three ops, and verify some ops 3 op are fixed - clone - hardtanh - leaky_relu Following ops are verifyed and wont be fixed - atan2: ttnn bug while x or y is 0 - pow: exponent don't suporrt tensor type - xlogy: y==1 should be 0 * Simplify binary test, Add scalar support * Supprt addcmul & addcdiv * Fix support of addcdiv and where - addcdiv with option is a special case in torch, need special pattern matching - ttnn.where(c, x, y) c can not be bool, need cast * Convert and test repeat op * Add silu conversion * Update new test cases according to torch.compile interface change * Simplify unary test cases, reuse impl code * Update test_stat import statment * Add group_norm conversion * Try convert layer_norm but the result is different * Support repeat * Update trinary tests * Support concat * Support split * Update format * [wip] group_norm * group_norm use customized replacement transformer method not work because it replace native_group_norm->getitem as ttnn.group_norm pattern replacement not work because it use symbolic trace, which using proxy, not value, however, we need to do the value dependent conversion * Add more test_group_norm * Add more test_layer_norm * Remove unused patterns/norm.py * refactor, rename, add comment * Support x.t() * move layer_norm impl to customized_replace * updatfor ttnn version 3023ec0f7 * Refactor test case Don't know why but if I reuse some code in test case, wrap it as a function, it will fail if vscode run more than 8 cases at the same time. It is really strange. I currently have no idea how it happen, I can just refactor the case to dup code. * Merge .gitignore * Resolve conflicts in README.md * Use incoming test_fall_back.py * Use our tests/tools/test_stats.py * Resolve generate_report.py * Preserve intermediate node meta for composite ops and add checks in respective tests (#50) Co-authored-by: Artem Yerofieiev <[email protected]> * Resolve torch_ttnn/__init__.py * Resolve confict * Remove duplicate entries from .gitignore * Update metadata in setup.py * Correct the name of the test module in test_datamove.py * Remove duplicate test for softmax * Fix merge errors, now binary test passed. * Remove test_if.py as `if` is already tested by lowering/misc/test_if.py * Remove test_only_add_matmul.py, superseded by lowering/matmul/test_only_add_matmul.py * Test group_norm * Test torch.matmul -> ttnn.matmul * Test compiling torch.nn.functional.linear * Refactor test for CSE * Remove test_norm.py as we've done its migration * This test no longer tests falling back to torch op but division, which should be handled in lowering/eltwise/binary/test_div.py instead * Convert tests for unary eltwise ops * Fix loading pytest arguments * Convert tests for binary eltwise ops * Fix precision test for bivariate ops * Fix precision test for univariate ops * Remove test_pointwise_trinary.py for flexibility Trivariate functions differ too much to share the same testing protocol * Test torch.reshape -> ttnn.reshape * Test compiling aten.repeat * Test compiling torch.cat * Remove test_datamove.py because all its tests have been moved to lowering/tensor_manipulation/ * Remove test already covered by test_resnet.py * Use more descriptive names in torch_ttnn/patterns/add.py Co-authored-by: Artem Yerofieiev <[email protected]> * Use more descriptive names in torch_ttnn/patterns/addcdiv.py Co-authored-by: Artem Yerofieiev <[email protected]> * Simpler path to resolve ttnn.addcdiv Co-authored-by: Artem Yerofieiev <[email protected]> * Make test names unique for easier batch testing with pytest * Fix import target_wrappers * Move code in add_coreops_pass.py to add_data_move_pass.py first to help refactoring * Refactor lowering univariate functions * Simplify control flow in lowering * Refactor lowering bivariate functions * Sort ops to lower * Lower to ttnn.atan2 * Lower to ttnn.leaky_relu * Lower default hardtanh to ttnn.clip for now * Lower addcdiv and addcmul * Remove migrated lowering code * Fix names of ttnn ops * Remove duplicate entries in the op list * Remove unused pattern file * Remove the file that is already merged into add_data_move_pass.py * Test broadcasting for bivariate ops * Test broadcasting for all bivariate ops * Remove intermediate test for `div` so we can test broadcastinig * Regroup ops based on API docs https://docs.tenstorrent.com/tt-metal/latest/ttnn/ttnn/api.html * Remove ops not working: - Comparison ops - Logical ops - aten.pow.Tensor_Tensor * Apply @kevinwuTT's patch on tt.repeat at model teardown * Format the code with `black` for a consistent style * Reformat with the specified config * Mark tests xfail based on #64 * Remove test for unsupported pow(tensor, tensor) * Mark broadcasting issues (#64) with atan2 and xlogy * Reformat the code * Mark broadcasting issues (#64) with (min|max)imum * Mark broadcasting issues (#64) with subtraction * Mark numerical issues with atan2 * Try setting realistic tolerance for low precision math ops * Tolerate more with pointwise unary math ops * Reflect that we convert torch.hardtanh to ttnn.clip for now * Remove test for unsupported group norm * Mark conversion failure with `linear` * Fix test_clone.py for the patch * Mark argument mismtach in `arange` (#65) * Link #66 to test_linear.py * Mark lowering issues with tensor manipulation (#67) * Reciprocal needs an offset because it has a pole at 0 * More tolerance for matmul for its accumulated error * Symetrically mark xfail for (min|max)imum * Merge human-made docs from README.md to docs/README.md.in * Do not use braces for shell variables to avoid clashing with .format * Generate README.md with new metrics * Mark xfail for xlogy involving broadcasting xlogy asserts the same size for inputs for now --------- Co-authored-by: swimdi <[email protected]> Co-authored-by: yoco <[email protected]> Co-authored-by: Zahid Wakeel <[email protected]> Co-authored-by: Artem Yerofieiev <[email protected]> Co-authored-by: yoco <[email protected]>
Brief