File tree Expand file tree Collapse file tree 2 files changed +31
-7
lines changed Expand file tree Collapse file tree 2 files changed +31
-7
lines changed Original file line number Diff line number Diff line change 11
11
12
12
import torch
13
13
14
+ from torchao .utils import is_MI300
15
+
14
16
logger : logging .Logger = logging .getLogger ()
15
17
16
18
@@ -52,7 +54,7 @@ class Float8TypeConfig:
52
54
"""
53
55
Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz.
54
56
55
- Currently, ROCm only supports fnuz variants.
57
+ Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4 .
56
58
"""
57
59
58
60
# The preferred e4m3 type.
@@ -62,12 +64,9 @@ class Float8TypeConfig:
62
64
e5m2_dtype = torch .float8_e5m2
63
65
64
66
def __post_init__ (self ):
65
- if torch .version .hip and torch .cuda .is_available ():
66
- prop = torch .cuda .get_device_properties (0 )
67
- MI300_ARCH = ("gfx940" , "gfx941" , "gfx942" )
68
- if prop .gcnArchName .split (":" )[0 ] in MI300_ARCH :
69
- self .e4m3_dtype = torch .float8_e4m3fnuz
70
- self .e5m2_dtype = torch .float8_e5m2fnuz
67
+ if torch .version .hip and torch .cuda .is_available () and is_MI300 ():
68
+ self .e4m3_dtype = torch .float8_e4m3fnuz
69
+ self .e5m2_dtype = torch .float8_e5m2fnuz
71
70
72
71
73
72
# User defined type for using the individual F8 type based on config
Original file line number Diff line number Diff line change @@ -606,6 +606,15 @@ def _torch_version_at_least(min_version):
606
606
return is_fbcode () or version ("torch" ) >= min_version
607
607
608
608
609
+ # Supported AMD GPU Models and their LLVM gfx Codes:
610
+ #
611
+ # | AMD GPU Model | LLVM gfx Code |
612
+ # |---------------|------------------------|
613
+ # | Navi4 | gfx1200, gfx1201 |
614
+ # | MI300X | gfx940, gfx941, gfx942 |
615
+ # | MI350 | gfx950 |
616
+
617
+
609
618
def is_MI300 ():
610
619
if torch .cuda .is_available () and torch .version .hip :
611
620
mxArchName = ["gfx940" , "gfx941" , "gfx942" ]
@@ -616,6 +625,22 @@ def is_MI300():
616
625
return False
617
626
618
627
628
+ def is_MI350 ():
629
+ if torch .cuda .is_available () and torch .version .hip :
630
+ archName = torch .cuda .get_device_properties (0 ).gcnArchName
631
+ if "gfx950" in archName :
632
+ return True
633
+ return False
634
+
635
+
636
+ def is_Navi4 ():
637
+ if torch .cuda .is_available () and torch .version .hip :
638
+ archName = torch .cuda .get_device_properties (0 ).gcnArchName
639
+ if "gfx1200" or "gfx1201" in archName :
640
+ return True
641
+ return False
642
+
643
+
619
644
def is_sm_at_least_89 ():
620
645
return (
621
646
torch .cuda .is_available ()
You can’t perform that action at this time.
0 commit comments