Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm OCP FP8 Support #1677

Merged
merged 6 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchao/dtypes/uintx/marlin_qqq_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __tensor_unflatten__(
def get_plain(self):
from torchao.quantization.marlin_qqq import (
unpack_from_marlin_qqq,
) # avoid circular import
)

int_data_expanded, s_group_expanded, s_channel_expanded = (
unpack_from_marlin_qqq(
Expand Down Expand Up @@ -207,7 +207,7 @@ def from_plain(
from torchao.quantization.marlin_qqq import (
const,
pack_to_marlin_qqq,
) # avoid circular import
)

assert isinstance(_layout, MarlinQQQLayout)

Expand Down
4 changes: 2 additions & 2 deletions torchao/dtypes/uintx/marlin_sparse_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def __tensor_unflatten__(
def get_plain(self):
from torchao.sparsity.marlin import (
unpack_from_marlin_24,
) # avoid circular import
)

int_data_expanded, scales_expanded = unpack_from_marlin_24(
self.int_data,
Expand All @@ -220,7 +220,7 @@ def from_plain(
from torchao.sparsity.marlin import (
const,
pack_to_marlin_24,
) # avoid circular import
)

assert isinstance(_layout, MarlinSparseLayout)

Expand Down
13 changes: 6 additions & 7 deletions torchao/float8/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from torchao.utils import is_MI300

logger: logging.Logger = logging.getLogger()


Expand Down Expand Up @@ -58,7 +60,7 @@ class Float8TypeConfig:
"""
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.

Currently, ROCm only supports fnuz variants.
Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4.
"""

# The preferred e4m3 type.
Expand All @@ -68,12 +70,9 @@ class Float8TypeConfig:
e5m2_dtype = torch.float8_e5m2

def __post_init__(self):
if torch.version.hip and torch.cuda.is_available():
prop = torch.cuda.get_device_properties(0)
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz
if torch.version.hip and torch.cuda.is_available() and is_MI300():
self.e4m3_dtype = torch.float8_e4m3fnuz
self.e5m2_dtype = torch.float8_e5m2fnuz


# User defined type for using the individual F8 type based on config
Expand Down
27 changes: 26 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,16 +604,41 @@ def _torch_version_at_least(min_version):
return is_fbcode() or version("torch") >= min_version


# Supported AMD GPU Models and their LLVM gfx Codes:
#
# | AMD GPU Model | LLVM gfx Code |
# |---------------|------------------------|
# | Navi4 | gfx1200, gfx1201 |
# | MI300X | gfx940, gfx941, gfx942 |
# | MI350 | gfx950 |


def is_MI300():
if torch.cuda.is_available() and torch.version.hip:
mxArchName = ["gfx940", "gfx941", "gfx942"]
archName = torch.cuda.get_device_properties().gcnArchName
archName = torch.cuda.get_device_properties(0).gcnArchName
for arch in mxArchName:
if arch in archName:
return True
return False


def is_MI350():
if torch.cuda.is_available() and torch.version.hip:
archName = torch.cuda.get_device_properties(0).gcnArchName
if "gfx950" in archName:
return True
return False


def is_Navi4():
if torch.cuda.is_available() and torch.version.hip:
archName = torch.cuda.get_device_properties(0).gcnArchName
if "gfx1200" or "gfx1201" in archName:
return True
return False


def is_sm_at_least_89():
return (
torch.cuda.is_available()
Expand Down
Loading