File tree 3 files changed +17
-7
lines changed
3 files changed +17
-7
lines changed Original file line number Diff line number Diff line change 81
81
REQUIREMENTS_TO_INSTALL=(
82
82
torch==" 2.7.0.${PYTORCH_NIGHTLY_VERSION} "
83
83
torchvision==" 0.22.0.${VISION_NIGHTLY_VERSION} "
84
- torchtune==" 0.6.0"
84
+ # torchtune=="0.6.0" # no 0.6.0 on xpu nightly
85
85
)
86
86
else
87
87
REQUIREMENTS_TO_INSTALL=(
115
115
" ${REQUIREMENTS_TO_INSTALL[@]} "
116
116
)
117
117
118
+ # Temporatory instal torchtune nightly from cpu nightly link since no torchtune nightly for xpu now
119
+ # TODO: Change to install torchtune from xpu nightly link, once torchtune xpu nightly is ready
120
+ if [[ -x " $( command -v xpu-smi) " ]];
121
+ then
122
+ (
123
+ set -x
124
+ $PIP_EXECUTABLE install --extra-index-url " https://download.pytorch.org/whl/nightly/cpu" \
125
+ torchtune==" 0.6.0.${TUNE_NIGHTLY_VERSION} "
126
+ )
127
+ fi
128
+
118
129
# For torchao need to install from github since nightly build doesn't have macos build.
119
130
# TODO: Remove this and install nightly build, once it supports macos
120
131
(
Original file line number Diff line number Diff line change 29
29
from torchchat .utils .build_utils import (
30
30
device_sync ,
31
31
is_cpu_device ,
32
- is_cuda_or_cpu_device ,
32
+ is_cuda_or_cpu_or_xpu_device ,
33
33
name_to_dtype ,
34
34
)
35
35
from torchchat .utils .measure_time import measure_time
@@ -539,7 +539,7 @@ def _initialize_model(
539
539
_set_gguf_kwargs (builder_args , is_et = is_pte , context = "generate" )
540
540
541
541
if builder_args .dso_path :
542
- if not is_cuda_or_cpu_device (builder_args .device ):
542
+ if not is_cuda_or_cpu_or_xpu_device (builder_args .device ):
543
543
print (
544
544
f"Cannot load specified DSO to { builder_args .device } . Attempting to load model to CPU instead"
545
545
)
@@ -573,7 +573,7 @@ def do_nothing(max_batch_size, max_seq_length):
573
573
raise RuntimeError (f"Failed to load AOTI compiled { builder_args .dso_path } " )
574
574
575
575
elif builder_args .aoti_package_path :
576
- if not is_cuda_or_cpu_device (builder_args .device ):
576
+ if not is_cuda_or_cpu_or_xpu_device (builder_args .device ):
577
577
print (
578
578
f"Cannot load specified PT2 to { builder_args .device } . Attempting to load model to CPU instead"
579
579
)
Original file line number Diff line number Diff line change @@ -303,6 +303,5 @@ def get_device(device) -> str:
303
303
def is_cpu_device (device ) -> bool :
304
304
return device == "" or str (device ) == "cpu"
305
305
306
-
307
- def is_cuda_or_cpu_device (device ) -> bool :
308
- return is_cpu_device (device ) or ("cuda" in str (device ))
306
+ def is_cuda_or_cpu_or_xpu_device (device ) -> bool :
307
+ return is_cpu_device (device ) or ("cuda" in str (device )) or ("xpu" in str (device ))
You can’t perform that action at this time.
0 commit comments