forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_sparse_semi_structured.py
1244 lines (1043 loc) · 51.7 KB
/
test_sparse_semi_structured.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Owner(s): ["module: sparse"]
# ruff: noqa: F841
import itertools
import random
import unittest
import torch
from torch import nn
import torch.nn.functional as F
from torch.sparse import (
SparseSemiStructuredTensor,
SparseSemiStructuredTensorCUSPARSELT,
SparseSemiStructuredTensorCUTLASS,
to_sparse_semi_structured,
)
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
_sparse_semi_structured_tile,
_compute_compressed_swizzled_bitmask,
)
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import _get_torch_cuda_version, PLATFORM_SUPPORTS_FP8, xfailIfSM89
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_dtype import all_types_and_complex
import torch._dynamo.test_case
from torch.testing._internal.common_utils import (
parametrize,
run_tests,
subtest,
TestCase,
TEST_WITH_ROCM,
IS_WINDOWS,
)
from torch.testing._internal.inductor_utils import HAS_GPU
import pytest
SEMI_STRUCTURED_SUPPORTED_BACKENDS = dict()
_IS_SM8X = False
_IS_SM9X = False
if torch.cuda.is_available():
_IS_SM8X = torch.cuda.get_device_capability(0)[0] == 8
_IS_SM9X = torch.cuda.get_device_capability(0)[0] == 9
# CUTLASS kernels only work for Ampere
if _IS_SM8X:
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cutlass"] = SparseSemiStructuredTensorCUTLASS
# add cuSPASRELt tests if available
if torch.backends.cusparselt.is_available() and (_IS_SM8X or _IS_SM9X):
SEMI_STRUCTURED_SUPPORTED_BACKENDS["cusparselt"] = SparseSemiStructuredTensorCUSPARSELT
inference_dtypes = dtypes(torch.float16, torch.bfloat16, torch.int8)
training_dtypes = dtypes(torch.float16, torch.bfloat16)
parametrize_backends = parametrize("backend", SEMI_STRUCTURED_SUPPORTED_BACKENDS)
atol_rtol_kw = {
torch.float16: {
"rtol": 1e-3,
"atol": 1e-3,
},
torch.bfloat16: {
"rtol": 1e-1,
"atol": 1e-1,
},
}
def sparse24_largest_mask_2d(original):
sparse = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(original)
return sparse.to_dense().bool()
def sparsify24_dense(original):
return sparse24_largest_mask_2d(original) * original
def rand_sparse_semi_structured_mask(
r, c, dtype=torch.float16, device="cuda", choice=None
):
"""
This function returns a 1:2 sparse matrix of size (r, c).
Note that this means this matrix will also be 2:4 and 4:8 sparse as well.
"""
choices = [[0, 1], [1, 0]]
mask_entries = [choice or random.choice(choices) for i in range(r * c // 2)]
return (
torch.tensor(mask_entries, dtype=dtype, device=device)
.reshape(r, c)
.contiguous()
)
def rand_sparse_semi_structured(r, c, dtype, device, choice=None):
pattern = '2by4' if dtype != torch.float32 else '1by2'
if pattern == '1by2':
ksparse = 2
choices = [
[0, 1],
[1, 0]
]
elif pattern == '2by4':
ksparse = 4
choices = [
[1, 1, 0, 0],
[1, 0, 1, 0],
[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 1, 0, 1],
[0, 0, 1, 1]
]
mask_entries = [choice or random.choice(choices) for i in range(r * c // ksparse)]
mask = torch.tensor(mask_entries, dtype=torch.bool).view(r, c).to(device)
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask applied.
dense = dense.masked_fill(~mask, 0)
return dense
def rand_sparse_semi_structured_all_patterns(r, c, dtype, device):
pattern = '2by4' if dtype != torch.float32 else '1by2'
if pattern == '1by2':
ksparse = 2
choices = [
[[0, 0], [0, 1]],
[[0, 1], [0, 1]],
[[1, 0], [1, 0]],
[[1, 1], [1, 0]]
]
elif pattern == '2by4':
ksparse = 4
choices = [
[[0, 0, 0, 0], [0, 0, 1, 1]],
[[0, 0, 0, 1], [0, 0, 1, 1]],
[[0, 0, 1, 0], [0, 0, 1, 1]],
[[0, 0, 1, 1], [0, 0, 1, 1]],
[[0, 1, 0, 0], [0, 1, 1, 0]],
[[0, 1, 0, 1], [0, 1, 0, 1]],
[[0, 1, 1, 0], [0, 1, 1, 0]],
[[0, 1, 1, 1], [0, 1, 0, 1]],
[[1, 0, 0, 0], [1, 0, 1, 0]],
[[1, 0, 0, 1], [1, 0, 0, 1]],
[[1, 0, 1, 0], [1, 0, 1, 0]],
[[1, 0, 1, 1], [1, 0, 0, 1]],
[[1, 1, 0, 0], [1, 1, 0, 0]],
[[1, 1, 0, 1], [1, 1, 0, 0]],
[[1, 1, 1, 0], [1, 1, 0, 0]],
[[1, 1, 1, 1], [1, 1, 0, 0]],
]
mask_rows = [random.randint(0, len(choices) - 1) for i in range(r * c // ksparse)]
COL_INV, COL_VAL = 0, 1
mask_entries_inv = [choices[i][COL_INV] for i in mask_rows]
mask_entries_val = [choices[i][COL_VAL] for i in mask_rows]
mask_inv = torch.tensor(mask_entries_inv, dtype=torch.bool).view(r, c).to(device)
mask_val = torch.tensor(mask_entries_val, dtype=torch.bool).view(r, c).to(device)
dense = make_tensor(r, c, dtype=dtype, device=device)
dense[dense == 0] = 1 # To prevent zeros except where mask below applied.
dense_inv = dense.masked_fill(~mask_inv, 0)
dense_val = dense_inv.masked_fill(~mask_val, 0)
return dense_inv, dense_val
class SparseSemiStructuredTensorCompileTest(torch._dynamo.test_case.TestCase):
def setUp(self):
if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
self.skipTest('semi-structured sparsity has no available backend!')
super().setUp()
def tearDown(self):
super().tearDown()
@staticmethod
def _test_mlp_contiguous_relu_compile(backend, dense_input_shape):
"""
Test nn.Linear + .contiguous() + nn.ReLU with SparseSemiStructuredTensor + torch.compile
We expect:
(1) The sparse tensor subclass should turn nn.Linear into `aten._structured_sparse_addmm` + `aten.contiguous()`
(2) Inductor should fuse the .contiguous() call into the relu
"""
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(128, 128)
def forward(self, x):
x = self.linear(x)
x = x.contiguous()
return torch.nn.functional.relu(x)
input = torch.rand(dense_input_shape, device="cuda").half()
model = Model().eval().cuda().half()
mod_linear = model.linear
m, n = mod_linear.weight.shape
mask = torch.Tensor([1, 0, 0, 1]).tile((m, n // 4)).bool().cuda()
# set masked weight
mod_linear.weight = nn.Parameter(mod_linear.weight * mask)
dense_result = model(input)
mod_linear.weight = nn.Parameter(SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].from_dense(mod_linear.weight))
sparse_result = model(input)
model = torch.compile(model, backend="inductor", fullgraph=True)
sparse_compile_result = model(input)
# test that sparse_compile_result and dense_result are numerically close
torch.testing.assert_close(dense_result, sparse_compile_result, rtol=1e-3, atol=1e-3)
# assert sparse and sparse_compile have the same strides,
# as meta registrations may return contiguous tensors when the output is transposed
# https://github.com/pytorch/pytorch/pull/114477
assert sparse_result.stride() == sparse_compile_result.stride()
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_mlp_contiguous_relu_compile_cusparselt(self):
"""
test for cuSPASRELt meta registrations (_cslt_sparse_mm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cusparselt", dense_input_shape)
@unittest.skipIf("cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cutlass not supported on this machine")
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
def test_mlp_contiguous_relu_compile_cutlass(self):
"""
test for CUTLASS meta registrations (_sparse_semi_structured_addmm) + torch.compile
"""
for dense_input_shape in [(1, 128), (64, 128), (128, 128), (64, 128, 128)]:
SparseSemiStructuredTensorCompileTest._test_mlp_contiguous_relu_compile("cutlass", dense_input_shape)
@unittest.skipIf(IS_WINDOWS, "torch.compile not supported on windows")
@unittest.skipIf("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS, "cusparselt not supported on this machine")
def test_sp24_compile(self) -> None:
x = torch.randn([1024, 512], device="cuda", dtype=torch.float16, requires_grad=True)
def fn(x):
y = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(x)
y = y.t()
return x @ y
# Eager
output = fn(x)
output.backward(output)
# Torch compile
output = torch.compile(fn)(x)
output.backward(output)
class TestSparseSemiStructured(TestCase):
def setUp(self):
if len(SEMI_STRUCTURED_SUPPORTED_BACKENDS) == 0:
self.skipTest('semi-structured sparsity has no available backend!')
if IS_WINDOWS:
self.skipTest("torch.compile not supported on windows")
@inference_dtypes
@parametrize_backends
def test_to_sparse_semi_structured(self, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
assert A.shape == A_sparse.shape
assert A.device == A_sparse.device
assert A.dtype == A_sparse.dtype
assert isinstance(A, torch.Tensor)
assert isinstance(A_sparse, SparseSemiStructuredTensor)
@inference_dtypes
@parametrize_backends
@parametrize("dense_input_shape", [(128, 1), (128, 64), (128, 128)])
def test_mm_sparse_first_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B) is correct for float16 and will throw error for int8
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B)
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B)
else:
dense_result = torch.mm(A, B)
sparse_result = torch.mm(A_sparse, B)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize_backends
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
def test_mm_sparse_first_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A_sparse, B.t()) is correct for float16/bfloat16
and will throw an error for int8 + padding
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
A = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8 and dense_input_shape in {(1, 128)}:
# padding with int8 throws an error because transposing B yields a contiguous output
# and row-row 2:4 sparse @ dense with NN is not supported by cuSPARSELt or CUTLASS.
if backend == "cutlass":
with self.assertRaisesRegex(RuntimeError, "spgemm_cutlass_dispatch_layouts"):
sparse_result = torch.mm(A_sparse, B.t())
else:
with self.assertRaisesRegex(RuntimeError,
"CUDA error: operation not supported when calling `cusparseLtMatmulDescriptorInit"):
sparse_result = torch.mm(A_sparse, B.t())
elif dtype is torch.int8:
# test transpose
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A_sparse, B.t())
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
else:
# test transpose
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A_sparse, B.t())
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize_backends
def test_mm_sparse_first_TN(self, dtype, dense_input_shape, device, backend):
"""
Ensure torch.mm(A_sparse.t(), B) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
A = rand_sparse_semi_structured_mask(128, 256, dtype=dtype)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand(dense_input_shape, device=A_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
):
torch.mm(A_sparse.t(), B)
@inference_dtypes
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize_backends
def test_mm_sparse_second_NT(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse.t()) is correct
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
# Currently we don't support int matmul on GPU, so evaluate on CPU and copy over
if dtype is torch.int8:
dense_result = torch.mm(A.cpu(), B.t().cpu()).to(device, dtype=torch.int8)
sparse_result = torch.mm(A, B_sparse.t())
else:
dense_result = torch.mm(A, B.t())
sparse_result = torch.mm(A, B_sparse.t())
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128)])
@parametrize_backends
def test_mm_sparse_second_NN(self, dense_input_shape, dtype, device, backend):
"""
Ensure torch.mm(A, B_sparse) throws error
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
B = rand_sparse_semi_structured_mask(256, 128, dtype=dtype)
B_sparse = to_sparse_semi_structured(B)
A = torch.rand(dense_input_shape, device=B_sparse.device).to(dtype)
with self.assertRaisesRegex(
NotImplementedError,
r"`SparseSemiStructuredTensor.*` matmul: operation is not supported",
):
sparse_result = torch.mm(A, B_sparse)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize("inference_mode", [subtest(True), subtest(False)])
@parametrize_backends
def test_linear(self, dense_input_shape, inference_mode, device, backend):
"""
Test nn.Linear has the same numerics
"""
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
input = torch.rand((dense_input_shape), device=device).half()
model = nn.Linear(128, 256).to(device).half()
m, n = model.weight.shape
mask = rand_sparse_semi_structured_mask(m, n, device=device, dtype=torch.bool)
# set masked weight
model.weight = nn.Parameter(model.weight * mask)
dense_result = model(input)
model.weight = nn.Parameter(to_sparse_semi_structured(model.weight))
if inference_mode:
with torch.inference_mode():
sparse_result = model(input)
else:
sparse_result = model(input)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize("dense_input_shape", [(1, 128), (64, 128), (128, 128), (64, 128, 128)])
@parametrize_backends
def test_mlp(self, device, dense_input_shape, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
input = torch.rand(dense_input_shape, device=device).half()
model = (
nn.Sequential(
nn.Linear(128, 256),
nn.Linear(256, 128),
)
.half()
.to(device)
)
for i in range(2):
m, n = model[i].weight.shape
mask = rand_sparse_semi_structured_mask(
m, n, device=device, dtype=torch.bool
)
# set masked weight
model[i].weight = nn.Parameter(model[i].weight * mask)
dense_result = model(input)
for i in range(2):
model[i].weight = nn.Parameter(to_sparse_semi_structured(model[i].weight))
sparse_result = model(input)
torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3)
@parametrize_backends
def test_values(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A)
assert A_sparse.values().shape == (128, 64)
assert (A_sparse.values() == 1).all()
@parametrize_backends
def test_indices(self, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
A = rand_sparse_semi_structured_mask(128, 128)
A_sparse = to_sparse_semi_structured(A)
assert A_sparse.indices().shape == (128, 8)
@inference_dtypes
@parametrize_backends
def test_min_sparse_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
config = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS[dtype]
A = rand_sparse_semi_structured_mask(config.sparse_min_rows, config.sparse_min_cols, dtype=dtype, device=device)
A_sparse = to_sparse_semi_structured(A)
B = torch.rand((config.sparse_min_cols, config.dense_min_cols), device=device).to(dtype)
if dtype == torch.int8:
dense_res = torch.mm(A.cpu(), B.cpu()).to(device, dtype=torch.int8)
# int8 sparse matmul not supported for R/R -> R layout, so we transpose one of the arguments to get R/C -> R
B_t = B.t().contiguous()
sparse_res = torch.mm(A_sparse, B_t.t())
else:
dense_res = torch.mm(A, B)
sparse_res = torch.mm(A_sparse, B)
torch.testing.assert_close(sparse_res, dense_res, rtol=1e-3, atol=1e-3)
@inference_dtypes
@parametrize_backends
def test_unsupported_shape(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
A = rand_sparse_semi_structured_mask(2, 2, dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.shape"):
A_sparse = to_sparse_semi_structured(A)
@dtypes(*all_types_and_complex())
@parametrize_backends
def test_unsupported_dtype(self, dtype, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
A = rand_sparse_semi_structured_mask(128, 128, dtype=dtype, device=device)
if dtype not in SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend]._DTYPE_SHAPE_CONSTRAINTS:
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dtype"):
A_sparse = to_sparse_semi_structured(A)
else:
A_sparse = to_sparse_semi_structured(A)
@parametrize_backends
def test_unsupported_dim(self, device, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
A = torch.rand(128, 128, 128, device=device, dtype=torch.float16)
with self.assertRaisesRegex(RuntimeError, "Error original_tensor.dim"):
A_sparse = to_sparse_semi_structured(A)
def create_random_mask(shape) -> torch.Tensor:
r = random.Random(0)
mask = torch.zeros(shape, dtype=torch.bool)
for line in range(mask.shape[0]):
for col in range(0, mask.shape[1], 4):
sparsity = r.choice(
[
[False, False, True, True],
[False, True, False, True],
[True, False, False, True],
[False, True, True, False],
[True, False, True, False],
[True, True, False, False],
]
)
mask[line, col : col + 4] = torch.tensor(sparsity, dtype=torch.bool)
return mask
class TestSparseSemiStructuredTraining(TestCase):
def setUp(self):
if not _IS_SM8X:
self.skipTest("SparseSemiStructuredTensor training only supported on SM8x (Ampere)")
if IS_WINDOWS:
self.skipTest('CUTLASS not supported on windows')
@training_dtypes
def test_prune_dense_static_sort(self, dtype) -> None:
# Ideally we would like to clone and compare, but that won't work because the sorting order will be different
# instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
dense = torch.randn(128, 128, device="cuda", dtype=dtype)
pruned = _sparse_semi_structured_tile(dense)
# CUTLASS
reference_cutlass = SparseSemiStructuredTensorCUTLASS.prune_dense_static_sort(pruned, algorithm="largest_abs_values_greedy")
torch.testing.assert_close(pruned, reference_cutlass.to_dense())
packed_cutlass, meta_cutlass = sparse_semi_structured_from_dense_cutlass(pruned)
packed_t_cutlass, meta_t_cutlass = sparse_semi_structured_from_dense_cutlass(pruned.t().contiguous())
meta_cutlass = meta_cutlass.as_strided(reference_cutlass.meta.shape, reference_cutlass.meta.stride())
meta_t_cutlass = meta_t_cutlass.as_strided(reference_cutlass.meta_t.shape, reference_cutlass.meta_t.stride())
compressed_swizzled_bitmask = _compute_compressed_swizzled_bitmask(pruned)
compressed_swizzled_bitmask = compressed_swizzled_bitmask.as_strided(reference_cutlass.compressed_swizzled_bitmask.shape,
reference_cutlass.compressed_swizzled_bitmask.stride())
cutlass = SparseSemiStructuredTensorCUTLASS(dense.shape,
packed_cutlass,
meta_cutlass,
packed_t_cutlass,
meta_t_cutlass,
compressed_swizzled_bitmask)
torch.testing.assert_close(reference_cutlass.to_dense(), cutlass.to_dense())
# CUSPARSELT
reference_cusparselt = SparseSemiStructuredTensorCUSPARSELT.prune_dense_static_sort(pruned,
algorithm="largest_abs_values_greedy")
torch.testing.assert_close(pruned, reference_cusparselt.to_dense())
packed_cusparselt = torch._cslt_compress(pruned)
packed_t_cusparselt = torch._cslt_compress(pruned.t().contiguous())
cusparselt = SparseSemiStructuredTensorCUSPARSELT(dense.shape,
packed_cusparselt,
None,
packed_t_cusparselt,
None,
compressed_swizzled_bitmask)
torch.testing.assert_close(reference_cusparselt.to_dense(), cusparselt.to_dense())
@training_dtypes
@parametrize_backends
def test_pruning_algo_largest_abs_values_greedy(self, dtype, backend) -> None:
inp = torch.tensor(
[[4, 3, 2, 1], [-1, -3, 0.6, 0.5], [1, 2, 3, 4], [10, 2, -1, 5]],
device="cuda",
dtype=dtype,
)
inp = F.pad(inp, (0, 128 - 4, 0, 128 - 4), "constant", 1)
sInp = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(inp, algorithm="largest_abs_values_greedy")
mask = sInp.to_dense() / inp
assert mask[:4, :4].int().tolist() == [
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1],
[1, 0, 0, 1],
]
@training_dtypes
def test_gemm(self, dtype) -> None:
M, N, K = 32, 32, 64
a = torch.randn([M, K], device="cuda", dtype=dtype)
b = torch.randn([K, N], device="cuda", dtype=dtype)
mask = rand_sparse_semi_structured_mask(M, K, dtype=torch.bool)
a.masked_fill_(~mask, 0)
a_sparse = to_sparse_semi_structured(a)
masked_a = a * mask
ref_out = masked_a @ b
sp24_out = a_sparse @ b
torch.testing.assert_close(ref_out, sp24_out, **atol_rtol_kw[dtype])
@training_dtypes
@parametrize_backends
def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
M, N = 128, 256
# Construct x to make sure we always have exactly 8 elements per 4x4 tile
a = (4 * torch.arange(8))[:, None] + torch.arange(8)[None, :]
a = a.repeat(M // 8, N // 8)
assert a.shape == (M, N)
a = a.cuda().to(dtype)
b = torch.randn([a.shape[1], 128], device="cuda", dtype=dtype)
a_sparse = SEMI_STRUCTURED_SUPPORTED_BACKENDS[backend].prune_dense_static_sort(a)
mask_dense = sparse24_largest_mask_2d(a).to(dtype)
if backend == "cutlass":
assert isinstance(a_sparse, SparseSemiStructuredTensorCUTLASS)
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(
mask_dense, use_cutlass=True)
sparse_mask = SparseSemiStructuredTensorCUTLASS(
mask_dense.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
)
torch.testing.assert_close(a_sparse.meta.view(torch.short), sparse_mask.meta)
ref_gemm = (mask_dense * a) @ b
pack_gemm = a_sparse @ b
torch.testing.assert_close(ref_gemm, pack_gemm, **atol_rtol_kw[dtype])
@training_dtypes
def test_pack_both_ways_id(self, dtype) -> None:
N = 512
torch.manual_seed(0)
a = torch.randn([N, N], dtype=dtype, device="cuda")
b = torch.eye(N, dtype=dtype, device="cuda")
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[
:4
]
# Heuristic to ensure we pack the same values
torch.testing.assert_close(
packed.to(torch.float64).sum(), packed_t.to(torch.float64).sum()
)
mask_dense = sparse24_largest_mask_2d(a.to(dtype))
ref_gemm = mask_dense * a
# Test A@B
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed, meta).t()
max_diff = (ref_gemm - pack_gemm).abs().argmax()
torch.testing.assert_close(
ref_gemm, pack_gemm,
**atol_rtol_kw[dtype]
), f"packed is wrong at pos: ({max_diff // N}, {max_diff % N})"
# Test A.t@B
pack_gemm = torch._sparse_semi_structured_linear(b.t(), packed_t, meta_t)
max_diff = (ref_gemm - pack_gemm).abs().argmax()
torch.testing.assert_close(
ref_gemm, pack_gemm,
**atol_rtol_kw[dtype]
), f"packed_t is wrong at pos: ({max_diff // N}, {max_diff % N})"
@training_dtypes
def test_pack_both_ways_edge_case1(self, dtype) -> None:
# In this case, the heuristic will keep 7 values out of 16
# instead of 8. let's see how the kernel handles this
quad = torch.tensor(
[
[2, -1, -2, -3], # Should be packed as `2 <null>`
[-1, 8, -1, 6],
[-1, -1, 4, 5],
[-1, 3, 7, -1],
],
dtype=dtype,
device="cuda",
)
a = torch.randn([32, 64], dtype=dtype, device="cuda")
a[:4, :4] = quad
packed, meta, packed_t, meta_t = torch._sparse_semi_structured_tile(a)[:4]
# Check first line in A
assert packed[0, 0].item() == 2
assert packed[0, 1].item() == 0
# And first column in A.t
assert packed_t[0, 0].item() == 2
assert packed_t[0, 1].item() == 0
@training_dtypes
def test_sp24_apply(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
(
packed,
meta,
packed_t,
meta_t,
bitmask,
) = torch._sparse_semi_structured_tile(x)
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
torch.testing.assert_close(packed, packed2)
torch.testing.assert_close(packed_t, packed_t2)
@training_dtypes
def test_sp24_apply_dense(self, dtype) -> None:
M, N = 256, 1024
x = torch.randn([M, N], dtype=dtype, device="cuda")
(
packed,
meta,
packed_t,
meta_t,
bitmask,
) = torch._sparse_semi_structured_tile(x)
expected = SparseSemiStructuredTensorCUTLASS(
x.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
).to_dense()
packed2, packed_t2 = torch._sparse_semi_structured_apply(x, bitmask)
sparse = SparseSemiStructuredTensorCUTLASS(
x.shape,
packed=packed2,
meta=meta,
packed_t=packed_t2,
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
)
dense = torch._sparse_semi_structured_apply_dense(x, bitmask)
torch.testing.assert_close(dense, expected)
torch.testing.assert_close(sparse.to_dense(), expected)
@training_dtypes
def test_sp24_matmuls(self, dtype) -> None:
M, N, K = 64, 256, 1024
a = torch.randn([M, K], device="cuda", dtype=dtype)
b = torch.randn([K, N], device="cuda", dtype=dtype)
a_m = sparse24_largest_mask_2d(a)
b_m = sparse24_largest_mask_2d(b)
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(a)
a_s = SparseSemiStructuredTensorCUTLASS(
a.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
)
(packed, meta, packed_t, meta_t, bitmask) = torch._sparse_semi_structured_tile(b)
b_s = SparseSemiStructuredTensorCUTLASS(
b.shape,
packed=packed,
meta=meta,
packed_t=packed_t,
meta_t=meta_t,
compressed_swizzled_bitmask=bitmask,
)
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, rtol=1e-1, atol=1.5e-1)
torch.testing.assert_close(a @ b_s, a @ (b * b_m), rtol=1e-1, atol=1.5e-1)
torch.testing.assert_close(
a @ a_s.t(), a @ (a * a_m).t(), rtol=1e-1, atol=1.5e-1
)
torch.testing.assert_close(
a_s.t() @ a, (a * a_m).t() @ a, rtol=1e-1, atol=1e-1
)
def test_sp24_matmuls_mat_vec(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([128], device="cuda", dtype=torch.float16)
a_m = sparse24_largest_mask_2d(a)
a_s = to_sparse_semi_structured(a)
with pytest.raises(NotImplementedError):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
def test_sp24_matmuls_bmm(self) -> None:
a = torch.randn([64, 128], device="cuda", dtype=torch.float16)
b = torch.randn([5, 6, 128], device="cuda", dtype=torch.float16)
a_m = sparse24_largest_mask_2d(a)
a_s = to_sparse_semi_structured(a)
with pytest.raises(NotImplementedError):
torch.testing.assert_close(a_s @ b, (a * a_m) @ b, **atol_rtol_kw[a.dtype])
class TestSparseSemiStructuredCUTLASS(TestCase):
"""
This contains CUTLASS specific tests for
- torch._sparse_semi_structured_linear
"""
def setUp(self):
SparseSemiStructuredTensor._FORCE_CUTLASS = True
if "cutlass" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS:
self.skipTest('CUTLASS not enabled')
def tearDown(self):
SparseSemiStructuredTensor._FORCE_CUTLASS = False
super(self.__class__, self).tearDown()
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
@inference_dtypes
def test_linear_cutlass(self, device, dtype):
def run_test(batch_shape, m, n, k, device, dtype, dtype_out, add_bias, activation, rtol, atol):
weight = rand_sparse_semi_structured(m, k, dtype, device)
input = make_tensor((*batch_shape, n, k), dtype=dtype, device=device)
bias = make_tensor((m,), dtype=dtype_out, device=device) if add_bias else None
dtype_dense = torch.float32
input_dense = input.to(dtype_dense)
weight_dense = weight.to(dtype_dense)
bias_dense = bias.to(dtype_dense) if add_bias else None
output0 = torch.nn.functional.linear(input_dense, weight_dense, bias=bias_dense)
if activation == "relu":
relu = torch.nn.ReLU()
output0 = relu(output0)
elif activation == "silu":
silu = torch.nn.SiLU()
output0 = silu(output0)
compressed = to_sparse_semi_structured(weight)
weight_sparse = compressed.values()
meta = compressed.indices()
output1 = torch._sparse_semi_structured_linear(input, weight_sparse, meta, bias=bias, activation=activation,
out_dtype=dtype_out if dtype == torch.int8 else None)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:
# Inputs are converted to TF32 internally for sparse GEMM,
# so make dense GEMM to do the same for matching results.
orig = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
batch_shapes = [[], [3], [3, 1]]
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
activations = [None, "relu", "silu"]
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 5e-3, 5e-3
elif dtype == torch.float32:
rtol, atol = 1e-3, 75e-2
for batch_shape, m, n, k, add_bias, activation in \
itertools.product(batch_shapes, range(3), range(3), range(3), (False, True), activations):
if activation == "silu" and dtype == torch.int8:
continue # SiLU not supported for integer inputs
m = 2 ** m * 32
n = 2 ** n * 32
k = 2 ** k * 128
run_test(batch_shape, m, n, k, device, dtype, dtype_out[dtype], add_bias, activation, rtol, atol)
if dtype == torch.float32:
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(TEST_WITH_ROCM or IS_WINDOWS, "ROCm and Windows doesn't support CUTLASS")
@parametrize("backend", ["cutlass"])
@inference_dtypes
def test_sparse_semi_structured_ops_cutlass(self, device, dtype, backend):
SparseSemiStructuredTensor._FORCE_CUTLASS = (backend == "cutlass")
if backend == "cutlass" and IS_WINDOWS:
self.skipTest("CUTLASS not supported on Windows")
def run_test(m, n, k, device, dtype, dtype_out, use_input, rtol, atol):
mat1 = rand_sparse_semi_structured(m, k, dtype, device)
# mat2 transposed as int8 case supports only row-major/column-major combination
mat2 = make_tensor((n, k), dtype=dtype, device=device).t()
input = make_tensor((m,), dtype=dtype_out, device=device) if use_input else None
if use_input:
if dtype.is_floating_point:
alpha = 1.3
beta = -0.7
else:
alpha = 2
beta = -3
dtype_dense = torch.float32
mat1_dense = mat1.to(dtype_dense)
mat2_dense = mat2.to(dtype_dense)
if not use_input:
output0 = torch.mm(mat1_dense, mat2_dense)
else:
input_dense = input.to(dtype_dense)[:, None]
output0 = torch.addmm(input_dense, mat1_dense, mat2_dense, alpha=alpha, beta=beta)
compressed = to_sparse_semi_structured(mat1)
mat1_sparse = compressed.values()
mat1_meta = compressed.indices()
if not use_input:
output1 = torch._sparse_semi_structured_mm(mat1_sparse, mat1_meta, mat2, out_dtype=dtype_out)
else:
output1 = torch._sparse_semi_structured_addmm(
input, mat1_sparse, mat1_meta, mat2, alpha=alpha, beta=beta, out_dtype=dtype_out
)
torch.testing.assert_close(output1.to(dtype_dense), output0, rtol=rtol, atol=atol)
if dtype == torch.float32:
# Inputs are converted to TF32 internally for sparse GEMM,
# so make dense GEMM to do the same for matching results.
orig = torch.backends.cuda.matmul.allow_tf32
torch.backends.cuda.matmul.allow_tf32 = True
dtype_out = {torch.int8: torch.int32, torch.half: torch.half, torch.bfloat16: torch.bfloat16, torch.float32: torch.float32}
rtol, atol = 1e-3, 1e-3
if dtype == torch.bfloat16:
rtol, atol = 5e-3, 5e-3
elif dtype == torch.float32:
rtol, atol = 1e-3, 75e-2
for m, n, k, use_input in \
itertools.product(range(3), range(3), range(3), (False, True)):
m = 2 ** m * 32
n = 2 ** n * 32
k = 2 ** k * 128
run_test(m, n, k, device, dtype, dtype_out[dtype], use_input, rtol, atol)
if dtype == torch.float32:
torch.backends.cuda.matmul.allow_tf32 = orig
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
@inference_dtypes
def test_conversions(self, device, dtype):
def run_test(r, c, device, dtype):
dense_ref = rand_sparse_semi_structured(r, c, dtype, device)
compressed = to_sparse_semi_structured(dense_ref)
# The torch.ops.aten._to_sparse_semi_structured operator
# uses CUTLASS to perform conversion from given dense
# matrix to the pair of corresponding sparse and metadata