Skip to content

[ROCm] preshuffled weight mm #1702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Apr 8, 2025
Merged

[ROCm] preshuffled weight mm #1702

merged 17 commits into from
Apr 8, 2025

Conversation

jeffdaily
Copy link
Contributor

Adds SwizzleTensor subclass that wraps a Tensor and reorders the contents to be suitable for HIPBLASLT_ORDER_COL16_4R8. SwizzleTensor intercepts torch.mm and replaces with custom calls to hipblaslt.

Copy link

pytorch-bot bot commented Feb 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1702

Note: Links to docs will display an error until the docs builds have been completed.

❌ 29 New Failures

As of commit 8b57424 with merge base 70fc520 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 11, 2025
@jeffdaily jeffdaily marked this pull request as draft February 11, 2025 20:55
from torchao.swizzle.swizzle_tensor import SwizzleTensor

aten = torch.ops.aten
SWIZZLE_OPS_TABLE: Dict[Any, Any] = {}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how you use this, won't you be able to use the builtin version of this: https://pytorch.org/docs/stable/library.html#torch.library.register_torch_dispatch

@facebook-github-bot
Copy link
Contributor

@jcaip has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jeffdaily jeffdaily marked this pull request as ready for review March 11, 2025 20:18
@facebook-github-bot
Copy link
Contributor

@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@jeffdaily
Copy link
Contributor Author

Note that other extensions aren't compiling currently but weren't added by this PR. I've had to modify conflicts in setup.py now that certain extensions are being built for rocm target. swizzle extension if compiled by itself is building and working fine.

@jeffdaily
Copy link
Contributor Author

Had to apply this diff to this PR to get setup.py to work on rocm. Comment out sparse_marlin and other extension source files and only leave swizzle alone.

diff --git a/setup.py b/setup.py
index c181267f..3d10de5c 100644
--- a/setup.py
+++ b/setup.py
@@ -350,15 +350,15 @@ def get_extensions():
     extensions_hip_dir = os.path.join(
         extensions_dir, "cuda", "tensor_core_tiled_layout"
     )
-    hip_sources = list(
-        glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
-    )
-    extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
-    hip_sources += list(
-        glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
-    )
+    #hip_sources = list(
+    #    glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
+    #)
+    #extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
+    #hip_sources += list(
+    #    glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
+    #)
     extensions_hip_dir = os.path.join(extensions_dir, "rocm")
-    hip_sources += list(
+    hip_sources = list(
         glob.glob(os.path.join(extensions_hip_dir, "**/*.hip"), recursive=True)
     )
     hip_sources += list(

@facebook-github-bot
Copy link
Contributor

@mxz297 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@mxz297 mxz297 self-requested a review April 8, 2025 20:00
@mxz297 mxz297 merged commit 5abaa35 into pytorch:main Apr 8, 2025
7 of 36 checks passed
@airMeng
Copy link
Collaborator

airMeng commented Apr 9, 2025

seems the PR breaks ruff lint checks and test-mps-ops (macos-m1-stable)

@mxz297 @jeffdaily

jerryzh168 added a commit that referenced this pull request Apr 9, 2025
@jerryzh168
Copy link
Contributor

reverting, please fix ruff before landing: https://hud.pytorch.org/pr/pytorch/ao/1702#40204953928

jerryzh168 added a commit that referenced this pull request Apr 9, 2025
Revert "[ROCm] preshuffled weight mm (#1702)"

This reverts commit 5abaa35.
@mxz297
Copy link

mxz297 commented Apr 9, 2025

@jeffdaily it seems like in test-mps-ops (macos-m1-stable), we still have

copying torchao/csrc/rocm/swizzle/swizzle.cpp -> build/lib.macosx-11.1-arm64-cpython-312/torchao/csrc/rocm/swizzle

swizzle source file used for compilation?

The ruff code analysis failures also needs to be addressed.

@jeffdaily
Copy link
Contributor Author

New PR here #2044

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: rocm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants