Skip to content

Commit 4251a54

Browse files
jenniewJack-Khuusonghappy
authored
Add XPU support for AOT inductor (#1503)
* add xpu * add xpu device * update * profile * update install * update * update * update * update * update * update --------- Co-authored-by: Jack-Khuu <[email protected]> Co-authored-by: Guoqiong <[email protected]>
1 parent d3cd165 commit 4251a54

File tree

3 files changed

+17
-7
lines changed

3 files changed

+17
-7
lines changed

Diff for: install/install_requirements.sh

+12-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ then
8181
REQUIREMENTS_TO_INSTALL=(
8282
torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}"
8383
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
84-
torchtune=="0.6.0"
84+
#torchtune=="0.6.0" # no 0.6.0 on xpu nightly
8585
)
8686
else
8787
REQUIREMENTS_TO_INSTALL=(
@@ -115,6 +115,17 @@ fi
115115
"${REQUIREMENTS_TO_INSTALL[@]}"
116116
)
117117

118+
# Temporatory instal torchtune nightly from cpu nightly link since no torchtune nightly for xpu now
119+
# TODO: Change to install torchtune from xpu nightly link, once torchtune xpu nightly is ready
120+
if [[ -x "$(command -v xpu-smi)" ]];
121+
then
122+
(
123+
set -x
124+
$PIP_EXECUTABLE install --extra-index-url "https://download.pytorch.org/whl/nightly/cpu" \
125+
torchtune=="0.6.0.${TUNE_NIGHTLY_VERSION}"
126+
)
127+
fi
128+
118129
# For torchao need to install from github since nightly build doesn't have macos build.
119130
# TODO: Remove this and install nightly build, once it supports macos
120131
(

Diff for: torchchat/cli/builder.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchchat.utils.build_utils import (
3030
device_sync,
3131
is_cpu_device,
32-
is_cuda_or_cpu_device,
32+
is_cuda_or_cpu_or_xpu_device,
3333
name_to_dtype,
3434
)
3535
from torchchat.utils.measure_time import measure_time
@@ -539,7 +539,7 @@ def _initialize_model(
539539
_set_gguf_kwargs(builder_args, is_et=is_pte, context="generate")
540540

541541
if builder_args.dso_path:
542-
if not is_cuda_or_cpu_device(builder_args.device):
542+
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
543543
print(
544544
f"Cannot load specified DSO to {builder_args.device}. Attempting to load model to CPU instead"
545545
)
@@ -573,7 +573,7 @@ def do_nothing(max_batch_size, max_seq_length):
573573
raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
574574

575575
elif builder_args.aoti_package_path:
576-
if not is_cuda_or_cpu_device(builder_args.device):
576+
if not is_cuda_or_cpu_or_xpu_device(builder_args.device):
577577
print(
578578
f"Cannot load specified PT2 to {builder_args.device}. Attempting to load model to CPU instead"
579579
)

Diff for: torchchat/utils/build_utils.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,5 @@ def get_device(device) -> str:
303303
def is_cpu_device(device) -> bool:
304304
return device == "" or str(device) == "cpu"
305305

306-
307-
def is_cuda_or_cpu_device(device) -> bool:
308-
return is_cpu_device(device) or ("cuda" in str(device))
306+
def is_cuda_or_cpu_or_xpu_device(device) -> bool:
307+
return is_cpu_device(device) or ("cuda" in str(device)) or ("xpu" in str(device))

0 commit comments

Comments
 (0)