Skip to content

Commit 3d04abd

Browse files
authored
Merge pull request #63 from DiamondLightSource/tomobarpadding
Changes to accomodate horizontal padding
2 parents acba4eb + 3da3464 commit 3d04abd

File tree

2 files changed

+98
-42
lines changed

2 files changed

+98
-42
lines changed

httomo_backends/methods_database/packages/backends/httomolibgpu/supporting_funcs/recon/algorithm.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,17 @@ def _calc_memory_bytes_FBP3d_tomobar(
7777
dtype: np.dtype,
7878
**kwargs,
7979
) -> Tuple[int, int]:
80-
det_height = non_slice_dims_shape[0]
81-
det_width = non_slice_dims_shape[1]
80+
if "detector_pad" in kwargs:
81+
detector_pad = kwargs["detector_pad"]
82+
else:
83+
detector_pad = 0
84+
85+
angles_tot = non_slice_dims_shape[0]
86+
det_width = non_slice_dims_shape[1] + 2 * detector_pad
8287
SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold
8388

8489
# 1. input
85-
input_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize
90+
input_slice_size = (angles_tot * det_width) * dtype.itemsize
8691

8792
########## FFT / filter / IFFT (filtersync_cupy)
8893

@@ -91,13 +96,13 @@ def _calc_memory_bytes_FBP3d_tomobar(
9196
cufft_estimate_1d(
9297
nx=det_width,
9398
fft_type=CufftType.CUFFT_R2C,
94-
batch=det_height * SLICES,
99+
batch=angles_tot * SLICES,
95100
)
96101
/ SLICES
97102
)
98103

