diff --git a/lib/smi.py b/lib/smi.py index 8d147237..7a9e9fb8 100644 --- a/lib/smi.py +++ b/lib/smi.py @@ -520,12 +520,11 @@ def fit2D4D_LLS_RealSphHarm_wSorting_norm_var(self, dwi, mask, rank1sl=True): if np.sum(l_max >= ll) > 1: id_useful_shells = id_shells[l_max>=ll] id_current_band = (l_all == ll) - for kk in range(n_voxels): - slm_voxel_raw = s_lm_clusters_all[id_current_band, kk, :][:, id_useful_shells] - slm_voxel_dn = self.low_rank_denoising(slm_voxel_raw, 1) - slm_dn[id_current_band, kk, :][:, id_useful_shells] = slm_voxel_dn - - sl_dn[int(ll/2),kk,id_useful_shells] = np.sqrt(np.sum(slm_voxel_dn**2, axis=0)) + slm_voxels_raw = s_lm_clusters_all[id_current_band, :, :][:, :, id_useful_shells] + slm_voxels_dn = self.low_rank_denoising(slm_voxels_raw.transpose(1,0,2), 1).transpose(1,0,2) + slm_dn[id_current_band, :, :][:, :, id_useful_shells] + sl_dn[int(ll/2),:,id_useful_shells] = np.sqrt(np.sum(slm_voxels_dn**2, axis=0)).T + s_lm_clusters_all = slm_dn.copy() s_l_clusters_all = sl_dn.copy() @@ -568,9 +567,11 @@ def fit2D4D_LLS_RealSphHarm_wSorting_norm_var(self, dwi, mask, rank1sl=True): def low_rank_denoising(self, X, p): u,s,v = np.linalg.svd(X, full_matrices=False) - s_dn = np.zeros_like(s) - s_dn[:p] = s[:p] - return u @ np.diag(s_dn) @ v + s_dn = np.zeros((s.shape[0], s.shape[1], s.shape[1])) + diag_inds = np.diag_indices(X.shape[2]) + s_dn[:,diag_inds[0][:p],diag_inds[1][:p]] = s[:,:p] + u_s = np.einsum('ijk,ikl->ijl', u, s_dn) + return np.einsum('ijk,ikl->ijl', u_s, v) def group_dwi_in_shells_b_beta_te(self): @@ -916,8 +917,6 @@ def standard_model_mlfit_rot_invs(self, rot_invs, sigma_norm_limits): np.divide( sigma_normalized, s0_lowest_te, out=sigma_normalized, where=s0_lowest_te != 0 ) - #rot_invs_normalized = (rot_invs_normalized / s0_lowest_te).T - #sigma_normalized = (sigma_normalized / s0_lowest_te).T shells = self.table_4d[0,:] beta = self.table_4d[1,:] @@ -950,14 +949,14 @@ def standard_model_mlfit_rot_invs(self, rot_invs, sigma_norm_limits): ) sigma_noise_norm_levels_ids = np.digitize( sigma_normalized, sigma_noise_norm_levels_edges - ) - 1 - + ) + sigma_noise_norm_levels_ids[sigma_normalized < sigma_noise_norm_levels_edges[0]] = 0 - sigma_noise_norm_levels_ids[sigma_normalized > sigma_noise_norm_levels_edges[-1]] = self.n_levels - 1 + sigma_noise_norm_levels_ids[sigma_normalized > sigma_noise_norm_levels_edges[-1]] = self.n_levels + 1 sigma_noise_norm_levels_mean = 1/2 * ( sigma_noise_norm_levels_edges[1:] + sigma_noise_norm_levels_edges[:-1] ) - + degree_kernel = 3 x_fit_norm = self.compute_extended_moments( rot_invs_normalized[:, keep_rot_invs_kernel], degree_kernel @@ -1009,17 +1008,17 @@ def standard_model_mlfit_rot_invs(self, rot_invs, sigma_norm_limits): rotinvs_train_norm = rotinvs_train / rotinvs_train[:,[0]] - for i in range(self.n_levels): + for i in range(1, len(sigma_noise_norm_levels_edges)): flag_current_noise_level = sigma_noise_norm_levels_ids == i if not np.any(flag_current_noise_level): continue - sigma_rotinvs_training = sigma_noise_norm_levels_mean[i] / sigma_ndirs_factor + sigma_rotinvs_training = sigma_noise_norm_levels_mean[i-1] / sigma_ndirs_factor meas_rotinvs_train = (rotinvs_train_norm + sigma_rotinvs_training * np.random.standard_normal((rotinvs_train_norm.shape)) ) - + x_train = self.compute_extended_moments( meas_rotinvs_train[:, keep_rot_invs_kernel], degree=degree_kernel)