We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 8fd16a6 commit 26fe1f1Copy full SHA for 26fe1f1
test/prototype/test_low_bit_optim.py
@@ -386,8 +386,10 @@ def world_size(self) -> int:
386
not TORCH_VERSION_AT_LEAST_2_5, reason="PyTorch>=2.5 is required."
387
)
388
@skip_if_lt_x_gpu(_FSDP_WORLD_SIZE)
389
- @skip_if_rocm("ROCm enablement in progress")
390
def test_fsdp2(self):
+ if torch.version.hip is not None:
391
+ pytest.skip("ROCm enablement in progress")
392
+
393
optim_classes = [low_bit_optim.AdamW8bit, low_bit_optim.AdamW4bit]
394
if torch.cuda.get_device_capability() >= (8, 9):
395
optim_classes.append(low_bit_optim.AdamWFp8)
0 commit comments