diff --git a/src/nifreeze/data/dmri.py b/src/nifreeze/data/dmri.py index 878c72c..6e353ae 100644 --- a/src/nifreeze/data/dmri.py +++ b/src/nifreeze/data/dmri.py @@ -224,18 +224,24 @@ def to_nifti( filename : :obj:`os.pathlike` The output NIfTI file path. insert_b0 : :obj:`bool`, optional - Insert a :math:`b=0` at the front of the output NIfTI. + Insert a :math:`b=0` at the front of the output NIfTI and add the corresponding + null gradient value to the output bval/bvec files. bvals_dec_places : :obj:`int`, optional Decimal places to use when serializing b-values. bvecs_dec_places : :obj:`int`, optional Decimal places to use when serializing b-vectors. """ + bvecs = self.bvecs + bvals = self.bvals + if not insert_b0: # Parent's to_nifti to handle the primary NIfTI export. super().to_nifti(filename) else: data = np.concatenate((self.bzero[..., np.newaxis], self.dataobj), axis=-1) + bvecs = np.concatenate((np.zeros(3)[:, np.newaxis], bvecs), axis=-1) + bvals = np.concatenate((np.zeros(1), bvals)) nii = nb.Nifti1Image(data, self.affine, self.datahdr) if self.datahdr is None: nii.header.set_xyzt_units("mm") diff --git a/test/test_data_dmri.py b/test/test_data_dmri.py index a6afd79..9604939 100644 --- a/test/test_data_dmri.py +++ b/test/test_data_dmri.py @@ -95,40 +95,40 @@ def _serialize_dwi_data( ) -def test_load(datadir, tmp_path): +@pytest.mark.parametrize("insert_b0", (False, True)) +def test_load(datadir, tmp_path, insert_b0): """Check that the registration parameters for b=0 gives a good estimate of known affine""" dwi_h5 = load(datadir / "dwi.h5") dwi_nifti_path = tmp_path / "dwi.nii.gz" gradients_path = tmp_path / "dwi.tsv" - bvecs_path = tmp_path / "dwi.bvecs" - bvals_path = tmp_path / "dwi.bvals" - grad_table = np.hstack((np.zeros((4, 1)), dwi_h5.gradients)) - - dwi_h5.to_nifti(dwi_nifti_path, insert_b0=True) - np.savetxt(str(gradients_path), grad_table.T) - np.savetxt(str(bvecs_path), grad_table[:3]) - np.savetxt(str(bvals_path), grad_table[-1]) + dwi_h5.to_nifti(dwi_nifti_path, insert_b0=insert_b0) with pytest.raises(RuntimeError): load(dwi_nifti_path) - # Try loading NIfTI + gradients table - dwi_from_nifti1 = load(dwi_nifti_path, gradients_file=gradients_path) - - assert np.allclose(dwi_h5.dataobj, dwi_from_nifti1.dataobj) - assert np.allclose(dwi_h5.bzero, dwi_from_nifti1.bzero) - assert np.allclose(dwi_h5.gradients, dwi_from_nifti1.gradients) - # Try loading NIfTI + b-vecs/vals - dwi_from_nifti2 = load( + out_root = dwi_nifti_path.parent / dwi_nifti_path.name.replace("".join(dwi_nifti_path.suffixes), "") + bvecs_path = out_root.with_suffix(".bvec") + bvals_path = out_root.with_suffix(".bval") + dwi_from_nifti1 = load( dwi_nifti_path, bvec_file=bvecs_path, bval_file=bvals_path, ) + assert np.allclose(dwi_h5.dataobj, dwi_from_nifti1.dataobj) + assert np.allclose(dwi_h5.bzero, dwi_from_nifti1.bzero) + assert np.allclose(dwi_h5.gradients, dwi_from_nifti1.gradients, atol=1e-6) + + grad_table = np.hstack((np.zeros((4, 1)), dwi_h5.gradients)) + np.savetxt(str(gradients_path), grad_table.T) + + # Try loading NIfTI + gradients table + dwi_from_nifti2 = load(dwi_nifti_path, gradients_file=gradients_path) + assert np.allclose(dwi_h5.dataobj, dwi_from_nifti2.dataobj) assert np.allclose(dwi_h5.bzero, dwi_from_nifti2.bzero) assert np.allclose(dwi_h5.gradients, dwi_from_nifti2.gradients)