7
7
ScriptModelConfig ,
8
8
)
9
9
from cellmap_flow .norm .input_normalize import MinMaxNormalizer
10
+ from funlib .geometry import Coordinate
10
11
from funlib .persistence import Array
11
12
import logging
12
13
@@ -28,19 +29,24 @@ def predict(read_roi, write_roi, config, **kwargs):
28
29
device = kwargs .get ("device" )
29
30
if device is None :
30
31
raise ValueError ("device must be provided in kwargs" )
32
+
33
+ use_half_prediction = kwargs .get ("use_half_prediction" , False )
31
34
32
35
raw_input = idi .to_ndarray_ts (read_roi )
33
36
raw_input = config .input_normalizer .normalize (raw_input )
34
37
raw_input = np .expand_dims (raw_input , (0 , 1 ))
35
38
36
39
with torch .no_grad ():
37
- raw_input_torch = torch .from_numpy (raw_input ).float ().half ()
40
+ raw_input_torch = torch .from_numpy (raw_input ).float ()
41
+ if use_half_prediction :
42
+ raw_input_torch = raw_input_torch .half ()
38
43
raw_input_torch = raw_input_torch .to (device , non_blocking = True )
39
44
return config .model .forward (raw_input_torch ).detach ().cpu ().numpy ()[0 ]
40
45
41
46
42
47
class Inferencer :
43
- def __init__ (self , model_config : ModelConfig ):
48
+ def __init__ (self , model_config : ModelConfig , use_half_prediction = False ):
49
+ self .use_half_prediction = use_half_prediction
44
50
self .model_config = model_config
45
51
# condig is lazy so one call is needed to get the config
46
52
_ = self .model_config .config
@@ -49,8 +55,8 @@ def __init__(self, model_config: ModelConfig):
49
55
self .model_config .config , "write_shape"
50
56
):
51
57
self .context = (
52
- self .model_config .config .read_shape
53
- - self .model_config .config .write_shape
58
+ Coordinate ( self .model_config .config .read_shape )
59
+ - Coordinate ( self .model_config .config .write_shape )
54
60
) / 2
55
61
56
62
self .optimize_model ()
@@ -66,7 +72,7 @@ def __init__(self, model_config: ModelConfig):
66
72
logger .warning ("No predict function provided, using default" )
67
73
self .model_config .config .predict = predict
68
74
69
- def optimize_model (self , use_half_prediction = True ):
75
+ def optimize_model (self ):
70
76
if not hasattr (self .model_config .config , "model" ):
71
77
logger .error ("Model is not loaded, cannot optimize" )
72
78
return
@@ -80,9 +86,10 @@ def optimize_model(self, use_half_prediction=True):
80
86
self .device = torch .device ("cpu" )
81
87
logger .error ("No GPU available, using CPU" )
82
88
self .model_config .config .model .to (self .device )
83
- if use_half_prediction :
89
+ if self . use_half_prediction :
84
90
self .model_config .config .model .half ()
85
91
print (f"Using device: { self .device } " )
92
+ # DIDN'T WORK with unet model
86
93
# if torch.__version__ >= "2.0":
87
94
# self.model_config.config.model = torch.compile(self.model_config.config.model)
88
95
# print("Model compiled")
@@ -109,7 +116,7 @@ def process_chunk_basic(self, idi, roi):
109
116
110
117
input_roi = output_roi .grow (self .context , self .context )
111
118
result = self .model_config .config .predict (
112
- input_roi , output_roi , self .model_config .config , idi = idi , device = self .device
119
+ input_roi , output_roi , self .model_config .config , idi = idi , device = self .device , use_half_prediction = self . use_half_prediction
113
120
)
114
121
write_data = self .model_config .config .normalize_output (result )
115
122
0 commit comments