Skip to content

Commit 838b74d

Browse files
authored
Add Ascend NPU support (#1758)
1 parent 2e99bb3 commit 838b74d

File tree

5 files changed

+114
-16
lines changed

5 files changed

+114
-16
lines changed

src/axolotl/utils/bench.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
import pynvml
55
import torch
66
from pynvml.nvml import NVMLError
7+
from transformers.utils.import_utils import is_torch_npu_available
8+
9+
from axolotl.utils.distributed import get_device_type
710

811

912
def check_cuda_device(default_value):
@@ -53,6 +56,12 @@ def mps_memory_usage_all():
5356
return usage, reserved - usage, 0
5457

5558

59+
def npu_memory_usage_all(device=0):
60+
usage = torch.npu.memory_allocated(device) / 1024.0**3
61+
reserved = torch.npu.memory_reserved(device) / 1024.0**3
62+
return usage, reserved - usage, 0
63+
64+
5665
@check_cuda_device(0.0)
5766
def gpu_memory_usage_smi(device=0):
5867
if isinstance(device, torch.device):
@@ -69,8 +78,11 @@ def gpu_memory_usage_smi(device=0):
6978

7079

7180
def log_gpu_memory_usage(log, msg, device):
81+
cur_device = get_device_type()
7282
if torch.backends.mps.is_available():
7383
usage, cache, misc = mps_memory_usage_all()
84+
elif "npu" in str(cur_device) and is_torch_npu_available():
85+
usage, cache, misc = npu_memory_usage_all(device)
7486
else:
7587
usage, cache, misc = gpu_memory_usage_all(device)
7688
extras = []
@@ -79,6 +91,7 @@ def log_gpu_memory_usage(log, msg, device):
7991
if misc > 0:
8092
extras.append(f"+{misc:.03f}GB misc")
8193
log.info(
82-
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
94+
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
95+
stacklevel=2,
8396
)
8497
return usage, cache, misc

src/axolotl/utils/config/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77
from transformers.utils import is_torch_bf16_gpu_available
8+
from transformers.utils.import_utils import is_torch_npu_available
89

910
from axolotl.integrations.config import merge_input_args
1011
from axolotl.utils.bench import log_gpu_memory_usage
@@ -29,7 +30,10 @@ def get_device():
2930
if torch.backends.mps.is_available():
3031
return "mps"
3132

32-
raise SystemError("No CUDA/mps device found")
33+
if is_torch_npu_available():
34+
return f"npu:{cfg.local_rank}"
35+
36+
raise SystemError("No CUDA/mps/npu device found")
3337
except Exception: # pylint: disable=broad-exception-caught
3438
return "cpu"
3539

@@ -39,6 +43,8 @@ def get_device():
3943
else:
4044
if cfg.device.startswith("cuda"):
4145
cfg.device_map = {"": torch.cuda.current_device()}
46+
elif cfg.device.startswith("npu"):
47+
cfg.device_map = {"npu": torch.npu.current_device()}
4248
else:
4349
cfg.device_map = {"": cfg.device}
4450

src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020
from transformers import SchedulerType
2121
from transformers.training_args import OptimizerNames
22+
from transformers.utils.import_utils import is_torch_npu_available
2223

2324
from axolotl.utils.config.models.internals import GPUCapabilities
2425

@@ -1433,6 +1434,40 @@ def check_torch_compile_deepspeed(cls, data):
14331434
)
14341435
return data
14351436

1437+
@model_validator(mode="before")
1438+
@classmethod
1439+
def check_npu_config(cls, data):
1440+
if is_torch_npu_available():
1441+
# check attention config
1442+
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
1443+
for attn in attn_list:
1444+
if data.get(attn):
1445+
raise NotImplementedError(
1446+
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
1447+
)
1448+
1449+
# check quant config
1450+
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
1451+
optimizer = data.get("optimizer")
1452+
raise NotImplementedError(
1453+
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
1454+
)
1455+
1456+
quant_list = ["load_in_8bit", "load_in_4bit"]
1457+
for quant in quant_list:
1458+
if data.get(quant):
1459+
raise NotImplementedError(
1460+
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
1461+
)
1462+
1463+
# check dtype config
1464+
if data.get("tf32"):
1465+
raise NotImplementedError(
1466+
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
1467+
)
1468+
1469+
return data
1470+
14361471

14371472
class AxolotlConfigWCapabilities(AxolotlInputConfig):
14381473
"""wrapper to valdiate gpu capabilities with the configured options"""

src/axolotl/utils/distributed.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,44 @@
99
import torch
1010
import torch.distributed as dist
1111
from accelerate import PartialState
12+
from transformers.utils.import_utils import (
13+
is_torch_cuda_available,
14+
is_torch_mps_available,
15+
is_torch_npu_available,
16+
)
1217

