Skip to content

Commit 498f946

Browse files
author
Suffian Khan
authored
Keep all_finite tensor on CPU when using PyTorch Frontend (microsoft#5371)
1 parent c2c7839 commit 498f946

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

orttraining/orttraining/python/training/orttrainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,12 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de
801801
outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc)
802802
result = {}
803803
for output_desc in outputs_desc_resolved:
804-
torch_tensor = torch.zeros(output_desc.shape, device=self.options.device.id,
804+
target_device = self.options.device.id
805+
if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name:
806+
# Keep all finite flag on CPU to match backend implementation
807+
# This prevents CPU -> GPU -> CPU copies between frontend and backend
808+
target_device = 'cpu'
809+
torch_tensor = torch.zeros(output_desc.shape, device=target_device,
805810
dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype)
806811
iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(self.options.device.id),
807812
_utils.dtype_torch_to_numpy(torch_tensor.dtype),

0 commit comments

Comments
 (0)