Skip to content

Commit 6d23968

Browse files
authored
Merge pull request #1390 from MouseLand/train_float32
Train float32, disable bfloat16 training.
2 parents 17fb25f + 1783aea commit 6d23968

File tree

2 files changed

+35
-22
lines changed

2 files changed

+35
-22
lines changed

cellpose/train.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
319319
Train the network with images for segmentation.
320320
321321
Args:
322-
net (object): The network model to train. If `net` is a bfloat16 model on MPS, it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned in bfloat16 for consistency. CUDA/CPU will train in bfloat16 if that is the provided net dtype.
322+
net (object): The network model to train. If `net` is a bfloat16 model it will be converted to float32 for training. The saved models will be in float32, but the original model will be returned as the original dtype for consistency.
323323
train_data (List[np.ndarray], optional): List of arrays (2D or 3D) - images for training. Defaults to None.
324324
train_labels (List[np.ndarray], optional): List of arrays (2D or 3D) - labels for train_data, where 0=no masks; 1,2,...=mask labels. Defaults to None.
325325
train_files (List[str], optional): List of strings - file names for images in train_data (to save flows for future runs). Defaults to None.
@@ -356,13 +356,11 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
356356

357357
device = net.device
358358

359-
original_net_dtype = None
360-
if device.type == 'mps' and net.dtype == torch.bfloat16:
359+
original_net_dtype = net.dtype
360+
if net.dtype == torch.bfloat16:
361361
# NOTE: this produces a side effect of returning a network that is not of a guaranteed dtype \
362-
original_net_dtype = torch.bfloat16
363-
train_logger.warning("Training with bfloat16 on MPS is not supported, using float32 network instead")
362+
train_logger.info(">>> converting bfloat16 network to float32 for training")
364363
net.dtype = torch.float32
365-
net.to(torch.float32)
366364

367365
scale_range = 0.5 if scale_range is None else scale_range
368366

@@ -462,11 +460,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
462460
X = torch.from_numpy(imgi).to(device)
463461
lbl = torch.from_numpy(lbl).to(device)
464462

465-
if X.dtype != net.dtype:
466-
X = X.to(net.dtype)
467-
lbl = lbl.to(net.dtype)
468-
469-
y = net(X)[0]
463+
with torch.autocast(device_type=device.type, dtype=net.dtype):
464+
y = net(X)[0]
470465
loss = _loss_fn_seg(lbl, y, device)
471466
if y.shape[1] > 3:
472467
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
@@ -510,11 +505,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
510505
X = torch.from_numpy(imgi).to(device)
511506
lbl = torch.from_numpy(lbl).to(device)
512507

513-
if X.dtype != net.dtype:
514-
X = X.to(net.dtype)
515-
lbl = lbl.to(net.dtype)
516-
517-
y = net(X)[0]
508+
with torch.autocast(device_type=device.type, dtype=net.dtype):
509+
y = net(X)[0]
518510
loss = _loss_fn_seg(lbl, y, device)
519511
if y.shape[1] > 3:
520512
loss3 = _loss_fn_class(lbl, y, class_weights=class_weights)
@@ -539,9 +531,8 @@ def train_seg(net, train_data=None, train_labels=None, train_files=None,
539531
net.save_model(filename0)
540532

541533
net.save_model(filename)
542-
543-
if original_net_dtype is not None:
534+
if original_net_dtype != torch.float32:
535+
train_logger.info(f">>> converting network back to {original_net_dtype} after training")
544536
net.dtype = original_net_dtype
545-
net.to(original_net_dtype)
546537

547538
return filename, train_losses, test_losses

cellpose/vit_sam.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
4949
for blk in self.encoder.blocks:
5050
blk.window_size = 0
5151

52-
self.dtype = dtype
53-
if self.dtype != torch.float32:
54-
self = self.to(self.dtype)
52+
self._dtype = dtype
53+
if dtype != torch.float32:
54+
self.dtype = dtype
5555

5656
def forward(self, x):
5757
# same progression as SAM until readout
@@ -90,6 +90,7 @@ def load_model(self, PATH, device, strict = False):
9090
if w2_data == None:
9191
raise ValueError('This model does not appear to be a CP4 model. CP3 models are not compatible with CP4.')
9292

93+
# models are always saved as float32
9394
if keys[0][:7] == "module.":
9495
from collections import OrderedDict
9596
new_state_dict = OrderedDict()
@@ -103,6 +104,27 @@ def load_model(self, PATH, device, strict = False):
103104
if self.dtype != torch.float32:
104105
self = self.to(self.dtype)
105106

107+
@property
108+
def dtype(self):
109+
"""
110+
Get the data type of the model.
111+
112+
Returns:
113+
torch.dtype: The data type of the model.
114+
"""
115+
return self._dtype
116+
117+
@dtype.setter
118+
def dtype(self, value):
119+
"""
120+
Set the data type of the model.
121+
122+
Args:
123+
value (torch.dtype): The data type to set for the model.
124+
"""
125+
if self._dtype != value:
126+
self.to(value)
127+
self._dtype = value
106128

107129
@property
108130
def device(self):

0 commit comments

Comments
 (0)