1318
distributed_state = None # pylint: disable=invalid-name
1419

1520

21+
def get_device_type():
22+
device = torch.device("cpu")
23+
if is_torch_cuda_available():
24+
device = torch.device("cuda")
25+
elif is_torch_mps_available():
26+
device = torch.device("mps")
27+
elif is_torch_npu_available():
28+
device = torch.device("npu")
29+
return device
30+
31+
32+
def get_device_count():
33+
cur_device = get_device_type()
34+
if "cuda" in str(cur_device):
35+
return torch.cuda.device_count()
36+
if "npu" in str(cur_device):
37+
return torch.npu.device_count()
38+
return 1
39+
40+
41+
def get_current_device():
42+
cur_device = get_device_type()
43+
if "cuda" in str(cur_device):
44+
return torch.cuda.current_device()
45+
if "npu" in str(cur_device):
46+
return torch.npu.current_device()
47+
return 0
48+
49+
1650
def is_distributed():
1751
"""
1852
Check if distributed training is initialized.
@@ -91,7 +125,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
91125
if not is_distributed():
92126
return [value_scalar]
93127
value_tensor = torch.tensor(
94-
value_scalar, device=torch.cuda.current_device()
128+
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
95129
).float()
96130

97131
if not is_main_process():
@@ -115,13 +149,14 @@ def broadcast_dict(vals: dict):
115149
if not is_distributed():
116150
return vals
117151

152+
cur_device = get_device_type()
118153
if is_main_process():
119154
data_byte = pickle.dumps(vals)
120-
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
121-
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
155+
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
156+
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
122157
else:
123-
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
124-
data_size = torch.IntTensor([0]).to("cuda")
158+
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
159+
data_size = torch.IntTensor([0]).to(cur_device)
125160

126161
dist.broadcast(data_size, 0)
127162
if not is_main_process():
@@ -150,14 +185,15 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
150185
Returns:
151186
- The computed value (int or float).
152187
"""
188+
cur_device = f"{get_device_type()}:{get_current_device()}"
153189
if is_main_process():
154190
value_scalar = fn()
155191
value_tensor = torch.tensor(
156-
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
192+
value_scalar, device=cur_device, dtype=torch.float32
157193
)
158194
else:
159195
value_tensor = torch.tensor(
160-
0.0, device=torch.cuda.current_device(), dtype=torch.float32
196+
0.0, device=cur_device, dtype=torch.float32
161197
) # Placeholder tensor
162198

163199
# Broadcast the tensor to all processes.
@@ -184,7 +220,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
184220
"""
185221
value_scalar = fn()
186222
value_tensor = torch.tensor(
187-
value_scalar, device=torch.cuda.current_device()
223+
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
188224
).float()
189225

190226
# Placeholder tensor for gathering results

src/axolotl/utils/models.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from axolotl.utils.bench import log_gpu_memory_usage
5656
from axolotl.utils.chat_templates import get_chat_template_from_config
5757
from axolotl.utils.dict import DictDefault
58-
from axolotl.utils.distributed import zero_only
58+
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
5959
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
6060
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
6161
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -570,7 +570,8 @@ def set_device_map_config(self) -> None:
570570
)
571571

572572
max_memory = {}
573-
for i in range(torch.cuda.device_count()):
573+
num_device = get_device_count()
574+
for i in range(num_device):
574575
max_memory[i] = gpu_memory_limit
575576
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
576577

@@ -595,8 +596,11 @@ def set_device_map_config(self) -> None:
595596
self.model_kwargs["device_map"] = device_map
596597
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
597598

598-
if torch.backends.mps.is_available():
599+
cur_device = get_device_type()
600+
if "mps" in str(cur_device):
599601
self.model_kwargs["device_map"] = "mps:0"
602+
elif "npu" in str(cur_device):
603+
self.model_kwargs["device_map"] = "npu:0"
600604

601605
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
602606
# if cfg.rl:
@@ -1050,7 +1054,11 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
10501054
self.ajust_model_config()
10511055

10521056
# log device memory usage
1053-
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
1057+
if hasattr(self.model, "device") and self.model.device.type in (
1058+
"cuda",
1059+
"mps",
1060+
"npu",
1061+
):
10541062
log_gpu_memory_usage(LOG, "after model load", self.model.device)
10551063

10561064
# make sure these are fp32 per Ramesh et al. (2021)
@@ -1118,9 +1126,9 @@ def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]:
11181126
and not skip_move_to_device
11191127
):
11201128
# TODO revaldate this conditional
1121-
self.model.to(f"cuda:{self.cfg.local_rank}")
1129+
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
11221130

1123-
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
1131+
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
11241132
setattr(self.model, "is_parallelizable", True)
11251133
setattr(self.model, "model_parallel", True)
11261134

0 commit comments

Comments
 (0)