Skip to content

Commit 7d110e2

Browse files
authored
mx: temporarily disable the rceil tests (#1977)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 42e1345 commit 7d110e2

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def test_some_zeros(elem_dtype):
105105
_test_mx(data, elem_dtype, block_size)
106106

107107

108+
# TODO(future PR): fix and reenable this test
109+
@pytest.mark.skip(reason="does not pass on B200 yet")
108110
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
109111
def test_to_mx_rceil():
110112
# nan
@@ -119,7 +121,9 @@ def test_to_mx_rceil():
119121
dtype=torch.uint32,
120122
).view(torch.float32)
121123
# fmt: on
122-
ground_truth_scale = torch.tensor([255], dtype=torch.uint8)
124+
ground_truth_scale = torch.tensor([255], dtype=torch.uint8).view(
125+
torch.float8_e8m0fnu
126+
)
123127
# fmt: off
124128
ground_truth_fp8 = torch.tensor(
125129
[
@@ -149,7 +153,7 @@ def test_to_mx_rceil():
149153
dtype=torch.uint32,
150154
).view(torch.float32)
151155
# fmt: on
152-
ground_truth_scale = torch.tensor([0], dtype=torch.uint8)
156+
ground_truth_scale = torch.tensor([0], dtype=torch.uint8).view(torch.float8_e8m0fnu)
153157
ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view(
154158
torch.float8_e4m3fn
155159
)
@@ -170,7 +174,7 @@ def test_to_mx_rceil():
170174
dtype=torch.uint16,
171175
).view(torch.bfloat16)
172176
# fmt: on
173-
ground_truth_scale = torch.tensor([0], dtype=torch.uint8)
177+
ground_truth_scale = torch.tensor([0], dtype=torch.uint8).view(torch.float8_e8m0fnu)
174178
ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view(
175179
torch.float8_e4m3fn
176180
)
@@ -191,7 +195,9 @@ def test_to_mx_rceil():
191195
dtype=torch.uint32,
192196
).view(torch.float32)
193197
# fmt: on
194-
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
198+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8).view(
199+
torch.float8_e8m0fnu
200+
)
195201
# fmt: off
196202
ground_truth_fp8 = torch.tensor(
197203
[
@@ -220,7 +226,9 @@ def test_to_mx_rceil():
220226
dtype=torch.uint16,
221227
).view(torch.bfloat16)
222228
# fmt: on
223-
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
229+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8).view(
230+
torch.float8_e8m0fnu
231+
)
224232
# fmt: off
225233
ground_truth_fp8 = torch.tensor(
226234
[
@@ -239,7 +247,7 @@ def test_to_mx_rceil():
239247
torch.testing.assert_close(data_mx._data, ground_truth_fp8)
240248
# zero
241249
data_hp = torch.tensor([0] * 32, dtype=torch.uint32).view(torch.float32)
242-
ground_truth_scale = torch.tensor([0], dtype=torch.uint8)
250+
ground_truth_scale = torch.tensor([0], dtype=torch.uint8).view(torch.float8_e8m0fnu)
243251
ground_truth_fp8 = torch.tensor([0] * 32, dtype=torch.uint8).view(
244252
torch.float8_e4m3fn
245253
)
@@ -260,7 +268,9 @@ def test_to_mx_rceil():
260268
dtype=torch.uint32,
261269
).view(torch.float32)
262270
# fmt: on
263-
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
271+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8).view(
272+
torch.float8_e8m0fnu
273+
)
264274
# fmt: off
265275
ground_truth_fp8 = torch.tensor(
266276
[
@@ -289,7 +299,9 @@ def test_to_mx_rceil():
289299
dtype=torch.uint16,
290300
).view(torch.bfloat16)
291301
# fmt: on
292-
ground_truth_scale = torch.tensor([119], dtype=torch.uint8)
302+
ground_truth_scale = torch.tensor([119], dtype=torch.uint8).view(
303+
torch.float8_e8m0fnu
304+
)
293305
# fmt: off
294306
ground_truth_fp8 = torch.tensor(
295307
[

0 commit comments

Comments
 (0)