Skip to content

Commit 42408aa

Browse files
Thiago CrepaldirayankrishRayan Krishnan
authored
Add new PytTrch front-end (microsoft#4815)
* Add ORTTrainerOptions class for the new pytorch frontend (microsoft#4382) Add ORTTrainerOptions class and some placeholders * Add _ORTTrainerModelDesc to perform validation for model description (microsoft#4416) * Add Loss Scaler classes to the new frontend (microsoft#4306) * Add TrainStepInfo used on the new frontend API (microsoft#4256) * Add Optimizer classes to the new frontend (microsoft#4280) * Add LRScheduler implementation (microsoft#4357) * Add basic ORTTrainer API (microsoft#4435) This PR presents the public API for ORTTrainer for the short term development. It also validates and saves input parameters, which will be used in the next stages, such as building ONNX model, post processing the model and configuring the training session * Add opset_version into ORTTrainerOptions and change type of ORTTrainer.loss_fn (microsoft#4592) * Update ModelDescription and minor fix on ORTTrainer ctor (microsoft#4605) * Update ModelDescription and minor fix on ORTTrainer/ORTTrainerOptions This PR keeps the public API intact, but changes how model description is stored on the backend Currently, users creates a dict with two lists of tuples. One list called 'inputs' and each tuple has the following format tuple(name, shape). The second list is called 'outputs' and each tuple can be either tuple(name, shape) or tuple(name, shape, is_loss). With this PR, when this dict is passed in to ORTTrainer, it is fully validated as usual. However, tuples are internally replaced by namedtuples and all output tuples will have tuple(name, shape, is_loss) format instead of is_loss being optionally present. Additionally to that normalization in the internal representation (which eases coding), two internal methods were created to replace a namedtuple(name, shape) to namedtuple(name, shape, dtype) or namedtuple(name, shape, is_loss, dtype) dependeing whether the tuple is an input or output. This is necessary as ORTTRainer finds out data types of each input/output during model export to onnx. Finally, a minor fix was done on ORTTrainer. It could initialize ORTTrainerOptions incorrectly when options=None * Rename input name for test * Add ONNX Model Export to New Frontend (microsoft#4612) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Thiago Crepaldi <[email protected]> * Create training session + minor improvements (microsoft#4668) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> * Save ONNX model in file (microsoft#4671) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> * Add eval step (microsoft#4674) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> * Add train_step (microsoft#4677) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> * Add LR Scheduler (microsoft#4694) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Thiago Crepaldi <[email protected]> * Add deterministic compute tests (microsoft#4716) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Thiago Crepaldi <[email protected]> * Add legacy vs experimental ORTTrainer accuracy comparison (microsoft#4727) Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> Co-authored-by: Thiago Crepaldi <[email protected]> * Add Mixed precision/LossScaler + several fixes (microsoft#4739) Additionally to the mixed precision/loss scaler code, this PR includes: * Fix CUDA training * Add optimization_step into TrainStepInfo class * Refactor LRSCheduler to use optimization_step instead of step * Updated several default values at ORTTrainerOptions * Add initial Gradient Accumulation supported. Untested * Fix ONNX model post processing * Refactor unit tests * Add ONNX BERT example + minor fixes (microsoft#4757) * Fix training issue when passing ONNX file into ORTTrainer Co-authored-by: Thiago Crepaldi <[email protected]> Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net> * Add Dynamic Shape support (microsoft#4758) * Update DeepSpeed Zero Stage option to a separate option group (microsoft#4772) * Add support to fetches (microsoft#4777) * Add Gradient Accumulation Steps support (microsoft#4793) * Fix Dynamic Axes feature and add unit test (microsoft#4795) * Add frozen weights test (microsoft#4807) * Move new pytorch front-end to 'experimental' namespace (microsoft#4814) * Fix build Co-authored-by: Rayan-Krishnan <[email protected]> Co-authored-by: Rayan Krishnan <t-rakr@OrtDevTest2v100.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
1 parent 5eec4f6 commit 42408aa

28 files changed

+3740
-4
lines changed

cmake/onnxruntime_python.cmake

+27
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,15 @@ if (onnxruntime_ENABLE_TRAINING)
173173
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
174174
"${ORTTRAINING_SOURCE_DIR}/python/training/*.py"
175175
)
176+
file(GLOB onnxruntime_python_root_srcs CONFIGURE_DEPENDS
177+
"${ORTTRAINING_SOURCE_DIR}/python/experimental/*.py"
178+
)
179+
file(GLOB onnxruntime_python_amp_srcs CONFIGURE_DEPENDS
180+
"${ORTTRAINING_SOURCE_DIR}/python/experimental/amp/*.py"
181+
)
182+
file(GLOB onnxruntime_python_optim_srcs CONFIGURE_DEPENDS
183+
"${ORTTRAINING_SOURCE_DIR}/python/experimental/optim/*.py"
184+
)
176185
else()
177186
file(GLOB onnxruntime_python_capi_training_srcs CONFIGURE_DEPENDS
178187
"${ONNXRUNTIME_ROOT}/python/training/*.py"
@@ -260,6 +269,24 @@ add_custom_command(
260269
$<TARGET_FILE_DIR:${test_data_target}>
261270
)
262271

272+
if (onnxruntime_ENABLE_TRAINING)
273+
add_custom_command(
274+
TARGET onnxruntime_pybind11_state POST_BUILD
275+
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/experimental
276+
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/experimental/amp
277+
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/experimental/optim
278+
COMMAND ${CMAKE_COMMAND} -E copy
279+
${onnxruntime_python_root_srcs}
280+
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/experimental/
281+
COMMAND ${CMAKE_COMMAND} -E copy
282+
${onnxruntime_python_amp_srcs}
283+
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/experimental/amp/
284+
COMMAND ${CMAKE_COMMAND} -E copy
285+
${onnxruntime_python_optim_srcs}
286+
$<TARGET_FILE_DIR:${test_data_target}>/onnxruntime/experimental/optim/
287+
)
288+
endif()
289+
263290
if (onnxruntime_USE_DNNL)
264291
add_custom_command(
265292
TARGET onnxruntime_pybind11_state POST_BUILD

dockerfiles/Dockerfile.training

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ RUN conda install -y \
4747
pip install \
4848
onnx=="${ONNX_VERSION}"
4949

50+
# install cerberus for the new pytorch front-end
51+
RUN pip install cerberus
52+
5053
# build ucx suite
5154
# note: openmpi will not select ucx without multithreading enabled
5255
ARG UCX_VERSION

onnxruntime/__init__.py

+12
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,21 @@
1313
from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \
1414
RunOptions, SessionOptions, set_default_logger_severity, NodeArg, ModelMetadata, GraphOptimizationLevel, \
1515
ExecutionMode, OrtDevice, SessionIOBinding
16+
17+
try:
18+
from onnxruntime.capi._pybind_state import set_cuda_mem_limit, set_cuda_device_id
19+
except ImportError:
20+
pass
21+
1622
from onnxruntime.capi.session import InferenceSession, IOBinding
1723
from onnxruntime.capi import onnxruntime_validation
1824

1925
from onnxruntime.capi.training import * # noqa: F403
2026

27+
# TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end
28+
try:
29+
from . import experimental
30+
except ImportError:
31+
pass
32+
2133
onnxruntime_validation.check_distro_info()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# -------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License.
4+
# --------------------------------------------------------------------------
5+
from onnxruntime.capi._pybind_state import TrainingParameters
6+
from onnxruntime.capi.training.training_session import TrainingSession
7+
8+
from .orttrainer_options import ORTTrainerOptions
9+
from .orttrainer import ORTTrainer, TrainStepInfo
10+
from . import amp, optim, model_desc_validation
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import importlib.util
2+
import numpy as np
3+
import os
4+
import sys
5+
import torch
6+
7+
8+
def get_device_index(device):
9+
'''Returns device index from a device'''
10+
11+
if type(device) == str:
12+
# Could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0
13+
device = torch.device(device)
14+
return 0 if device.index is None else device.index
15+
16+
17+
def get_device_index_from_input(input):
18+
'''Returns device index from a input PyTorch Tensor'''
19+
20+
if isinstance(input, (list, tuple)):
21+
device_index = get_device_index(input[0].device)
22+
else:
23+
device_index = get_device_index(input.device)
24+
return device_index
25+
26+
27+
def get_all_gradients_finite_name_from_session(session):
28+
'''Find all_gradients_finite node on Session graph and return its name'''
29+
30+
nodes = [x for x in session._outputs_meta if 'all_gradients_finite' in x.name]
31+
if len(nodes) != 1:
32+
raise RuntimeError("'all_gradients_finite' node not found within training session")
33+
return nodes[0].name
34+
35+
36+
def get_gradient_accumulation_name_from_session(session):
37+
'''Find Group_Accumulated_Gradients node on Session graph and return its name'''
38+
39+
nodes = [x for x in session._outputs_meta if 'Group_Accumulated_Gradients' in x.name]
40+
if len(nodes) != 1:
41+
raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session")
42+
return nodes[0].name
43+
44+
45+
def dtype_torch_to_numpy(torch_dtype):
46+
'''Converts PyTorch types to Numpy types
47+
48+
Also must map to types accepted by:
49+
MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type)
50+
51+
References:
52+
https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html
53+
https://pytorch.org/docs/stable/tensors.html
54+
'''
55+
if torch_dtype == torch.float64 or torch_dtype == torch.double:
56+
return np.float64
57+
elif torch_dtype == torch.float32 or torch_dtype == torch.float:
58+
return np.float32
59+
elif torch_dtype == torch.float16 or torch_dtype == torch.half or torch_dtype == torch.bfloat16:
60+
# NOTE: numpy doesn't support bfloat16
61+
return np.float16
62+
elif torch_dtype == torch.int64 or torch_dtype == torch.long:
63+
return np.longlong # np.int64 doesn't work!?
64+
elif torch_dtype == torch.int32 or torch_dtype == torch.int:
65+
return np.int32
66+
elif torch_dtype == torch.int16 or torch_dtype == torch.short:
67+
return np.int16
68+
elif torch_dtype == torch.int8:
69+
return np.int8
70+
elif torch_dtype == torch.uint8:
71+
return np.uint8
72+
elif torch_dtype == torch.complex32 or torch_dtype == torch.complex64:
73+
# NOTE: numpy doesn't support complex32
74+
return np.complex64
75+
elif torch_dtype == torch.complex128 or torch_dtype == torch.cdouble:
76+
return np.complex128
77+
elif torch_dtype == torch.bool:
78+
return np.bool_
79+
else:
80+
raise ValueError(
81+
f'torch_dtype ({str(torch_dtype)}) type is not supported by Numpy')
82+
83+
84+
def dtype_onnx_to_torch(onnx_type):
85+
'''Converts ONNX types to PyTorch types
86+
87+
Reference: https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto (enum DataType)
88+
https://pytorch.org/docs/stable/tensors.html
89+
'''
90+
onnx_types = ['UNDEFINED', 'FLOAT', 'UINT8', 'INT8', 'UINT16', 'INT16', 'INT32', 'INT64', 'STRING',
91+
'BOOL', 'FLOAT16', 'DOUBLE', 'UINT32', 'UINT64', 'COMPLEX64', 'COMPLEX128', 'BFLOAT16']
92+
93+
if isinstance(onnx_type, int):
94+
assert onnx_type < len(onnx_types), "Invalid onnx_type integer"
95+
elif isinstance(onnx_type, str):
96+
onnx_type = onnx_type.upper()
97+
assert onnx_type in onnx_types, "Invalid onnx_type string"
98+
onnx_type = onnx_types.index(onnx_type)
99+
else:
100+
raise ValueError(
101+
"'onnx_type' must be an ONNX type represented by either a string or integer")
102+
103+
if onnx_type == 0:
104+
return None
105+
elif onnx_type == 1:
106+
return torch.float
107+
elif onnx_type >= 2 and onnx_type <= 3:
108+
# NOTE: Pytorch doesn't support uint8
109+
return torch.int8
110+
elif onnx_type >= 4 and onnx_type <= 5:
111+
# NOTE: Pytorch doesn't support int16
112+
return torch.int16
113+
elif onnx_type == 6 or onnx_type == 12:
114+
# NOTE: Pytorch doesn't support uint32
115+
return torch.int32
116+
elif onnx_type == 7 or onnx_type == 13:
117+
# NOTE: Pytorch doesn't support uint64
118+
return torch.int64
119+
elif onnx_type == 8:
120+
return str
121+
elif onnx_type == 9:
122+
return torch.bool
123+
elif onnx_type == 10:
124+
return torch.float16
125+
elif onnx_type == 11:
126+
return torch.double
127+
elif onnx_type == 14:
128+
return torch.complex64
129+
elif onnx_type == 15:
130+
return torch.complex128
131+
elif onnx_type == 16:
132+
return torch.bfloat
133+
134+
135+
def static_vars(**kwargs):
136+
r'''Decorator to add :py:attr:`kwargs` as static vars to 'func'
137+
138+
Example:
139+
140+
.. code-block:: python
141+
142+
>>> @static_vars(counter=0)
143+
... def myfync():
144+
... myfync.counter += 1
145+
... return myfync.counter
146+
...
147+
>>> print(myfunc())
148+
1
149+
>>> print(myfunc())
150+
2
151+
>>> print(myfunc())
152+
3
153+
>>> myfunc.counter = 100
154+
>>> print(myfunc())
155+
101
156+
'''
157+
def decorate(func):
158+
for k in kwargs:
159+
setattr(func, k, kwargs[k])
160+
return func
161+
return decorate
162+
163+
164+
def import_module_from_file(file_path, module_name=None):
165+
'''Import a Python module from a file into interpreter'''
166+
167+
assert isinstance(file_path, str) and os.path.exists(file_path),\
168+
"'file_path' must be a full path string with the python file to load"
169+
assert module_name is None or isinstance(module_name, str) and module_name,\
170+
"'module_name' must be a string with the python module name to load"
171+
172+
if not module_name:
173+
module_name = os.path.basename(file_path).split('.')[0]
174+
175+
spec = importlib.util.spec_from_file_location(module_name, file_path)
176+
module = importlib.util.module_from_spec(spec)
177+
sys.modules[module_name] = module
178+
spec.loader.exec_module(module)
179+
return module
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .loss_scaler import LossScaler, DynamicLossScaler
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
class LossScaler(object):
2+
r"""Base class for implementing custom loss scaler strategies
3+
4+
Once the scaler is configured, no user intervention is needed to update loss scale during training.
5+
6+
Note:
7+
This class should never be instantiated, but used as an abstract class for custom loss scaling strategy.
8+
"""
9+
10+
def __init__(self, loss_scale):
11+
self._input_name = None
12+
self._loss_scale = loss_scale
13+
14+
@property
15+
def input_name(self):
16+
return self._input_name
17+
18+
@input_name.setter
19+
def input_name(self, input_name):
20+
assert isinstance(input_name, str), "'input_name' must be a string"
21+
assert input_name is None or len(input_name) > 0, "'input_name' cannot be empty"
22+
self._input_name = input_name
23+
24+
@property
25+
def loss_scale(self):
26+
return self._loss_scale
27+
28+
@loss_scale.setter
29+
def loss_scale(self, loss_scale):
30+
assert isinstance(loss_scale, float) and loss_scale > 0, "'loss_scale' must be a positive float"
31+
self._loss_scale = loss_scale
32+
33+
def reset(self):
34+
r"""Resets loss scaler internal state"""
35+
raise NotImplementedError
36+
37+
def update(self, train_step_info):
38+
r"""Updates loss based on user input and training session info
39+
40+
Args:
41+
train_step_info (TrainStepInfo): last step state information
42+
43+
Returns:
44+
Updated loss scale (float)
45+
"""
46+
raise NotImplementedError
47+
48+
49+
class DynamicLossScaler(LossScaler):
50+
r"""Default implementation for :py:class:`.LossScaler` class used for mixed precision
51+
52+
This loss scaler works by assuming an initial scale, which is doubled every time a certain number of
53+
(stable) training steps are performed without exploding gradients (overflow or reach infinity).
54+
When at least one of the gradients explode, loss scale is divided by 2.
55+
56+
Users can use this class in two ways:
57+
58+
1. Enable mixed precision and not setting a loss scaler class. Default values are used
59+
2. Enable mixed precision and instantiate this class to override default arguments
60+
61+
Static loss scaling can be achieved by setting :py:attr:`.automatic_update` to :py:obj:`False`
62+
and not performing manual :py:meth:`update` in train loop.
63+
64+
Args:
65+
automatic_update (bool, default is False): boolean switch that allows :py:meth:`ORTTrainer.train_step`
66+
to automatically perform loss scaling. If False, an explicit call to :py:meth:`.update` must be done by the user,
67+
otherwise static loss scaling is performed.
68+
loss_scale (default is 1 << 16): A float that represents current loss scale
69+
up_scale_window (int, default is 2000): number of stable train steps before doubling loss scale
70+
min_loss_scale (float, default is 1): min value for the loss scale. Used when loss scale is decreased
71+
max_loss_scale (float, default is 1 << 24): max value for the loss scale. Used when loss scale is increased
72+
73+
Example with default values:
74+
.. code-block:: python
75+
76+
scaler1 = amp.DynamicLossScaler()
77+
print(f'Default loss scale is {scaler1.loss_scale}')
78+
79+
Example with user specified values:
80+
.. code-block:: python
81+
82+
scaler2 = amp.DynamicLossScaler(loss_scale=1<<8)
83+
print(f'Custom loss scale is {scaler2.loss_scale}')
84+
"""
85+
86+
def __init__(self, automatic_update=True,
87+
loss_scale=float(1 << 16),
88+
up_scale_window=2000,
89+
min_loss_scale=1.0,
90+
max_loss_scale=float(1 << 24)):
91+
super().__init__(loss_scale)
92+
self.automatic_update = automatic_update
93+
self.up_scale_window = up_scale_window
94+
self.min_loss_scale = min_loss_scale
95+
self.max_loss_scale = max_loss_scale
96+
97+
self._initial_loss_scale = loss_scale
98+
self._stable_steps_count = 0
99+
100+
def reset(self):
101+
self.loss_scale = self._initial_loss_scale
102+
self._stable_steps_count = 0
103+
104+
def update(self, train_step_info):
105+
if train_step_info.all_finite:
106+
self._stable_steps_count += 1
107+
108+
if self._stable_steps_count >= self.up_scale_window:
109+
self.loss_scale = min(self.max_loss_scale, self.loss_scale * 2)
110+
self._stable_steps_count = 0
111+
else:
112+
self.loss_scale = max(self.min_loss_scale, self.loss_scale / 2)
113+
self._stable_steps_count = 0
114+
return self.loss_scale

0 commit comments

Comments
 (0)