Skip to content

Commit 232ae25

Browse files
committed
cellmap_flow/
1 parent 6d779c9 commit 232ae25

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

cellmap_flow/inferencer.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
ScriptModelConfig,
88
)
99
from cellmap_flow.norm.input_normalize import MinMaxNormalizer
10+
from funlib.geometry import Coordinate
1011
from funlib.persistence import Array
1112
import logging
1213

@@ -28,19 +29,24 @@ def predict(read_roi, write_roi, config, **kwargs):
2829
device = kwargs.get("device")
2930
if device is None:
3031
raise ValueError("device must be provided in kwargs")
32+
33+
use_half_prediction = kwargs.get("use_half_prediction", False)
3134

3235
raw_input = idi.to_ndarray_ts(read_roi)
3336
raw_input = config.input_normalizer.normalize(raw_input)
3437
raw_input = np.expand_dims(raw_input, (0, 1))
3538

3639
with torch.no_grad():
37-
raw_input_torch = torch.from_numpy(raw_input).float().half()
40+
raw_input_torch = torch.from_numpy(raw_input).float()
41+
if use_half_prediction:
42+
raw_input_torch = raw_input_torch.half()
3843
raw_input_torch = raw_input_torch.to(device, non_blocking=True)
3944
return config.model.forward(raw_input_torch).detach().cpu().numpy()[0]
4045

4146

4247
class Inferencer:
43-
def __init__(self, model_config: ModelConfig):
48+
def __init__(self, model_config: ModelConfig, use_half_prediction=False):
49+
self.use_half_prediction = use_half_prediction
4450
self.model_config = model_config
4551
# condig is lazy so one call is needed to get the config
4652
_ = self.model_config.config
@@ -49,8 +55,8 @@ def __init__(self, model_config: ModelConfig):
4955
self.model_config.config, "write_shape"
5056
):
5157
self.context = (
52-
self.model_config.config.read_shape
53-
- self.model_config.config.write_shape
58+
Coordinate(self.model_config.config.read_shape)
59+
- Coordinate(self.model_config.config.write_shape)
5460
) / 2
5561

5662
self.optimize_model()
@@ -66,7 +72,7 @@ def __init__(self, model_config: ModelConfig):
6672
logger.warning("No predict function provided, using default")
6773
self.model_config.config.predict = predict
6874

69-
def optimize_model(self, use_half_prediction=True):
75+
def optimize_model(self):
7076
if not hasattr(self.model_config.config, "model"):
7177
logger.error("Model is not loaded, cannot optimize")
7278
return
@@ -80,9 +86,10 @@ def optimize_model(self, use_half_prediction=True):
8086
self.device = torch.device("cpu")
8187
logger.error("No GPU available, using CPU")
8288
self.model_config.config.model.to(self.device)
83-
if use_half_prediction:
89+
if self.use_half_prediction:
8490
self.model_config.config.model.half()
8591
print(f"Using device: {self.device}")
92+
# DIDN'T WORK with unet model
8693
# if torch.__version__ >= "2.0":
8794
# self.model_config.config.model = torch.compile(self.model_config.config.model)
8895
# print("Model compiled")
@@ -109,7 +116,7 @@ def process_chunk_basic(self, idi, roi):
109116

110117
input_roi = output_roi.grow(self.context, self.context)
111118
result = self.model_config.config.predict(
112-
input_roi, output_roi, self.model_config.config, idi=idi, device=self.device
119+
input_roi, output_roi, self.model_config.config, idi=idi, device=self.device, use_half_prediction=self.use_half_prediction
113120
)
114121
write_data = self.model_config.config.normalize_output(result)
115122

0 commit comments

Comments
 (0)