99104
# 3. RFFT output size (proj_f in code)
100-
proj_f_slice = det_height * (det_width // 2 + 1) * np.complex64().itemsize
105+
proj_f_slice = angles_tot * (det_width // 2 + 1) * np.complex64().itemsize
101106

102107
# 4. Filter size (independent of number of slices)
103108
filter_size = (det_width // 2 + 1) * np.float32().itemsize
@@ -107,7 +112,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
107112
cufft_estimate_1d(
108113
nx=det_width,
109114
fft_type=CufftType.CUFFT_C2R,
110-
batch=det_height * SLICES,
115+
batch=angles_tot * SLICES,
111116
)
112117
/ SLICES
113118
)
@@ -123,9 +128,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
123128

124129
# 6. we swap the axes before passing data to Astra in ToMoBAR
125130
# https://github.com/dkazanc/ToMoBAR/blob/54137829b6326406e09f6ef9c95eb35c213838a7/tomobar/methodsDIR_CuPy.py#L135
126-
pre_astra_input_swapaxis_slice = (
127-
np.prod(non_slice_dims_shape) * np.float32().itemsize
128-
)
131+
pre_astra_input_swapaxis_slice = (angles_tot * det_width) * np.float32().itemsize
129132

130133
# 7. astra backprojection will generate an output array
131134
# https://github.com/dkazanc/ToMoBAR/blob/54137829b6326406e09f6ef9c95eb35c213838a7/tomobar/astra_wrappers/astra_base.py#L524
@@ -151,7 +154,7 @@ def _calc_memory_bytes_FBP3d_tomobar(
151154
# so it does not add to the memory overall
152155

153156
# We assume for safety here that one FFT plan is not freed and one is freed
154-
tot_memory_bytes = (
157+
tot_memory_bytes = int(
155158
projection_mem_size + filtersync_size - ifftplan_slice_size + recon_output_size
156159
)
157160

@@ -166,8 +169,14 @@ def _calc_memory_bytes_LPRec3d_tomobar(
166169
) -> Tuple[int, int]:
167170
# Based on: https://github.com/dkazanc/ToMoBAR/pull/112/commits/4704ecdc6ded3dd5ec0583c2008aa104f30a8a39
168171

172+
if "detector_pad" in kwargs:
173+
detector_pad = kwargs["detector_pad"]
174+
else:
175+
detector_pad = 0
176+
169177
angles_tot = non_slice_dims_shape[0]
170-
DetectorsLengthH = non_slice_dims_shape[1]
178+
DetectorsLengthH_prepad = non_slice_dims_shape[1]
179+
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
171180
SLICES = 200 # dummy multiplier+divisor to pass large batch size threshold
172181
_CENTER_SIZE_MIN = 192 # must be divisible by 8
173182

@@ -210,7 +219,7 @@ def _calc_memory_bytes_LPRec3d_tomobar(
210219
if odd_horiz:
211220
output_dims = tuple(x + 1 for x in output_dims)
212221

213-
in_slice_size = np.prod(non_slice_dims_shape) * dtype.itemsize
222+
in_slice_size = (angles_tot * DetectorsLengthH) * dtype.itemsize
214223
padded_in_slice_size = angles_tot * n * np.float32().itemsize
215224

216225
theta_size = angles_tot * np.float32().itemsize
@@ -256,7 +265,9 @@ def _calc_memory_bytes_LPRec3d_tomobar(
256265
center_size * center_size * (1 + angle_range_pi_count * 2) * np.int16().itemsize
257266
)
258267

259-
recon_output_size = DetectorsLengthH * DetectorsLengthH * np.float32().itemsize
268+
recon_output_size = (
269+
DetectorsLengthH_prepad * DetectorsLengthH_prepad * np.float32().itemsize
270+
)
260271
ifft2_plan_slice_size = (
261272
cufft_estimate_2d(
262273
nx=(2 * m + 2 * n), ny=(2 * m + 2 * n), fft_type=CufftType.CUFFT_C2C
@@ -342,24 +353,28 @@ def add_to_memory_counters(amount, per_slice: bool):
342353
add_to_memory_counters(after_recon_swapaxis_slice, True)
343354

344355
return (tot_memory_bytes * 1.05, fixed_amount + 250 * 1024 * 1024)
345-
# return (tot_memory_bytes, fixed_amount)
346-
347356

348357
def _calc_memory_bytes_SIRT3d_tomobar(
349358
non_slice_dims_shape: Tuple[int, int],
350359
dtype: np.dtype,
351360
**kwargs,
352361
) -> Tuple[int, int]:
353-
DetectorsLengthH = non_slice_dims_shape[1]
362+
363+
if "detector_pad" in kwargs:
364+
detector_pad = kwargs["detector_pad"]
365+
else:
366+
detector_pad = 0
367+
anglesnum = non_slice_dims_shape[0]
368+
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
354369
# calculate the output shape
355370
output_dims = _calc_output_dim_SIRT3d_tomobar(non_slice_dims_shape, **kwargs)
356371

357-
in_data_size = np.prod(non_slice_dims_shape) * dtype.itemsize
372+
in_data_size = (anglesnum * DetectorsLengthH) * dtype.itemsize
358373
out_data_size = np.prod(output_dims) * dtype.itemsize
359374

360375
astra_projection = 2.5 * (in_data_size + out_data_size)
361376

362-
tot_memory_bytes = 2 * in_data_size + 2 * out_data_size + astra_projection
377+
tot_memory_bytes = int(2 * in_data_size + 2 * out_data_size + astra_projection)
363378
return (tot_memory_bytes, 0)
364379

365380

@@ -368,14 +383,20 @@ def _calc_memory_bytes_CGLS3d_tomobar(
368383
dtype: np.dtype,
369384
**kwargs,
370385
) -> Tuple[int, int]:
371-
DetectorsLengthH = non_slice_dims_shape[1]
386+
if "detector_pad" in kwargs:
387+
detector_pad = kwargs["detector_pad"]
388+
else:
389+
detector_pad = 0
390+
391+
anglesnum = non_slice_dims_shape[0]
392+
DetectorsLengthH = non_slice_dims_shape[1] + 2 * detector_pad
372393
# calculate the output shape
373394
output_dims = _calc_output_dim_CGLS3d_tomobar(non_slice_dims_shape, **kwargs)
374395

375-
in_data_size = np.prod(non_slice_dims_shape) * dtype.itemsize
396+
in_data_size = (anglesnum * DetectorsLengthH) * dtype.itemsize
376397
out_data_size = np.prod(output_dims) * dtype.itemsize
377398

378399
astra_projection = 2.5 * (in_data_size + out_data_size)
379400

380-
tot_memory_bytes = 2 * in_data_size + 2 * out_data_size + astra_projection
401+
tot_memory_bytes = int(2 * in_data_size + 2 * out_data_size + astra_projection)
381402
return (tot_memory_bytes, 0)

tests/test_httomolibgpu.py

Lines changed: 56 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -528,17 +528,24 @@ def test_data_sampler_memoryhook(slices, newshape, interpolation, ensure_clean_m
528528

529529

530530
@pytest.mark.cupy
531+
@pytest.mark.parametrize("padding_detx", [0, 10, 100, 200])
531532
@pytest.mark.parametrize("projections", [1801, 3601])
532533
@pytest.mark.parametrize("slices", [7, 11, 15])
533534
@pytest.mark.parametrize("detectorX", [1200, 2560])
534535
def test_recon_FBP3d_tomobar_memoryhook(
535-
slices, detectorX, projections, ensure_clean_memory, mocker: MockerFixture
536+
slices,
537+
detectorX,
538+
projections,
539+
padding_detx,
540+
ensure_clean_memory,
541+
mocker: MockerFixture,
536542
):
537543
data = cp.random.random_sample((projections, slices, detectorX), dtype=np.float32)
538544
kwargs = {}
539545
kwargs["angles"] = np.linspace(
540546
0.0 * np.pi / 180.0, 180.0 * np.pi / 180.0, data.shape[0]
541547
)
548+
kwargs["detector_pad"] = padding_detx
542549
kwargs["center"] = 500
543550
kwargs["recon_size"] = detectorX
544551
kwargs["recon_mask_radius"] = 0.8
@@ -579,61 +586,88 @@ def test_recon_FBP3d_tomobar_memoryhook(
579586

580587

581588
@pytest.mark.cupy
582-
# @pytest.mark.parametrize("projections", [1801])
583-
# @pytest.mark.parametrize("detX_size", [2560])
584-
# @pytest.mark.parametrize("slices", [15])
585-
# @pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])
586-
587-
589+
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
588590
@pytest.mark.parametrize("projections", [1500, 1801, 2560])
589591
@pytest.mark.parametrize("detX_size", [2560])
590592
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
591593
@pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])
592-
593-
# @pytest.mark.parametrize("projections", [1500, 1801, 2560])
594-
# @pytest.mark.parametrize("detX_size", [2560])
595-
# @pytest.mark.parametrize("slices", [3, 4, 5, 10])
596-
# @pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])
597594
def test_recon_LPRec3d_tomobar_0_pi_memoryhook(
598-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
595+
slices,
596+
detX_size,
597+
projections,
598+
projection_angle_range,
599+
padding_detx,
600+
ensure_clean_memory,
599601
):
600602
__test_recon_LPRec3d_tomobar_memoryhook_common(
601-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
603+
slices,
604+
detX_size,
605+
projections,
606+
projection_angle_range,
607+
padding_detx,
608+
ensure_clean_memory,
602609
)
603610

604611

605612
@pytest.mark.full
606613
@pytest.mark.cupy
614+
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
607615
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
608616
@pytest.mark.parametrize("detX_size", [2560])
609617
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
610618
@pytest.mark.parametrize("projection_angle_range", [(0, np.pi)])
611619
def test_recon_LPRec3d_tomobar_0_pi_memoryhook_full(
612-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
620+
slices,
621+
detX_size,
622+
projections,
623+
projection_angle_range,
624+
padding_detx,
625+
ensure_clean_memory,
613626
):
614627
__test_recon_LPRec3d_tomobar_memoryhook_common(
615-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
628+
slices,
629+
detX_size,
630+
projections,
631+
projection_angle_range,
632+
padding_detx,
633+
ensure_clean_memory,
616634
)
617635

618636

619637
@pytest.mark.full
620638
@pytest.mark.cupy
639+
@pytest.mark.parametrize("padding_detx", [0, 10, 50, 100])
621640
@pytest.mark.parametrize("projections", [1500, 1801, 2560, 3601])
622641
@pytest.mark.parametrize("detX_size", [2560])
623642
@pytest.mark.parametrize("slices", [3, 4, 5, 10, 15, 20])
624643
@pytest.mark.parametrize(
625644
"projection_angle_range", [(0, np.pi), (0, 2 * np.pi), (-np.pi / 2, np.pi / 2)]
626645
)
627646
def test_recon_LPRec3d_tomobar_memoryhook_full(
628-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
647+
slices,
648+
detX_size,
649+
projections,
650+
projection_angle_range,
651+
padding_detx,
652+
ensure_clean_memory,
629653
):
630654
__test_recon_LPRec3d_tomobar_memoryhook_common(
631-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
655+
slices,
656+
detX_size,
657+
projections,
658+
projection_angle_range,
659+
padding_detx,
660+
ensure_clean_memory,
632661
)
633662

634663

635664
def __test_recon_LPRec3d_tomobar_memoryhook_common(
636-
slices, detX_size, projections, projection_angle_range, ensure_clean_memory
665+
slices,
666+
detX_size,
667+
projections,
668+
projection_angle_range,
669+
padding_detx,
670+
ensure_clean_memory,
637671
):
638672
angles_number = projections
639673
data = cp.random.random_sample((angles_number, slices, detX_size), dtype=np.float32)
@@ -642,6 +676,7 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
642676
projection_angle_range[0], projection_angle_range[1], data.shape[0]
643677
)
644678
kwargs["center"] = 1280
679+
kwargs["detector_pad"] = padding_detx
645680
kwargs["recon_size"] = detX_size
646681
kwargs["recon_mask_radius"] = 0.8
647682

@@ -687,9 +722,9 @@ def __test_recon_LPRec3d_tomobar_memoryhook_common(
687722
if slices <= 3:
688723
assert percents_relative_maxmem <= 75
689724
elif slices <= 5:
690-
assert percents_relative_maxmem <= 60
725+
assert percents_relative_maxmem <= 63
691726
else:
692-
assert percents_relative_maxmem <= 47
727+
assert percents_relative_maxmem <= 50
693728

694729

695730
@pytest.mark.cupy

0 commit comments

Comments
 (0)