Skip to content

Commit 36b09e3

Browse files
authored
ROCm OCP FP8 Support (#1677)
* document ROCm OCP F8 support * lint * Add AMD GPU model and gfx code documentation Add a comment documenting supported AMD GPU models and their corresponding LLVM gfx codes, including Navi4, MI300X, and MI350. * lint * Refactor MI300 float8 dtype detection using utility function Use is_MI300() utility function to simplify MI300 architecture detection for float8 dtypes * lint
1 parent 1d430bc commit 36b09e3

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

torchao/float8/config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
import torch
1313

14+
from torchao.utils import is_MI300
15+
1416
logger: logging.Logger = logging.getLogger()
1517

1618

@@ -52,7 +54,7 @@ class Float8TypeConfig:
5254
"""
5355
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
5456
55-
Currently, ROCm only supports fnuz variants.
57+
Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4.
5658
"""
5759

5860
# The preferred e4m3 type.
@@ -62,12 +64,9 @@ class Float8TypeConfig:
6264
e5m2_dtype = torch.float8_e5m2
6365

6466
def __post_init__(self):
65-
if torch.version.hip and torch.cuda.is_available():
66-
prop = torch.cuda.get_device_properties(0)
67-
MI300_ARCH = ("gfx940", "gfx941", "gfx942")
68-
if prop.gcnArchName.split(":")[0] in MI300_ARCH:
69-
self.e4m3_dtype = torch.float8_e4m3fnuz
70-
self.e5m2_dtype = torch.float8_e5m2fnuz
67+
if torch.version.hip and torch.cuda.is_available() and is_MI300():
68+
self.e4m3_dtype = torch.float8_e4m3fnuz
69+
self.e5m2_dtype = torch.float8_e5m2fnuz
7170

7271

7372
# User defined type for using the individual F8 type based on config

torchao/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,15 @@ def _torch_version_at_least(min_version):
606606
return is_fbcode() or version("torch") >= min_version
607607

608608

609+
# Supported AMD GPU Models and their LLVM gfx Codes:
610+
#
611+
# | AMD GPU Model | LLVM gfx Code |
612+
# |---------------|------------------------|
613+
# | Navi4 | gfx1200, gfx1201 |
614+
# | MI300X | gfx940, gfx941, gfx942 |
615+
# | MI350 | gfx950 |
616+
617+
609618
def is_MI300():
610619
if torch.cuda.is_available() and torch.version.hip:
611620
mxArchName = ["gfx940", "gfx941", "gfx942"]
@@ -616,6 +625,22 @@ def is_MI300():
616625
return False
617626

618627

628+
def is_MI350():
629+
if torch.cuda.is_available() and torch.version.hip:
630+
archName = torch.cuda.get_device_properties(0).gcnArchName
631+
if "gfx950" in archName:
632+
return True
633+
return False
634+
635+
636+
def is_Navi4():
637+
if torch.cuda.is_available() and torch.version.hip:
638+
archName = torch.cuda.get_device_properties(0).gcnArchName
639+
if "gfx1200" or "gfx1201" in archName:
640+
return True
641+
return False
642+
643+
619644
def is_sm_at_least_89():
620645
return (
621646
torch.cuda.is_available()

0 commit comments

Comments
 (0)