23
23
import math
24
24
from typing import Tuple
25
25
import numpy as np
26
- from httomo_backends .cufft import CufftType , cufft_estimate_1d
26
+ from httomo_backends .cufft import CufftType , cufft_estimate_1d , cufft_estimate_2d
27
27
28
28
__all__ = [
29
29
"_calc_memory_bytes_FBP3d_tomobar" ,
30
+ "_calc_memory_bytes_LPRec3d_tomobar" ,
30
31
"_calc_memory_bytes_SIRT3d_tomobar" ,
31
32
"_calc_memory_bytes_CGLS3d_tomobar" ,
32
33
"_calc_output_dim_FBP2d_astra" ,
33
34
"_calc_output_dim_FBP3d_tomobar" ,
35
+ "_calc_output_dim_LPRec3d_tomobar" ,
34
36
"_calc_output_dim_SIRT3d_tomobar" ,
35
37
"_calc_output_dim_CGLS3d_tomobar" ,
36
38
]
@@ -58,6 +60,10 @@ def _calc_output_dim_FBP3d_tomobar(non_slice_dims_shape, **kwargs):
58
60
return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
59
61
60
62
63
+ def _calc_output_dim_LPRec3d_tomobar (non_slice_dims_shape , ** kwargs ):
64
+ return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
65
+
66
+
61
67
def _calc_output_dim_SIRT3d_tomobar (non_slice_dims_shape , ** kwargs ):
62
68
return __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
63
69
@@ -71,12 +77,17 @@ def _calc_memory_bytes_FBP3d_tomobar(
71
77
dtype : np .dtype ,
72
78
** kwargs ,
73
79
) -> Tuple [int , int ]:
74
- det_height = non_slice_dims_shape [0 ]
75
- det_width = non_slice_dims_shape [1 ]
80
+ if "detector_pad" in kwargs :
81
+ detector_pad = kwargs ["detector_pad" ]
82
+ else :
83
+ detector_pad = 0
84
+
85
+ angles_tot = non_slice_dims_shape [0 ]
86
+ det_width = non_slice_dims_shape [1 ] + 2 * detector_pad
76
87
SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold
77
88
78
89
# 1. input
79
- input_slice_size = np . prod ( non_slice_dims_shape ) * dtype .itemsize
90
+ input_slice_size = ( angles_tot * det_width ) * dtype .itemsize
80
91
81
92
########## FFT / filter / IFFT (filtersync_cupy)
82
93
@@ -85,13 +96,13 @@ def _calc_memory_bytes_FBP3d_tomobar(
85
96
cufft_estimate_1d (
86
97
nx = det_width ,
87
98
fft_type = CufftType .CUFFT_R2C ,
88
- batch = det_height * SLICES ,
99
+ batch = angles_tot * SLICES ,
89
100
)
90
101
/ SLICES
91
102
)
92
103
93
104
# 3. RFFT output size (proj_f in code)
94
- proj_f_slice = det_height * (det_width // 2 + 1 ) * np .complex64 ().itemsize
105
+ proj_f_slice = angles_tot * (det_width // 2 + 1 ) * np .complex64 ().itemsize
95
106
96
107
# 4. Filter size (independent of number of slices)
97
108
filter_size = (det_width // 2 + 1 ) * np .float32 ().itemsize
@@ -101,7 +112,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
101
112
cufft_estimate_1d (
102
113
nx = det_width ,
103
114
fft_type = CufftType .CUFFT_C2R ,
104
- batch = det_height * SLICES ,
115
+ batch = angles_tot * SLICES ,
105
116
)
106
117
/ SLICES
107
118
)
@@ -117,9 +128,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
117
128
118
129
# 6. we swap the axes before passing data to Astra in ToMoBAR
119
130
# https://github.com/dkazanc/ToMoBAR/blob/54137829b6326406e09f6ef9c95eb35c213838a7/tomobar/methodsDIR_CuPy.py#L135
120
- pre_astra_input_swapaxis_slice = (
121
- np .prod (non_slice_dims_shape ) * np .float32 ().itemsize
122
- )
131
+ pre_astra_input_swapaxis_slice = (angles_tot * det_width ) * np .float32 ().itemsize
123
132
124
133
# 7. astra backprojection will generate an output array
125
134
# https://github.com/dkazanc/ToMoBAR/blob/54137829b6326406e09f6ef9c95eb35c213838a7/tomobar/astra_wrappers/astra_base.py#L524
@@ -145,29 +154,227 @@ def _calc_memory_bytes_FBP3d_tomobar(
145
154
# so it does not add to the memory overall
146
155
147
156
# We assume for safety here that one FFT plan is not freed and one is freed
148
- tot_memory_bytes = (
157
+ tot_memory_bytes = int (
149
158
projection_mem_size + filtersync_size - ifftplan_slice_size + recon_output_size
150
159
)
151
160
152
161
# this account for the memory used for filtration AND backprojection.
153
162
return (tot_memory_bytes , fixed_amount )
154
163
155
164
165
+ def _calc_memory_bytes_LPRec3d_tomobar (
166
+ non_slice_dims_shape : Tuple [int , int ],
167
+ dtype : np .dtype ,
168
+ ** kwargs ,
169
+ ) -> Tuple [int , int ]:
170
+ # Based on: https://github.com/dkazanc/ToMoBAR/pull/112/commits/4704ecdc6ded3dd5ec0583c2008aa104f30a8a39
171
+
172
+ if "detector_pad" in kwargs :
173
+ detector_pad = kwargs ["detector_pad" ]
174
+ else :
175
+ detector_pad = 0
176
+
177
+ angles_tot = non_slice_dims_shape [0 ]
178
+ DetectorsLengthH_prepad = non_slice_dims_shape [1 ]
179
+ DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
180
+ SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold
181
+ _CENTER_SIZE_MIN = 192 # must be divisible by 8
182
+
183
+ n = DetectorsLengthH
184
+
185
+ odd_horiz = False
186
+ if (n % 2 ) != 0 :
187
+ n = n + 1 # dealing with the odd horizontal detector size
188
+ odd_horiz = True
189
+
190
+ eps = 1e-4 # accuracy of usfft
191
+ mu = - np .log (eps ) / (2 * n * n )
192
+ m = int (
193
+ np .ceil (
194
+ 2 * n * 1 / np .pi * np .sqrt (- mu * np .log (eps ) + (mu * n ) * (mu * n ) / 4 )
195
+ )
196
+ )
197
+
198
+ center_size = 6144
199
+ center_size = min (center_size , n * 2 + m * 2 )
200
+
201
+ oversampling_level = 2 # at least 2 or larger required
202
+ ne = oversampling_level * n
203
+ padding_m = ne // 2 - n // 2
204
+
205
+ if "angles" in kwargs :
206
+ angles = kwargs ["angles" ]
207
+ sorted_theta_cpu = np .sort (angles )
208
+ theta_full_range = abs (sorted_theta_cpu [angles_tot - 1 ] - sorted_theta_cpu [0 ])
209
+ angle_range_pi_count = 1 + int (np .ceil (theta_full_range / math .pi ))
210
+ angle_range_pi_count += 1 # account for difference from actual algorithm
211
+ else :
212
+ angle_range_pi_count = 1 + int (
213
+ np .ceil (2 )
214
+ ) # assume a 2 * PI projection angle range
215
+
216
+ chunk_count = 4
217
+
218
+ output_dims = __calc_output_dim_recon (non_slice_dims_shape , ** kwargs )
219
+ if odd_horiz :
220
+ output_dims = tuple (x + 1 for x in output_dims )
221
+
222
+ in_slice_size = (angles_tot * DetectorsLengthH ) * dtype .itemsize
223
+ padded_in_slice_size = angles_tot * n * np .float32 ().itemsize
224
+
225
+ theta_size = angles_tot * np .float32 ().itemsize
226
+ filter_size = (n // 2 + 1 ) * np .float32 ().itemsize
227
+ rfftfreq_size = filter_size
228
+ scaled_filter_size = filter_size
229
+
230
+ tmp_p_input_slice = angles_tot * n * np .float32 ().itemsize
231
+
232
+ padded_tmp_p_input_slice = angles_tot * (n + padding_m * 2 ) * np .float32 ().itemsize
233
+ rfft_plan_slice_size = (
234
+ cufft_estimate_1d (
235
+ nx = n + padding_m * 2 ,
236
+ fft_type = CufftType .CUFFT_R2C ,
237
+ batch = angles_tot * SLICES ,
238
+ )
239
+ / SLICES
240
+ )
241
+ rfft_result_size = angles_tot * (n + padding_m * 2 ) * np .complex64 ().itemsize
242
+ filtered_rfft_result_size = rfft_result_size
243
+ irfft_plan_slice_size = (
244
+ cufft_estimate_1d (
245
+ nx = (n + padding_m * 2 ),
246
+ fft_type = CufftType .CUFFT_C2R ,
247
+ batch = angles_tot * SLICES ,
248
+ )
249
+ / SLICES
250
+ )
251
+ irfft_scratch_memory_size = filtered_rfft_result_size * 2
252
+ irfft_result_size = angles_tot * (n + padding_m * 2 ) * np .float32 ().itemsize
253
+
254
+ datac_size = angles_tot * n * np .complex64 ().itemsize / 2
255
+ fde_size = (2 * m + 2 * n ) * (2 * m + 2 * n ) * np .complex64 ().itemsize / 2
256
+ fft_plan_slice_size = (
257
+ cufft_estimate_1d (nx = n , fft_type = CufftType .CUFFT_C2C , batch = angles_tot * SLICES )
258
+ / SLICES
259
+ )
260
+ fft_result_size = datac_size
261
+
262
+ sorted_theta_indices_size = angles_tot * np .int64 ().itemsize
263
+ sorted_theta_size = angles_tot * np .float32 ().itemsize
264
+ angle_range_size = (
265
+ center_size * center_size * (1 + angle_range_pi_count * 2 ) * np .int16 ().itemsize
266
+ )
267
+
268
+ recon_output_size = (
269
+ DetectorsLengthH_prepad * DetectorsLengthH_prepad * np .float32 ().itemsize
270
+ )
271
+ ifft2_plan_slice_size = (
272
+ cufft_estimate_2d (
273
+ nx = (2 * m + 2 * n ), ny = (2 * m + 2 * n ), fft_type = CufftType .CUFFT_C2C
274
+ )
275
+ / 2
276
+ )
277
+ circular_mask_size = np .prod (output_dims ) / 2 * np .int64 ().itemsize * 4
278
+ after_recon_swapaxis_slice = recon_output_size
279
+
280
+ tot_memory_bytes = 0
281
+ current_tot_memory_bytes = 0
282
+
283
+ fixed_amount = 0
284
+ current_fixed_amount = 0
285
+
286
+ def add_to_memory_counters (amount , per_slice : bool ):
287
+ nonlocal tot_memory_bytes
288
+ nonlocal current_tot_memory_bytes
289
+ nonlocal fixed_amount
290
+ nonlocal current_fixed_amount
291
+
292
+ if per_slice :
293
+ current_tot_memory_bytes += amount
294
+ tot_memory_bytes = max (tot_memory_bytes , current_tot_memory_bytes )
295
+ else :
296
+ current_fixed_amount += amount
297
+ fixed_amount = max (fixed_amount , current_fixed_amount )
298
+
299
+ add_to_memory_counters (in_slice_size , True )
300
+ add_to_memory_counters (padded_in_slice_size , True )
301
+
302
+ add_to_memory_counters (theta_size , False )
303
+ if center_size >= _CENTER_SIZE_MIN :
304
+ add_to_memory_counters (sorted_theta_indices_size , False )
305
+ add_to_memory_counters (sorted_theta_size , False )
306
+ add_to_memory_counters (angle_range_size , False )
307
+ add_to_memory_counters (filter_size , False )
308
+ add_to_memory_counters (rfftfreq_size , False )
309
+ add_to_memory_counters (scaled_filter_size , False )
310
+
311
+ add_to_memory_counters (tmp_p_input_slice , True )
312
+
313
+ add_to_memory_counters (rfft_plan_slice_size / chunk_count * 2 , True )
314
+ add_to_memory_counters (irfft_plan_slice_size / chunk_count * 2 , True )
315
+ # add_to_memory_counters(irfft_scratch_memory_size / chunk_count, True)
316
+ for _ in range (0 , chunk_count ):
317
+ add_to_memory_counters (padded_tmp_p_input_slice / chunk_count , True )
318
+
319
+ add_to_memory_counters (rfft_result_size / chunk_count , True )
320
+ add_to_memory_counters (filtered_rfft_result_size / chunk_count , True )
321
+ add_to_memory_counters (- rfft_result_size / chunk_count , True )
322
+ add_to_memory_counters (- padded_tmp_p_input_slice / chunk_count , True )
323
+
324
+ add_to_memory_counters (irfft_scratch_memory_size / chunk_count , True )
325
+ add_to_memory_counters (- irfft_scratch_memory_size / chunk_count , True )
326
+ add_to_memory_counters (irfft_result_size / chunk_count , True )
327
+ add_to_memory_counters (- filtered_rfft_result_size / chunk_count , True )
328
+
329
+ add_to_memory_counters (- irfft_result_size / chunk_count , True )
330
+
331
+ add_to_memory_counters (- padded_in_slice_size , True )
332
+ add_to_memory_counters (- filter_size , False )
333
+ add_to_memory_counters (- rfftfreq_size , False )
334
+ add_to_memory_counters (- scaled_filter_size , False )
335
+
336
+ add_to_memory_counters (datac_size , True )
337
+ add_to_memory_counters (fde_size , True )
338
+ add_to_memory_counters (- tmp_p_input_slice , True )
339
+ add_to_memory_counters (fft_plan_slice_size , True )
340
+ add_to_memory_counters (fft_result_size , True )
341
+ add_to_memory_counters (- datac_size , True )
342
+
343
+ add_to_memory_counters (- fft_result_size , True )
344
+
345
+ add_to_memory_counters (ifft2_plan_slice_size / chunk_count * 2 , True )
346
+ for _ in range (0 , chunk_count ):
347
+ add_to_memory_counters (fde_size / chunk_count , True )
348
+ add_to_memory_counters (- fde_size / chunk_count , True )
349
+
350
+ add_to_memory_counters (recon_output_size , True )
351
+ add_to_memory_counters (- fde_size , True )
352
+ add_to_memory_counters (circular_mask_size , False )
353
+ add_to_memory_counters (after_recon_swapaxis_slice , True )
354
+
355
+ return (tot_memory_bytes * 1.05 , fixed_amount + 250 * 1024 * 1024 )
356
+
156
357
def _calc_memory_bytes_SIRT3d_tomobar (
157
358
non_slice_dims_shape : Tuple [int , int ],
158
359
dtype : np .dtype ,
159
360
** kwargs ,
160
361
) -> Tuple [int , int ]:
161
- DetectorsLengthH = non_slice_dims_shape [1 ]
362
+
363
+ if "detector_pad" in kwargs :
364
+ detector_pad = kwargs ["detector_pad" ]
365
+ else :
366
+ detector_pad = 0
367
+ anglesnum = non_slice_dims_shape [0 ]
368
+ DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
162
369
# calculate the output shape
163
370
output_dims = _calc_output_dim_SIRT3d_tomobar (non_slice_dims_shape , ** kwargs )
164
371
165
- in_data_size = np . prod ( non_slice_dims_shape ) * dtype .itemsize
372
+ in_data_size = ( anglesnum * DetectorsLengthH ) * dtype .itemsize
166
373
out_data_size = np .prod (output_dims ) * dtype .itemsize
167
374
168
375
astra_projection = 2.5 * (in_data_size + out_data_size )
169
376
170
- tot_memory_bytes = 2 * in_data_size + 2 * out_data_size + astra_projection
377
+ tot_memory_bytes = int ( 2 * in_data_size + 2 * out_data_size + astra_projection )
171
378
return (tot_memory_bytes , 0 )
172
379
173
380
@@ -176,14 +383,20 @@ def _calc_memory_bytes_CGLS3d_tomobar(
176
383
dtype : np .dtype ,
177
384
** kwargs ,
178
385
) -> Tuple [int , int ]:
179
- DetectorsLengthH = non_slice_dims_shape [1 ]
386
+ if "detector_pad" in kwargs :
387
+ detector_pad = kwargs ["detector_pad" ]
388
+ else :
389
+ detector_pad = 0
390
+
391
+ anglesnum = non_slice_dims_shape [0 ]
392
+ DetectorsLengthH = non_slice_dims_shape [1 ] + 2 * detector_pad
180
393
# calculate the output shape
181
394
output_dims = _calc_output_dim_CGLS3d_tomobar (non_slice_dims_shape , ** kwargs )
182
395
183
- in_data_size = np . prod ( non_slice_dims_shape ) * dtype .itemsize
396
+ in_data_size = ( anglesnum * DetectorsLengthH ) * dtype .itemsize
184
397
out_data_size = np .prod (output_dims ) * dtype .itemsize
185
398
186
399
astra_projection = 2.5 * (in_data_size + out_data_size )
187
400
188
- tot_memory_bytes = 2 * in_data_size + 2 * out_data_size + astra_projection
401
+ tot_memory_bytes = int ( 2 * in_data_size + 2 * out_data_size + astra_projection )
189
402
return (tot_memory_bytes , 0 )
0 commit comments