Skip to content

Commit 2335ec2

Browse files
authored
update transformers imports for deepspeed and is_torch_xla_available (#2012)
* change deepspeed to integrations.deepspeed * add version check and change tpu to xla * add version check
1 parent 29f23f1 commit 2335ec2

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

optimum/onnxruntime/trainer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
from torch.utils.data import Dataset, RandomSampler
5656
from transformers.data.data_collator import DataCollator
5757
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
58-
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
5958
from transformers.modeling_utils import PreTrainedModel, unwrap_model
6059
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
6160
from transformers.trainer import Trainer
@@ -81,10 +80,10 @@
8180
is_apex_available,
8281
is_sagemaker_dp_enabled,
8382
is_sagemaker_mp_enabled,
84-
is_torch_tpu_available,
8583
)
8684

8785
from ..utils import logging
86+
from ..utils.import_utils import check_if_transformers_greater
8887
from .training_args import ORTOptimizerNames, ORTTrainingArguments
8988
from .utils import (
9089
is_onnxruntime_training_available,
@@ -94,8 +93,25 @@
9493
if is_apex_available():
9594
from apex import amp
9695

97-
if is_torch_tpu_available(check_device=False):
98-
import torch_xla.core.xla_model as xm
96+
if check_if_transformers_greater("4.33"):
97+
from transformers.integrations.deepspeed import (
98+
deepspeed_init,
99+
deepspeed_load_checkpoint,
100+
is_deepspeed_zero3_enabled,
101+
)
102+
else:
103+
from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled
104+
105+
if check_if_transformers_greater("4.39"):
106+
from transformers.utils import is_torch_xla_available
107+
108+
if is_torch_xla_available():
109+
import torch_xla.core.xla_model as xm
110+
else:
111+
from transformers.utils import is_torch_tpu_available
112+
113+
if is_torch_tpu_available(check_device=False):
114+
import torch_xla.core.xla_model as xm
99115

100116
if TYPE_CHECKING:
101117
import optuna

optimum/onnxruntime/trainer_seq2seq.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919
import torch
2020
from torch import nn
2121
from torch.utils.data import Dataset
22-
from transformers.deepspeed import is_deepspeed_zero3_enabled
2322
from transformers.trainer_utils import PredictionOutput
2423
from transformers.utils import is_accelerate_available, logging
2524

25+
from ..utils.import_utils import check_if_transformers_greater
2626
from .trainer import ORTTrainer
2727

2828

@@ -33,6 +33,11 @@
3333
"The package `accelerate` is required to use the ORTTrainer. Please install it following https://huggingface.co/docs/accelerate/basic_tutorials/install."
3434
)
3535

36+
if check_if_transformers_greater("4.33"):
37+
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
38+
else:
39+
from transformers.deepspeed import is_deepspeed_zero3_enabled
40+
3641
logger = logging.get_logger(__name__)
3742

3843

0 commit comments

Comments
 (0)