-
Notifications
You must be signed in to change notification settings - Fork 235
/
Copy pathutils.py
700 lines (552 loc) · 23.3 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
import functools
import itertools
import re
import time
from functools import reduce
from importlib.metadata import version
from math import gcd
from typing import Any, Callable, Tuple
import torch
import torch.nn.utils.parametrize as parametrize
from torch.utils._python_dispatch import return_and_correct_aliasing
__all__ = [
"benchmark_model",
"profiler_runner",
"get_available_devices",
"get_compute_capability",
"skip_if_compute_capability_less_than",
"benchmark_torch_function_in_microseconds",
"find_multiple",
"_register_custom_op",
"get_model_size_in_bytes",
"unwrap_tensor_subclass",
"TorchAOBaseTensor",
"TORCH_VERSION_AT_LEAST_2_2",
"TORCH_VERSION_AT_LEAST_2_3",
"TORCH_VERSION_AT_LEAST_2_4",
"TORCH_VERSION_AT_LEAST_2_5",
"TORCH_VERSION_AT_LEAST_2_6",
"TORCH_VERSION_AT_LEAST_2_7",
# Needs to be deprecated in the future
"TORCH_VERSION_AFTER_2_2",
"TORCH_VERSION_AFTER_2_3",
"TORCH_VERSION_AFTER_2_4",
"TORCH_VERSION_AFTER_2_5",
"is_MI300",
"is_sm_at_least_89",
"is_sm_at_least_90",
]
# Referenced from: https://github.com/pytorch/pytorch/blob/9105d54c6b37099575c0059ef274c86c4dc80c57/torch/ao/quantization/utils.py#L711
def _assert_and_get_unique_device(module: torch.nn.Module) -> Any:
"""
Returns the unique device for a module, or None if no device is found.
Throws an error if multiple devices are detected.
"""
devices = {p.device for p in module.parameters()} | {
p.device for p in module.buffers()
}
assert len(devices) <= 1, (
"prepare only works with cpu or single-device CUDA modules, "
f"but got devices {devices}"
)
device = next(iter(devices)) if len(devices) > 0 else None
return device
def benchmark_model(model, num_runs, args=(), kwargs=None, device_type=None):
"""Benchmark model runs with `args` and `kwargs` both are optional"""
if kwargs is None:
kwargs = {}
if device_type is None:
assert isinstance(
model, torch.nn.Module
), "Expecting `model` to be torch.nn.Module if device_type is not provided"
device_type = _assert_and_get_unique_device(model).type
if device_type == "cuda":
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(*args, **kwargs)
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_runs
elif device_type == "mps":
torch.mps.synchronize()
start_event = torch.mps.event.Event(enable_timing=True)
end_event = torch.mps.event.Event(enable_timing=True)
start_event.record()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(*args, **kwargs)
end_event.record()
torch.mps.synchronize()
return start_event.elapsed_time(end_event) / num_runs
elif device_type == "cpu":
torch.cpu.synchronize()
start_time = time.time()
# benchmark
for _ in range(num_runs):
with torch.autograd.profiler.record_function("timed region"):
model(*args, **kwargs)
end_time = time.time()
torch.cpu.synchronize()
average_time_per_run = (end_time - start_time) / num_runs
return average_time_per_run
def profiler_runner(path, fn, *args, **kwargs):
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
) as prof:
result = fn(*args, **kwargs)
prof.export_chrome_trace(path)
return result
def get_available_devices():
devices = ["cpu"]
if torch.cuda.is_available():
devices.append("cuda")
elif torch.xpu.is_available():
devices.append("xpu")
if TORCH_VERSION_AT_LEAST_2_5:
if torch.mps.is_available():
devices.append("mps")
return devices
def get_compute_capability():
if torch.cuda.is_available():
capability = torch.cuda.get_device_capability()
return float(f"{capability[0]}.{capability[1]}")
return 0.0
def skip_if_compute_capability_less_than(min_capability):
import unittest
def decorator(test_func):
def wrapper(*args, **kwargs):
if get_compute_capability() < min_capability:
raise unittest.SkipTest(
f"Compute capability is less than {min_capability}"
)
return test_func(*args, **kwargs)
return wrapper
return decorator
def compute_max_diff(output: torch.Tensor, output_ref: torch.Tensor) -> torch.Tensor:
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref)
)
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
import torch.utils.benchmark as benchmark # this avoids importing numpy when torchao module is loaded
# Manual warmup
f(*args, **kwargs)
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f}, # noqa: E501
)
measurement = t0.blocked_autorange()
return measurement.mean * 1e6
def find_multiple(n: int, *args: Tuple[int]) -> int:
k: int = reduce(lambda x, y: x * y // gcd(x, y), args + (1,)) # type: ignore[9]
if n % k == 0:
return n
return n + k - (n % k)
def _register_custom_op(lib):
"""This decorator is used to preserve some high level operators for torch.export.export
while still allow them to be decomposed for inductor path
requirement: make sure `fn.__name__[1:]` is the operator name you want to register
NOTE: This should be applied at the top, after all other decorators have been applied
NOTE: We haven't tested the case when `fn` accepts tensor subclass instance as input,
e.g. uint4 tensor subclass instance, and we'll probably need to figure out what would make
sense for downstream system (like executorch) to accept as well
Example:
lib = torch.library.Library("my_namespace', "FRAGMENT")
register_custom_op = _register_custom_op(lib)
@register_custom_op
def _the_op_that_needs_to_be_preserved(...)
...
# after this, `_the_op_that_needs_to_be_preserved` will be preserved as
# torch.ops.my_namespace.the_op_that_needs_to_be_preserved operator after
# torch.export.export / torch._export.export_for_training
"""
from torch._inductor.decomposition import register_decomposition
def decorator(fn):
if TORCH_VERSION_AT_LEAST_2_5:
from torch._library.infer_schema import infer_schema
# expecting fn.__name__ starts with `_` and we want to take the rest
# to be the name of the custom op
assert (
fn.__name__[0] == "_"
), f"Expecting function name starts with `_`, got {fn.__name__}"
assert not any(
c in fn.__name__ for c in ".<>"
), f"Expecting op to be defined in normal functions, not lambda or local: {fn.__name__}"
op_name = fn.__name__[1:]
schema = op_name + infer_schema(fn, mutates_args={})
lib.define(schema)
lib.impl(op_name, fn, "CompositeImplicitAutograd")
lib_namespace = lib.ns
op = getattr(getattr(torch.ops, lib_namespace), op_name)
register_decomposition([op])(fn)
return op
else:
return fn
return decorator
def get_model_size_in_bytes(model, ignore_embeddings=False):
"""
Returns the model size in bytes. The option to ignore embeddings
is useful for models with disproportionately large embeddings compared
to other model parameters that get quantized/sparsified.
"""
def flat_size(tensor):
if hasattr(tensor, "__tensor_flatten__"):
size = 0
# 0th element is a list of attributes that
# hold tensors
for attr_name in tensor.__tensor_flatten__()[0]:
sub_tensor = getattr(tensor, attr_name)
size += flat_size(sub_tensor)
return size
else:
return tensor.numel() * tensor.element_size()
model_size = 0
for name, child in model.named_children():
if not (isinstance(child, torch.nn.Embedding) and ignore_embeddings):
for p in itertools.chain(
child.parameters(recurse=False), child.buffers(recurse=False)
):
model_size += flat_size(p)
model_size += get_model_size_in_bytes(child, ignore_embeddings)
return model_size
class UnwrapTensorSubclass(torch.nn.Module):
def forward(self, *tensors):
todo = list(tensors)
for tp, meta, inner_tensors in reversed(self.rebuild_stack):
nb_tensor = len(inner_tensors)
inner_tensors = {a: b for a, b in zip(inner_tensors, todo[-nb_tensor:])}
todo = todo[nb_tensor:]
rebuilt = tp.__tensor_unflatten__(inner_tensors, meta, None, None)
todo.append(rebuilt)
assert len(todo) == 1
return todo[0]
def right_inverse(self, tensor):
assert type(tensor) is not torch.Tensor
rebuild_stack = []
plain_tensors = []
todo = [tensor]
while todo:
obj = todo.pop()
inner_tensors, metadata = obj.__tensor_flatten__()
rebuild_stack.append((type(obj), metadata, inner_tensors))
for attr_name in inner_tensors:
val = getattr(obj, attr_name)
if type(val) is torch.Tensor:
plain_tensors.append(val)
else:
assert isinstance(val, torch.Tensor)
todo.append(val)
self.rebuild_stack = rebuild_stack
return plain_tensors
def unwrap_tensor_subclass(model, filter_fn=None):
"""Unwraps (nested) tensor subclass in the model to plain tensors
This is a workaround to make a model with tensor subclass to work with `torch.export.export`
and `torch.aot_compile`, we hope this can be integrated into compile stack soon
tracking issue: https://github.com/pytorch/ao/issues/345
"""
for name, child in model.named_children():
# make sure child.weight is a tensor subclass
if (
(
isinstance(child, torch.nn.Linear)
or isinstance(child, torch.nn.Embedding)
)
and hasattr(child, "weight")
and type(child.weight) is not torch.Tensor
and type(child.weight) is not torch.nn.Parameter
and isinstance(child.weight, torch.Tensor)
and issubclass(type(child.weight), torch.Tensor)
):
parametrize.register_parametrization(
child, "weight", UnwrapTensorSubclass()
)
unwrap_tensor_subclass(child)
return model
def _is_float8_type(dtype: torch.dtype) -> bool:
fp8_types = {
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
}
return dtype in fp8_types
def parse_version(version_string):
# Extract just the X.Y.Z part from the version string
match = re.match(r"(\d+\.\d+\.\d+)", version_string)
if match:
version = match.group(1)
return [int(x) for x in version.split(".")]
else:
raise ValueError(f"Invalid version string format: {version_string}")
def compare_versions(v1, v2):
v1_parts = parse_version(v1)
v2_parts = parse_version(v2)
return (v1_parts > v2_parts) - (v1_parts < v2_parts)
def is_fbcode():
return not hasattr(torch.version, "git_version")
def torch_version_at_least(min_version):
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")
TORCH_VERSION_AT_LEAST_2_4 = torch_version_at_least("2.4.0")
TORCH_VERSION_AT_LEAST_2_3 = torch_version_at_least("2.3.0")
TORCH_VERSION_AT_LEAST_2_2 = torch_version_at_least("2.2.0")
"""
Helper function for implementing aten op or torch function dispatch
and dispatching to these implementations.
"""
def _implements(cls, aten_ops_or_torch_fns):
"""Use this decorator to implement a function for an aten ops in __torch_dispatch__
(if user passed in a list of ops)
or torch function in __torch_function__ (if user passed in a single object)
class MyTensor(torch.Tensor):
...
implements = classmethod(_implements)
implements = MyTensor.implements
@implements(torch.nn.functional.linear):
def _(func, types, args, kwargs):
...
"""
if not hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE"):
cls._ATEN_OP_OR_TORCH_FN_TABLE = {}
if not isinstance(aten_ops_or_torch_fns, (list, tuple)):
aten_ops_or_torch_fns = [aten_ops_or_torch_fns]
def decorator(func):
for op in aten_ops_or_torch_fns:
@functools.wraps(op)
def wrapper(f, types, args, kwargs):
return func(f, types, args, kwargs)
cls._ATEN_OP_OR_TORCH_FN_TABLE[op] = wrapper
return func
return decorator
def _dispatch__torch_function__(cls, func, types, args=(), kwargs=None):
"""Use this util function for a common `__torch_function__` implementation
that dispatches to ops/functions registered with `_implements`
class MyTensor(torch.Tensor):
...
__torch_function__ = classmethod(_dispatch__torch_function__)
"""
kwargs = {} if kwargs is None else kwargs
if (
hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE")
and func in cls._ATEN_OP_OR_TORCH_FN_TABLE
):
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
"""Use this util function for a common `__torch_dispatch__` implementation
that dispatches to ops/functions registered with `_implements`
class MyTensor(torch.Tensor):
...
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
"""
if (
hasattr(cls, "_ATEN_OP_OR_TORCH_FN_TABLE")
and func in cls._ATEN_OP_OR_TORCH_FN_TABLE
):
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
arg_types = tuple(type(arg) for arg in args)
kwarg_types = {k: type(arg) for k, arg in kwargs.items()}
raise NotImplementedError(
f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}"
)
def _register_layout(tensor_class: Callable, layout_class: Callable):
"""Helper function for layout registrations, this is used to implement
register_layout decorator for each tensor subclass, see aqt.py for example usage
Args:
tensor_class: Tensor subclass type
layout_class: the class type of subclass of `Layout`, e.g. `PlainLayout`
Returns:
a decorator that registers the tensor impl constructor in the table
"""
# tensor_class._LAYOUT_CONSTRUCTOR_TABLE is a map from layout_class like TensorCoreTiledLayout
# to tensor_impl class constructor like TensorCoreTiledAQTTensorImpl.from_plain that can construct a tensor_impl
# from plain data like (quantized, unpacked) `data`, `scale`, `zero_point`
if not hasattr(tensor_class, "_LAYOUT_CONSTRUCTOR_TABLE"):
tensor_class._LAYOUT_CONSTRUCTOR_TABLE = {}
def decorator(tensor_impl_class):
tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] = (
tensor_impl_class.from_plain
)
if TORCH_VERSION_AT_LEAST_2_5:
# Allow serialization to work for models uses this tensor impl subclass
torch.serialization.add_safe_globals([layout_class, tensor_impl_class])
return tensor_impl_class
return decorator
def _get_tensor_impl_constructor(
tensor_class: Callable, layout_class: Callable
) -> Callable:
"""Get TensorImpl class constructor (TensorImplClass.from_plain) for `tensor_class` based on `layout_class`
`layout_class` means the class type of subclass of `Layout`, e.g. `PlainLayout`
Args:
tensor_class: Tensor subclass type
layout_class: the class type of subclass of `Layout`, e.g. `PlainLayout`
Returns:
tensor impl subclass constructor for the layout_class
"""
if not hasattr(tensor_class, "_LAYOUT_CONSTRUCTOR_TABLE"):
raise ValueError(
f"no registered tensor_impl class constructor for: {tensor_class}"
)
if layout_class not in tensor_class._LAYOUT_CONSTRUCTOR_TABLE:
raise ValueError(
f"layout_name: {layout_class} is not supported yet for {tensor_class}"
)
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]
"""
TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass
"""
class TorchAOBaseTensor(torch.Tensor):
"""A util tensor subclass that provides commonly used functions
new tensor subclass can inherit it to get all the utility functions, and
should be inherited to define a new tensor subclass
class MyTensor(TorchAOBaseTensor):
pass
This includes:
`_get_to_kwargs` that can get the kwargs for `to`
class MyTensor(TorchAOBaseTensor):
def to(self, *args, **kwargs):
kwargs = _get_to_kwargs(*args, **kwargs)
...
`implements`:
implements = MyTensor.implements
@implements(torch.nn.functional.linear):
def _(func, types, args, kwargs):
...
`register_layout`:
register_layout = MyTensor.register_layout
@register_layout(PlainLayout)
class PlainAQTTensorImpl(...):
...
`get_tensor_impl_constructor`:
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor
# in constructor of MyTensor:
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
`__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass
`__repr__`: Used for tensor representation, must be defined in the subclass
`_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass
"""
implements = classmethod(_implements)
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
__torch_function__ = classmethod(_dispatch__torch_function__)
register_layout = classmethod(_register_layout)
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
def _get_to_kwargs(self, *args, **kwargs):
# `torch._C._nn._parse_to` can't handle `layout` argument
for arg in args:
if isinstance(arg, torch.layout):
args.remove(arg)
if "layout" in kwargs:
kwargs.pop("layout")
# ignoring `non_blocking` and `memory_format` args since these are not
# very useful for most of the tensor subclasses
# if in the future there are use cases that need these, we'd recommend
# to override `_get_to_kwargs` and return these args
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
kwargs = {
"device": device,
"dtype": dtype,
}
return kwargs
def __tensor_flatten__(self):
raise NotImplementedError("Subclasses must implement __tensor_flatten__")
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")
def __repr__(self):
raise NotImplementedError("Subclasses must implement __repr__")
def _apply_fn_to_data(self, fn):
raise NotImplementedError("Subclasses must implement _apply_fn_to_data")
def get_layout(self):
if not hasattr(self, "_layout"):
return None
return self._layout
implements = TorchAOBaseTensor.implements
aten = torch.ops.aten
@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)
@implements(aten.t.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
)
@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.slice)
)
def fill_defaults(args, n, defaults_tail):
"""
__torch_dispatch__ doesn't guarantee the number of arguments you are
passed (e.g., defaulted arguments are not passed); but usually it is
convenient to pad out the arguments list with defaults. This function
helps you do that.
Args:
args: the list of positional arguments passed to __torch_dispatch__
n: the number of arguments you are expecting to get
defaults_tail: default values for the arguments, starting from the
end of the list
Example:
>>> fill_defaults([1, 2, 3], 5, [3, 4, 5])
[1, 2, 3, 4, 5]
>>> fill_defaults([1, 2, 3], 5, [None, None, None])
[1, 2, 3, None, None]]
"""
if n - len(defaults_tail) > len(args):
raise RuntimeError("not enough defaults to fill arguments")
r = list(args)
for i in range(len(args), n):
r.append(defaults_tail[i - n + len(defaults_tail)])
return r
# Deprecated, will be deleted in the future
def _torch_version_at_least(min_version):
return is_fbcode() or version("torch") >= min_version
def is_MI300():
if torch.cuda.is_available() and torch.version.hip:
mxArchName = ["gfx940", "gfx941", "gfx942"]
archName = torch.cuda.get_device_properties().gcnArchName
for arch in mxArchName:
if arch in archName:
return True
return False
def is_sm_at_least_89():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (8, 9)
)
def is_sm_at_least_90():
return (
torch.cuda.is_available()
and torch.version.cuda
and torch.cuda.get_device_capability() >= (9, 0)
)
TORCH_VERSION_AFTER_2_5 = _torch_version_at_least("2.5.0.dev")
TORCH_VERSION_AFTER_2_4 = _torch_version_at_least("2.4.0.dev")
TORCH_VERSION_AFTER_2_3 = _torch_version_at_least("2.3.0.dev")
TORCH_VERSION_AFTER_2_2 = _torch_version_at_least("2.2.0.dev")