diff --git a/scripts/inner_ear/processing/filter_objects.py b/scripts/inner_ear/processing/filter_objects.py new file mode 100644 index 0000000..258ec58 --- /dev/null +++ b/scripts/inner_ear/processing/filter_objects.py @@ -0,0 +1,134 @@ +import os +from pathlib import Path +from tqdm import tqdm + +import h5py +import imageio.v3 as imageio +import numpy as np +from skimage.measure import label +from skimage.segmentation import relabel_sequential + +from synapse_net.file_utils import get_data_path +from parse_table import parse_table, get_data_root, _match_correction_folder, _match_correction_file + + +def _load_segmentation(seg_path): + ext = Path(seg_path).suffix + assert ext in (".h5", ".tif"), ext + if ext == ".tif": + seg = imageio.imread(seg_path) + else: + with h5py.File(seg_path, "r") as f: + seg = f["segmentation"][:] + return seg + + +def _save_segmentation(seg_path, seg): + ext = Path(seg_path).suffix + assert ext in (".h5", ".tif"), ext + if ext == ".tif": + imageio.imwrite(seg_path, seg, compression="zlib") + else: + with h5py.File(seg_path, "a") as f: + f.create_dataset("segmentation", data=seg, compression="gzip") + return seg + + +def _filter_n_objects(segmentation, num_objects): + # Create individual objects for all disconnected pieces. + segmentation = label(segmentation) + # Find object ids and sizes, excluding background. + ids, sizes = np.unique(segmentation, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + # Only keep the biggest 'num_objects' objects. + keep_ids = ids[np.argsort(sizes)[::-1]][:num_objects] + segmentation[~np.isin(segmentation, keep_ids)] = 0 + # Relabel the segmentation sequentially. + segmentation, _, _ = relabel_sequential(segmentation) + # Ensure that we have the correct number of objects. + n_ids = int(segmentation.max()) + assert n_ids == num_objects + return segmentation + + +def process_tomogram(folder, num_ribbon, num_pd): + data_path = get_data_path(folder) + output_folder = os.path.join(folder, "automatisch", "v2") + fname = Path(data_path).stem + + correction_folder = _match_correction_folder(folder) + + ribbon_path = _match_correction_file(correction_folder, "ribbon") + if not os.path.exists(ribbon_path): + ribbon_path = os.path.join(output_folder, f"{fname}_ribbon.h5") + assert os.path.exists(ribbon_path), ribbon_path + ribbon = _load_segmentation(ribbon_path) + + pd_path = _match_correction_file(correction_folder, "PD") + if not os.path.exists(pd_path): + pd_path = os.path.join(output_folder, f"{fname}_pd.h5") + assert os.path.exists(pd_path), pd_path + PD = _load_segmentation(pd_path) + + # Filter the ribbon and the PD. + print("Filtering number of ribbons:", num_ribbon) + ribbon = _filter_n_objects(ribbon, num_ribbon) + bkp_path_ribbon = ribbon_path + ".bkp" + os.rename(ribbon_path, bkp_path_ribbon) + _save_segmentation(ribbon_path, ribbon) + + print("Filtering number of PDs:", num_pd) + PD = _filter_n_objects(PD, num_pd) + bkp_path_pd = pd_path + ".bkp" + os.rename(pd_path, bkp_path_pd) + _save_segmentation(pd_path, PD) + + +def filter_objects(table, version): + for i, row in tqdm(table.iterrows(), total=len(table)): + folder = row["Local Path"] + if folder == "": + continue + + # We have to handle the segmentation without ribbon separately. + if row["PD vorhanden? "] == "nein": + continue + + n_pds = row["Anzahl PDs"] + if n_pds == "unklar": + n_pds = 1 + + n_pds = int(n_pds) + n_ribbons = int(row["Anzahl Ribbons"]) + if (n_ribbons == 2 and n_pds == 1): + print(f"The tomogram {folder} has {n_ribbons} ribbons and {n_pds} PDs.") + print("The structure post-processing for this case is not yet implemented and will be skipped.") + continue + + micro = row["EM alt vs. Neu"] + if micro == "beides": + process_tomogram(folder, n_ribbons, n_pds) + + folder_new = os.path.join(folder, "Tomo neues EM") + if not os.path.exists(folder_new): + folder_new = os.path.join(folder, "neues EM") + assert os.path.exists(folder_new), folder_new + process_tomogram(folder_new, n_ribbons, n_pds) + + elif micro == "alt": + process_tomogram(folder, n_ribbons, n_pds) + + elif micro == "neu": + process_tomogram(folder, n_ribbons, n_pds) + + +def main(): + data_root = get_data_root() + table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Übersicht.xlsx") + table = parse_table(table_path, data_root) + version = 2 + filter_objects(table, version) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/.gitignore b/scripts/otoferlin/.gitignore new file mode 100644 index 0000000..59d5ea8 --- /dev/null +++ b/scripts/otoferlin/.gitignore @@ -0,0 +1,4 @@ +data/ +sync_segmentation.sh +segmentation/ +results/ diff --git a/scripts/otoferlin/README.md b/scripts/otoferlin/README.md new file mode 100644 index 0000000..a96eda3 --- /dev/null +++ b/scripts/otoferlin/README.md @@ -0,0 +1,9 @@ +# Otoferlin Analysis + + +## Notes on improvements: + +- Try less than 20 exclude slices +- Update boundary post-proc (not robust when PD not found and selects wrong objects) +- Can we find fiducials with ilastik and mask them out? They are interfering with Ribbon finding. + - Alternative: just restrict the processing to a center crop by default. diff --git a/scripts/otoferlin/automatic_processing.py b/scripts/otoferlin/automatic_processing.py new file mode 100644 index 0000000..531a38c --- /dev/null +++ b/scripts/otoferlin/automatic_processing.py @@ -0,0 +1,201 @@ +import os + +import h5py +import numpy as np + +from skimage.measure import label +from skimage.segmentation import relabel_sequential + +from synapse_net.distance_measurements import measure_segmentation_to_object_distances +from synapse_net.file_utils import read_mrc +from synapse_net.inference.vesicles import segment_vesicles +from synapse_net.tools.util import get_model, compute_scale_from_voxel_size, _segment_ribbon_AZ +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path, get_adapted_model, load_segmentations + +# These are tomograms for which the sophisticated membrane processing fails. +# In this case, we just select the largest boundary piece. +SIMPLE_MEM_POSTPROCESSING = [ + "Otof_TDAKO1blockA_GridN5_2_rec.mrc", "Otof_TDAKO2blockC_GridF5_1_rec.mrc", "Otof_TDAKO2blockC_GridF5_2_rec.mrc", + "Bl6_NtoTDAWT1_blockH_GridF3_1_rec.mrc", "Bl6_NtoTDAWT1_blockH_GridG2_3_rec.mrc", "Otof_TDAKO1blockA_GridN5_5_rec.mrc", + "Otof_TDAKO2blockC_GridE2_1_rec.mrc", "Otof_TDAKO2blockC_GridE2_2_rec.mrc", + +] + + +def _get_center_crop(input_): + halo_xy = (600, 600) + bb_xy = tuple( + slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(input_.shape[1:], halo_xy) + ) + bb = (np.s_[:],) + bb_xy + return bb, input_.shape + + +def _get_tiling(): + # tile = {"x": 768, "y": 768, "z": 48} + tile = {"x": 512, "y": 512, "z": 48} + halo = {"x": 128, "y": 128, "z": 8} + return {"tile": tile, "halo": halo} + + +def process_vesicles(mrc_path, output_path, process_center_crop): + key = "segmentation/vesicles" + if os.path.exists(output_path): + with h5py.File(output_path, "r") as f: + if key in f: + return + + input_, voxel_size = read_mrc(mrc_path) + if process_center_crop: + bb, full_shape = _get_center_crop(input_) + input_ = input_[bb] + + model = get_adapted_model() + scale = compute_scale_from_voxel_size(voxel_size, "ribbon") + print("Rescaling volume for vesicle segmentation with factor:", scale) + tiling = _get_tiling() + segmentation = segment_vesicles(input_, model=model, scale=scale, tiling=tiling) + + if process_center_crop: + full_seg = np.zeros(full_shape, dtype=segmentation.dtype) + full_seg[bb] = segmentation + segmentation = full_seg + + with h5py.File(output_path, "a") as f: + f.create_dataset(key, data=segmentation, compression="gzip") + + +def _simple_membrane_postprocessing(membrane_prediction): + seg = label(membrane_prediction) + ids, sizes = np.unique(seg, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + return (seg == ids[np.argmax(sizes)]).astype("uint8") + + +def process_ribbon_structures(mrc_path, output_path, process_center_crop): + key = "segmentation/ribbon" + with h5py.File(output_path, "r") as f: + if key in f: + return + vesicles = f["segmentation/vesicles"][:] + + input_, voxel_size = read_mrc(mrc_path) + if process_center_crop: + bb, full_shape = _get_center_crop(input_) + input_, vesicles = input_[bb], vesicles[bb] + assert input_.shape == vesicles.shape + + model_name = "ribbon" + model = get_model(model_name) + scale = compute_scale_from_voxel_size(voxel_size, model_name) + tiling = _get_tiling() + + segmentations, predictions = _segment_ribbon_AZ( + input_, model, tiling=tiling, scale=scale, verbose=True, extra_segmentation=vesicles, + return_predictions=True, n_slices_exclude=5, + ) + + # The distance based post-processing for membranes fails for some tomograms. + # In these cases, just choose the largest membrane piece. + fname = os.path.basename(mrc_path) + if fname in SIMPLE_MEM_POSTPROCESSING: + segmentations["membrane"] = _simple_membrane_postprocessing(predictions["membrane"]) + + if process_center_crop: + for name, seg in segmentations.items(): + full_seg = np.zeros(full_shape, dtype=seg.dtype) + full_seg[bb] = seg + segmentations[name] = full_seg + for name, pred in predictions.items(): + full_pred = np.zeros(full_shape, dtype=seg.dtype) + full_pred[bb] = pred + predictions[name] = full_pred + + with h5py.File(output_path, "a") as f: + for name, seg in segmentations.items(): + f.create_dataset(f"segmentation/{name}", data=seg, compression="gzip") + f.create_dataset(f"prediction/{name}", data=predictions[name], compression="gzip") + + +def postprocess_vesicles( + mrc_path, output_path, process_center_crop, force=False +): + key = "segmentation/veiscles_postprocessed" + with h5py.File(output_path, "r") as f: + if key in f and not force: + return + vesicles = f["segmentation/vesicles"][:] + if process_center_crop: + bb, full_shape = _get_center_crop(vesicles) + vesicles = vesicles[bb] + else: + bb = np.s_[:] + + segs = load_segmentations(output_path) + ribbon = segs["ribbon"][bb] + membrane = segs["membrane"][bb] + + # Filter out small vesicle fragments. + min_size = 5000 + ids, sizes = np.unique(vesicles, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + filter_ids = ids[sizes < min_size] + vesicles[np.isin(vesicles, filter_ids)] = 0 + + input_, voxel_size = read_mrc(mrc_path) + voxel_size = tuple(voxel_size[ax] for ax in "zyx") + input_ = input_[bb] + + # Filter out all vesicles farther than 120 nm from the membrane or ribbon. + max_dist = 120 + seg = (ribbon + membrane) > 0 + distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, seg, resolution=voxel_size) + filter_ids = seg_ids[distances > max_dist] + vesicles[np.isin(vesicles, filter_ids)] = 0 + + vesicles, _, _ = relabel_sequential(vesicles) + + if process_center_crop: + full_seg = np.zeros(full_shape, dtype=vesicles.dtype) + full_seg[bb] = vesicles + vesicles = full_seg + with h5py.File(output_path, "a") as f: + if key in f: + f[key][:] = vesicles + else: + f.create_dataset(key, data=vesicles, compression="gzip") + + +def process_tomogram(mrc_path): + output_path = get_seg_path(mrc_path) + output_folder = os.path.split(output_path)[0] + os.makedirs(output_folder, exist_ok=True) + + process_center_crop = True + + process_vesicles(mrc_path, output_path, process_center_crop) + process_ribbon_structures(mrc_path, output_path, process_center_crop) + postprocess_vesicles(mrc_path, output_path, process_center_crop) + + +def main(): + tomograms = get_all_tomograms() + # for tomogram in tqdm(tomograms, desc="Process tomograms"): + # process_tomogram(tomogram) + + # Update the membrane postprocessing for the tomograms where this went wrong. + for tomo in tqdm(tomograms, desc="Fix membrame postprocesing"): + if os.path.basename(tomo) not in SIMPLE_MEM_POSTPROCESSING: + continue + seg_path = get_seg_path(tomo) + with h5py.File(seg_path, "r") as f: + pred = f["prediction/membrane"][:] + seg = _simple_membrane_postprocessing(pred) + with h5py.File(seg_path, "a") as f: + f["segmentation/membrane"][:] = seg + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/check_automatic_result.py b/scripts/otoferlin/check_automatic_result.py new file mode 100644 index 0000000..4c4c46c --- /dev/null +++ b/scripts/otoferlin/check_automatic_result.py @@ -0,0 +1,60 @@ +import os + +import h5py +import napari +import numpy as np + +from synapse_net.file_utils import read_mrc +from skimage.exposure import equalize_adapthist +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path, get_colormaps + + +def check_automatic_result(mrc_path, version, use_clahe=False, center_crop=True, segmentation_group="segmentation"): + tomogram, _ = read_mrc(mrc_path) + if center_crop: + halo = (50, 512, 512) + bb = tuple( + slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(tomogram.shape, halo) + ) + tomogram = tomogram[bb] + else: + bb = np.s_[:] + + if use_clahe: + print("Run CLAHE ...") + tomogram = equalize_adapthist(tomogram, clip_limit=0.03) + print("... done") + + seg_path = get_seg_path(mrc_path, version) + segmentations, colormaps = {}, {} + if os.path.exists(seg_path): + with h5py.File(seg_path, "r") as f: + g = f[segmentation_group] + for name, ds in g.items(): + segmentations[name] = ds[bb] + colormaps[name] = get_colormaps().get(name, None) + + v = napari.Viewer() + v.add_image(tomogram) + for name, seg in segmentations.items(): + v.add_labels(seg, name=name, colormap=colormaps.get(name)) + v.title = os.path.basename(mrc_path) + napari.run() + + +def main(): + version = 2 + tomograms = get_all_tomograms() + for i, tomogram in tqdm( + enumerate(tomograms), total=len(tomograms), desc="Visualize automatic segmentation results" + ): + print("Checking tomogram", tomogram) + check_automatic_result(tomogram, version) + # check_automatic_result(tomogram, version, segmentation_group="vesicles") + # check_automatic_result(tomogram, version, segmentation_group="prediction") + + +if __name__: + main() diff --git a/scripts/otoferlin/check_structure_postprocessing.py b/scripts/otoferlin/check_structure_postprocessing.py new file mode 100644 index 0000000..6b8d4de --- /dev/null +++ b/scripts/otoferlin/check_structure_postprocessing.py @@ -0,0 +1,56 @@ +import os + +import h5py +import napari +import numpy as np + +from synapse_net.file_utils import read_mrc +from tqdm import tqdm + +from common import get_seg_path, get_all_tomograms, get_colormaps, STRUCTURE_NAMES + + +def check_structure_postprocessing(mrc_path, center_crop=True): + tomogram, _ = read_mrc(mrc_path) + if center_crop: + halo = (50, 512, 512) + bb = tuple( + slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(tomogram.shape, halo) + ) + tomogram = tomogram[bb] + else: + bb = np.s_[:] + + seg_path = get_seg_path(mrc_path) + assert os.path.exists(seg_path) + + segmentations, predictions, colormaps = {}, {}, {} + with h5py.File(seg_path, "r") as f: + g = f["segmentation"] + for name in STRUCTURE_NAMES: + segmentations[f"seg/{name}"] = g[name][bb] + colormaps[name] = get_colormaps().get(name, None) + + g = f["prediction"] + for name in STRUCTURE_NAMES: + predictions[f"pred/{name}"] = g[name][bb] + + v = napari.Viewer() + v.add_image(tomogram) + for name, seg in segmentations.items(): + v.add_labels(seg, name=name, colormap=colormaps.get(name.split("/")[1])) + for name, pred in predictions.items(): + v.add_labels(pred, name=name, colormap=colormaps.get(name.split("/")[1]), visible=False) + v.title = os.path.basename(mrc_path) + napari.run() + + +def main(): + tomograms = get_all_tomograms() + for i, tomogram in tqdm(enumerate(tomograms), total=len(tomograms), desc="Check structure postproc"): + print(tomogram) + check_structure_postprocessing(tomogram) + + +if __name__: + main() diff --git a/scripts/otoferlin/common.py b/scripts/otoferlin/common.py new file mode 100644 index 0000000..9dd0ca7 --- /dev/null +++ b/scripts/otoferlin/common.py @@ -0,0 +1,116 @@ +import os +from glob import glob + +import imageio.v3 as imageio +import h5py +import pandas as pd +from synapse_net.tools.util import load_custom_model + + +# These are the files just for the test data. +# INPUT_ROOT = "/home/ag-wichmann/data/test-data/tomograms" +# OUTPUT_ROOT = "/home/ag-wichmann/data/test-data/segmentation" + +# These are the otoferlin tomograms. +INPUT_ROOT = "/home/ag-wichmann/data/otoferlin/tomograms" +OUTPUT_ROOT = "./segmentation" + +STRUCTURE_NAMES = ("ribbon", "PD", "membrane") + +# The version of the automatic segmentation. We have: +# - version 1: using the default models for all structures and the initial version of post-processing. +# - version 2: using the adapted model for vesicles in the otoferlin and updating the post-processing. +VERSION = 2 + + +def get_adapted_model(): + # Path on nhr. + # model_path = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/synaptic-reconstruction/scripts/otoferlin/domain_adaptation/checkpoints/otoferlin_da.pt" # noqa + # Path on the Workstation. + model_path = "/home/ag-wichmann/Downloads/otoferlin_da.pt" + model = load_custom_model(model_path) + return model + + +def get_folders(): + if os.path.exists(INPUT_ROOT): + return INPUT_ROOT, OUTPUT_ROOT + root_in = "./data/tomograms" + assert os.path.exists(root_in) + return root_in, OUTPUT_ROOT + + +def load_table(): + table_path = "overview Otoferlin samples.xlsx" + table_mut = pd.read_excel(table_path, sheet_name="Mut") + table_wt = pd.read_excel(table_path, sheet_name="Wt") + table = pd.concat([table_mut, table_wt]) + table = table[table["Einschluss? "] == "ja"] + return table + + +def get_all_tomograms(restrict_to_good_tomos=False, restrict_to_nachgeb=False): + root, _ = get_folders() + tomograms = glob(os.path.join(root, "**", "*.mrc"), recursive=True) + tomograms += glob(os.path.join(root, "**", "*.rec"), recursive=True) + tomograms = sorted(tomograms) + if restrict_to_good_tomos: + table = load_table() + if restrict_to_nachgeb: + table = table[table["nachgebessert"] == "ja"] + fnames = [os.path.basename(row["File name"]) for _, row in table.iterrows()] + tomograms = [tomo for tomo in tomograms if os.path.basename(tomo) in fnames] + # assert len(tomograms) == len(table), f"{len(tomograms), len(table)}" + return tomograms + + +def get_seg_path(mrc_path, version=VERSION): + input_root, output_root = get_folders() + rel_path = os.path.relpath(mrc_path, input_root) + rel_folder, fname = os.path.split(rel_path) + fname = os.path.splitext(fname)[0] + seg_path = os.path.join(output_root, f"v{VERSION}", rel_folder, f"{fname}.h5") + return seg_path + + +def get_colormaps(): + pool_map = { + "RA-V": (0, 0.33, 0), + "MP-V": (1.0, 0.549, 0.0), + "Docked-V": (1, 1, 0), + None: "gray", + } + ribbon_map = {1: "red", 2: "red", None: (0, 0, 0, 0), 0: (0, 0, 0, 0)} + membrane_map = {1: "purple", None: (0, 0, 0, 0)} + pd_map = {1: "magenta", 2: "magenta", None: (0, 0, 0, 0)} + return {"pools": pool_map, "membrane": membrane_map, "PD": pd_map, "ribbon": ribbon_map} + + +def load_segmentations(seg_path, verbose=True): + # Keep the typo in the name, as these are the hdf5 keys! + seg_names = {"vesicles": "veiscles_postprocessed"} + seg_names.update({name: name for name in STRUCTURE_NAMES}) + + segmentations = {} + correction_folder = os.path.join(os.path.split(seg_path)[0], "correction") + with h5py.File(seg_path, "r") as f: + g = f["segmentation"] + for out_name, name in seg_names.items(): + correction_path = os.path.join(correction_folder, f"{name}.tif") + if os.path.exists(correction_path): + if verbose: + print("Loading corrected", name, "segmentation from", correction_path) + segmentations[out_name] = imageio.imread(correction_path) + else: + segmentations[out_name] = g[f"{name}"][:] + return segmentations + + +def to_condition(mrc_path): + fname = os.path.basename(mrc_path) + return "TDA KO" if fname.startswith("Otof") else "TDA WT" + + +if __name__ == "__main__": + tomos = get_all_tomograms(restrict_to_good_tomos=True, restrict_to_nachgeb=True) + print("We have", len(tomos), "tomograms") diff --git a/scripts/otoferlin/compare_vesicle_segmentation.py b/scripts/otoferlin/compare_vesicle_segmentation.py new file mode 100644 index 0000000..555947d --- /dev/null +++ b/scripts/otoferlin/compare_vesicle_segmentation.py @@ -0,0 +1,58 @@ +import os + +import h5py + +from skimage.exposure import equalize_adapthist +from synapse_net.inference.vesicles import segment_vesicles +from synapse_net.file_utils import read_mrc +from synapse_net.tools.util import get_model, compute_scale_from_voxel_size, load_custom_model +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path + + +def compare_vesicles(tomo_path): + seg_path = get_seg_path(tomo_path) + seg_folder = os.path.split(seg_path)[0] + os.makedirs(seg_folder, exist_ok=True) + + model_paths = { + "adapted_v1": "/mnt/vast-nhr/home/pape41/u12086/inner-ear-da.pt", + "adapted_v2": "./domain_adaptation/checkpoints/otoferlin_da.pt" + } + for model_type in ("vesicles_3d", "adapted_v1", "adapted_v2"): + for use_clahe in (False, True): + seg_key = f"vesicles/{model_type}" + if use_clahe: + seg_key += "_clahe" + + if os.path.exists(seg_path): + with h5py.File(seg_path, "r") as f: + if seg_key in f: + continue + + tomogram, voxel_size = read_mrc(tomo_path) + if use_clahe: + tomogram = equalize_adapthist(tomogram, clip_limit=0.03) + + if model_type == "vesicles_3d": + model = get_model(model_type) + scale = compute_scale_from_voxel_size(voxel_size, model_type) + else: + model_path = model_paths[model_type] + model = load_custom_model(model_path) + scale = compute_scale_from_voxel_size(voxel_size, "ribbon") + + seg = segment_vesicles(tomogram, model=model, scale=scale) + with h5py.File(seg_path, "a") as f: + f.create_dataset(seg_key, data=seg, compression="gzip") + + +def main(): + tomograms = get_all_tomograms() + for tomo in tqdm(tomograms): + compare_vesicles(tomo) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/correct_structure_segmentation.py b/scripts/otoferlin/correct_structure_segmentation.py new file mode 100644 index 0000000..5c2ff15 --- /dev/null +++ b/scripts/otoferlin/correct_structure_segmentation.py @@ -0,0 +1,34 @@ +from pathlib import Path + +import napari + +from synapse_net.file_utils import read_mrc +from common import get_all_tomograms, get_seg_path, load_segmentations, get_colormaps + + +def correct_structure_segmentation(mrc_path): + seg_path = get_seg_path(mrc_path) + + data, _ = read_mrc(mrc_path) + segmentations = load_segmentations(seg_path) + color_maps = get_colormaps() + + v = napari.Viewer() + v.add_image(data) + for name, seg in segmentations.items(): + if name == "vesicles": + name = "veiscles_postprocessed" + v.add_labels(seg, name=name, colormap=color_maps.get(name, None)) + fname = Path(mrc_path).stem + v.title = fname + napari.run() + + +def main(): + tomograms = get_all_tomograms() + for tomo in tomograms: + correct_structure_segmentation(tomo) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/correct_vesicle_pools.py b/scripts/otoferlin/correct_vesicle_pools.py new file mode 100644 index 0000000..e99d7a2 --- /dev/null +++ b/scripts/otoferlin/correct_vesicle_pools.py @@ -0,0 +1,162 @@ +import os + +import imageio.v3 as imageio +import napari +import numpy as np +import pandas as pd +from magicgui import magicgui + +from synapse_net.file_utils import read_mrc +from synapse_net.distance_measurements import load_distances +from skimage.measure import regionprops +from common import load_segmentations, get_seg_path, get_all_tomograms, get_colormaps, STRUCTURE_NAMES +from tqdm import tqdm + +import warnings +warnings.filterwarnings("ignore") + + +# FIXME: adding vesicles to pool doesn't work / messes with color map +def _create_pool_layer(seg, assignment_path): + assignments = pd.read_csv(assignment_path) + pools = np.zeros_like(seg) + + pool_colors = get_colormaps()["pools"] + colormap = {None: "gray", 0: (0, 0, 0, 0)} + + # Sorting of floats and ints by np.unique is weird. We better don't trust unique here + # It should not matter if one of the pools is empty. + pool_names = ["RA-V", "MP-V", "Docked-V"] + + for pool_id, pool_name in enumerate(pool_names, 1): + if not isinstance(pool_name, str) and np.isnan(pool_name): + continue + pool_vesicle_ids = assignments[assignments.pool == pool_name].vesicle_id.values + pool_mask = np.isin(seg, pool_vesicle_ids) + pools[pool_mask] = pool_id + colormap[pool_id] = pool_colors[pool_name] + + return pools, colormap, assignments + + +def _update_assignments(vesicles, pool_correction, assignment_path): + old_assignments = pd.read_csv(assignment_path) + props = regionprops(vesicles, pool_correction) + + val_to_pool = {0: 0, 1: "RA-V", 2: "MP-V", 3: "Docked-V", 4: None} + corrected_pools = {prop.label: val_to_pool[int(prop.max_intensity)] for prop in props} + + new_assignments = [] + for _, row in old_assignments.iterrows(): + vesicle_id = row.vesicle_id + corrected_pool = corrected_pools[vesicle_id] + if corrected_pool != 0: + row.pool = corrected_pool + new_assignments.append(row) + new_assignments = pd.DataFrame(new_assignments) + new_assignments.to_csv(assignment_path, index=False) + + +def _create_outlier_mask(assignments, vesicles, output_folder): + distances = {} + for name in STRUCTURE_NAMES: + dist, _, _, ids = load_distances(os.path.join(output_folder, "distances", f"{name}.npz")) + distances[name] = {vid: dist for vid, dist in zip(ids, dist)} + + pool_criteria = { + "RA-V": {"ribbon": 80}, + "MP-V": {"PD": 100, "membrane": 50}, + "Docked-V": {"PD": 100, "membrane": 2}, + } + + vesicle_ids = assignments.vesicle_id.values + outlier_ids = [] + for pool in ("RA-V", "MP-V", "Docked-V"): + pool_ids = assignments[assignments.pool == pool].vesicle_id.values + for name in STRUCTURE_NAMES: + min_dist = pool_criteria[pool].get(name) + if min_dist is None: + continue + dist = distances[name] + assert len(dist) == len(vesicle_ids) + pool_outliers = [vid for vid in pool_ids if dist[vid] > min_dist] + if pool_outliers: + print("Pool:", pool, ":", name, ":", len(pool_outliers)) + outlier_ids.extend(pool_outliers) + + outlier_ids = np.unique(outlier_ids) + outlier_mask = np.isin(vesicles, outlier_ids).astype("uint8") + return outlier_mask + + +def correct_vesicle_pools(mrc_path, show_outliers, skip_if_no_outlier=False): + seg_path = get_seg_path(mrc_path) + + output_folder = os.path.split(seg_path)[0] + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + if not os.path.exists(assignment_path): + print("Skip", seg_path, "due to missing assignments") + return + + data, _ = read_mrc(mrc_path) + segmentations = load_segmentations(seg_path, verbose=False) + vesicles = segmentations["vesicles"] + + colormaps = get_colormaps() + pool_colors = colormaps["pools"] + correction_colors = { + 1: pool_colors["RA-V"], 2: pool_colors["MP-V"], 3: pool_colors["Docked-V"], 4: "Gray", None: "Gray" + } + + vesicle_pools, pool_colors, assignments = _create_pool_layer(vesicles, assignment_path) + if show_outliers: + outlier_mask = _create_outlier_mask(assignments, vesicles, output_folder) + else: + outlier_mask = None + + if skip_if_no_outlier and outlier_mask.sum() == 0: + return + + pool_correction_path = os.path.join(output_folder, "correction", "pool_correction.tif") + os.makedirs(os.path.join(output_folder, "correction"), exist_ok=True) + if os.path.exists(pool_correction_path): + pool_correction = imageio.imread(pool_correction_path) + else: + pool_correction = np.zeros_like(vesicles) + + v = napari.Viewer() + v.add_image(data) + v.add_labels(vesicle_pools, colormap=pool_colors) + v.add_labels(pool_correction, colormap=correction_colors) + v.add_labels(vesicles, visible=False) + for name in STRUCTURE_NAMES: + # v.add_labels(segmentations[name], name=name, visible=False, colormap=colormaps[name]) + v.add_labels(segmentations[name], name=name, visible=False) + + if outlier_mask is not None: + v.add_labels(outlier_mask) + + @magicgui(call_button="Update Pools") + def update_pools(viewer: napari.Viewer): + pool_data = viewer.layers["vesicle_pools"].data + vesicles = viewer.layers["vesicles"].data + pool_correction = viewer.layers["pool_correction"].data + _update_assignments(vesicles, pool_correction, assignment_path) + pool_data, pool_colors, _ = _create_pool_layer(vesicles, assignment_path) + viewer.layers["vesicle_pools"].data = pool_data + viewer.layers["vesicle_pools"].colormap = pool_colors + + v.window.add_dock_widget(update_pools) + v.title = os.path.basename(mrc_path) + + napari.run() + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomo in tqdm(tomograms): + correct_vesicle_pools(tomo, show_outliers=True, skip_if_no_outlier=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/domain_adaptation/train_domain_adaptation.py b/scripts/otoferlin/domain_adaptation/train_domain_adaptation.py new file mode 100644 index 0000000..99b32f7 --- /dev/null +++ b/scripts/otoferlin/domain_adaptation/train_domain_adaptation.py @@ -0,0 +1,66 @@ +import os +from glob import glob + +import h5py + +from synapse_net.file_utils import read_mrc +from synapse_net.training.domain_adaptation import mean_teacher_adaptation +from synapse_net.tools.util import compute_scale_from_voxel_size +from synapse_net.inference.util import _Scaler + + +# Apply rescaling, depending on what the segmentation comparison shows. +def preprocess_training_data(): + root = "../data/tomograms" + tomograms = glob(os.path.join(root, "**", "*.mrc"), recursive=True) + tomograms += glob(os.path.join(root, "**", "*.rec"), recursive=True) + tomograms = sorted(tomograms) + + train_folder = "./train_data" + os.makedirs(train_folder, exist_ok=True) + + all_paths = [] + for i, tomo_path in enumerate(tomograms): + out_path = os.path.join(train_folder, f"tomo{i}.h5") + if os.path.exists(out_path): + all_paths.append(out_path) + continue + + data, voxel_size = read_mrc(tomo_path) + scale = compute_scale_from_voxel_size(voxel_size, "ribbon") + print("Scale factor:", scale) + scaler = _Scaler(scale, verbose=True) + data = scaler.scale_input(data) + + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=data, compression="gzip") + all_paths.append(out_path) + + train_paths, val_paths = all_paths[:-1], all_paths[-1:] + return train_paths, val_paths + + +def train_domain_adaptation(train_paths, val_paths): + model_path = "/mnt/vast-nhr/home/pape41/u12086/inner-ear-da.pt" + model_name = "adapted_otoferlin" + + patch_shape = [48, 384, 384] + mean_teacher_adaptation( + name=model_name, + unsupervised_train_paths=train_paths, + unsupervised_val_paths=val_paths, + raw_key="raw", + patch_shape=patch_shape, + source_checkpoint=model_path, + confidence_threshold=0.75, + n_iterations=int(2.5*1e4), + ) + + +def main(): + train_paths, val_paths = preprocess_training_data() + train_domain_adaptation(train_paths, val_paths) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/ensure_labeled_all_vesicles.py b/scripts/otoferlin/ensure_labeled_all_vesicles.py new file mode 100644 index 0000000..c32f8b9 --- /dev/null +++ b/scripts/otoferlin/ensure_labeled_all_vesicles.py @@ -0,0 +1,20 @@ +from common import get_all_tomograms, get_seg_path, load_segmentations +from tqdm import tqdm +from skimage.measure import label +import numpy as np + + +def ensure_labeled(vesicles): + n_ids = len(np.unique(vesicles)) + n_ids_labeled = len(np.unique(label(vesicles))) + assert n_ids == n_ids_labeled, f"{n_ids}, {n_ids_labeled}" + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + segmentations = load_segmentations(get_seg_path(tomogram)) + ensure_labeled(segmentations["vesicles"]) + + +main() diff --git a/scripts/otoferlin/export_results.py b/scripts/otoferlin/export_results.py new file mode 100644 index 0000000..f83ff3c --- /dev/null +++ b/scripts/otoferlin/export_results.py @@ -0,0 +1,133 @@ +import os +from datetime import datetime + +import numpy as np +import pandas as pd +from common import get_all_tomograms, get_seg_path, to_condition + +from synapse_net.distance_measurements import load_distances + + +def get_output_folder(): + output_root = "./results" + date = datetime.now().strftime("%Y%m%d") + + version = 1 + output_folder = os.path.join(output_root, f"{date}_{version}") + while os.path.exists(output_folder): + version += 1 + output_folder = os.path.join(output_root, f"{date}_{version}") + + os.makedirs(output_folder) + return output_folder + + +def _export_results(tomograms, result_path, result_extraction): + results = {} + for tomo in tomograms: + condition = to_condition(tomo) + res = result_extraction(tomo) + if condition in results: + results[condition].append(res) + else: + results[condition] = [res] + + for condition, res in results.items(): + res = pd.concat(res) + if os.path.exists(result_path): + with pd.ExcelWriter(result_path, engine="openpyxl", mode="a") as writer: + res.to_excel(writer, sheet_name=condition, index=False) + else: + res.to_excel(result_path, sheet_name=condition, index=False) + + +def load_measures(measure_path, min_radius=5): + measures = pd.read_csv(measure_path).dropna() + measures = measures[measures.radius > min_radius] + return measures + + +def count_vesicle_pools(measures, ribbon_id, tomo): + ribbon_measures = measures[measures.ribbon_id == ribbon_id] + pool_names, counts = np.unique(ribbon_measures.pool.values, return_counts=True) + pool_names, counts = pool_names.tolist(), counts.tolist() + pool_names.append("MP-V_all") + counts.append(counts[pool_names.index("MP-V")] + counts[pool_names.index("Docked-V")]) + res = {"tomogram": [os.path.basename(tomo)], "ribbon": ribbon_id} + res.update({k: v for k, v in zip(pool_names, counts)}) + return pd.DataFrame(res) + + +def export_vesicle_pools(tomograms, result_path): + + def result_extraction(tomo): + folder = os.path.split(get_seg_path(tomo))[0] + measure_path = os.path.join(folder, "vesicle_pools.csv") + measures = load_measures(measure_path) + ribbon_ids = pd.unique(measures.ribbon_id) + + results = [] + for ribbon_id in ribbon_ids: + res = count_vesicle_pools(measures, ribbon_id, tomo) + results.append(res) + return pd.concat(results) + + _export_results(tomograms, result_path, result_extraction) + + +def export_distances(tomograms, result_path): + def result_extraction(tomo): + folder = os.path.split(get_seg_path(tomo))[0] + measure_path = os.path.join(folder, "vesicle_pools.csv") + measures = load_measures(measure_path) + + measures = measures[measures.pool.isin(["MP-V", "Docked-V"])][["vesicle_id", "pool"]] + + # Load the distances to PD. + pd_distances, _, _, seg_ids = load_distances(os.path.join(folder, "distances", "PD.npz")) + pd_distances = {sid: dist for sid, dist in zip(seg_ids, pd_distances)} + measures["distance-to-pd"] = [pd_distances[vid] for vid in measures.vesicle_id.values] + + # Load the distances to membrane. + mem_distances, _, _, seg_ids = load_distances(os.path.join(folder, "distances", "membrane.npz")) + mem_distances = {sid: dist for sid, dist in zip(seg_ids, mem_distances)} + measures["distance-to-membrane"] = [mem_distances[vid] for vid in measures.vesicle_id.values] + + measures = measures.drop(columns=["vesicle_id"]) + measures.insert(0, "tomogram", len(measures) * [os.path.basename(tomo)]) + + return measures + + _export_results(tomograms, result_path, result_extraction) + + +def export_diameter(tomograms, result_path): + def result_extraction(tomo): + folder = os.path.split(get_seg_path(tomo))[0] + measure_path = os.path.join(folder, "vesicle_pools.csv") + measures = load_measures(measure_path) + + measures = measures[measures.pool.isin(["MP-V", "Docked-V"])][["pool", "diameter"]] + measures.insert(0, "tomogram", len(measures) * [os.path.basename(tomo)]) + + return measures + + _export_results(tomograms, result_path, result_extraction) + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + result_folder = get_output_folder() + + result_path = os.path.join(result_folder, "vesicle_pools.xlsx") + export_vesicle_pools(tomograms, result_path) + + result_path = os.path.join(result_folder, "distances.xlsx") + export_distances(tomograms, result_path) + + result_path = os.path.join(result_folder, "diameter.xlsx") + export_diameter(tomograms, result_path) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/export_to_imod.py b/scripts/otoferlin/export_to_imod.py new file mode 100644 index 0000000..35e5a72 --- /dev/null +++ b/scripts/otoferlin/export_to_imod.py @@ -0,0 +1,92 @@ +import os +from glob import glob + +from pathlib import Path +from subprocess import run + +import numpy as np +import pandas as pd + +from tqdm import tqdm +from synapse_net.imod.to_imod import write_segmentation_to_imod, write_segmentation_to_imod_as_points +from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, load_segmentations + + +def check_imod(mrc_path, mod_path): + run(["imod", mrc_path, mod_path]) + + +def export_tomogram(mrc_path, force): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + + # export_folder = os.path.join(output_folder, "imod") + tomo_name = Path(mrc_path).stem + export_folder = os.path.join(f"./imod/{tomo_name}") + if os.path.exists(export_folder) and not force: + return + + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + + os.makedirs(export_folder, exist_ok=True) + + # Load the pool assignments and export the pools to IMOD. + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + assignments = pd.read_csv(assignment_path) + + colors = { + "Docked-V": (255, 170, 127), # (1, 0.666667, 0.498039) + "RA-V": (0, 85, 0), # (0, 0.333333, 0) + "MP-V": (255, 170, 0), # (1, 0.666667, 0) + "ribbon": (255, 0, 0), + "PD": (255, 0, 255), # (1, 0, 1) + "membrane": (255, 170, 255), # 1, 0.666667, 1 + } + + pools = ['Docked-V', 'RA-V', 'MP-V'] + radius_factor = 0.85 + for pool in pools: + export_path = os.path.join(export_folder, f"{pool}.mod") + pool_ids = assignments[assignments.pool == pool].vesicle_id + pool_seg = vesicles.copy() + pool_seg[~np.isin(pool_seg, pool_ids)] = 0 + write_segmentation_to_imod_as_points( + mrc_path, pool_seg, export_path, min_radius=5, radius_factor=radius_factor, + color=colors.get(pool), name=pool, + ) + # check_imod(mrc_path, export_path) + + # Export the structures to IMOD. + for name in STRUCTURE_NAMES: + export_path = os.path.join(export_folder, f"{name}.mod") + color = colors.get(name) + write_segmentation_to_imod(mrc_path, segmentations[name], export_path, color=color) + # check_imod(mrc_path, export_path) + + # Join the model + all_mod_files = sorted(glob(os.path.join(export_folder, "*.mod"))) + export_path = os.path.join(export_folder, f"{tomo_name}.mod") + join_cmd = ["imodjoin"] + all_mod_files + [export_path] + run(join_cmd) + check_imod(mrc_path, export_path) + + +def main(): + force = True + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + tomograms_for_vis = [ + "Bl6_NtoTDAWT1_blockH_GridE4_1_rec.mrc", + "Otof_TDAKO1blockA_GridN5_6_rec.mrc", + ] + for tomogram in tqdm(tomograms, desc="Process tomograms"): + fname = os.path.basename(tomogram) + if fname not in tomograms_for_vis: + continue + print("Exporting:", fname) + export_tomogram(tomogram, force) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/filter_objects_and_measure.py b/scripts/otoferlin/filter_objects_and_measure.py new file mode 100644 index 0000000..1479bbe --- /dev/null +++ b/scripts/otoferlin/filter_objects_and_measure.py @@ -0,0 +1,81 @@ +import os +from tqdm import tqdm + +import numpy as np +from skimage.measure import label +from skimage.segmentation import relabel_sequential +from common import get_all_tomograms, get_seg_path, load_table, load_segmentations, STRUCTURE_NAMES +from synapse_net.distance_measurements import measure_segmentation_to_object_distances, load_distances +from synapse_net.file_utils import read_mrc + + +def _filter_n_objects(segmentation, num_objects): + # Create individual objects for all disconnected pieces. + segmentation = label(segmentation) + # Find object ids and sizes, excluding background. + ids, sizes = np.unique(segmentation, return_counts=True) + ids, sizes = ids[1:], sizes[1:] + # Only keep the biggest 'num_objects' objects. + keep_ids = ids[np.argsort(sizes)[::-1]][:num_objects] + segmentation[~np.isin(segmentation, keep_ids)] = 0 + # Relabel the segmentation sequentially. + segmentation, _, _ = relabel_sequential(segmentation) + # Ensure that we have the correct number of objects. + n_ids = int(segmentation.max()) + assert n_ids == num_objects + return segmentation + + +def filter_and_measure(mrc_path, seg_path, output_folder, force): + result_folder = os.path.join(output_folder, "distances") + if os.path.exists(result_folder) and not force: + return + + # Load the table to find out how many ribbons / PDs we have here. + table = load_table() + table = table[table["File name"] == os.path.basename(mrc_path)] + assert len(table) == 1 + + num_ribbon = int(table["#ribbons"].values[0]) + num_pd = int(table["PD?"].values[0]) + + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + structures = {name: segmentations[name] for name in STRUCTURE_NAMES} + + # Filter the ribbon and the PD. + print("Filtering number of ribbons:", num_ribbon) + structures["ribbon"] = _filter_n_objects(structures["ribbon"], num_ribbon) + print("Filtering number of PDs:", num_pd) + structures["PD"] = _filter_n_objects(structures["PD"], num_pd) + + _, resolution = read_mrc(mrc_path) + resolution = [resolution[ax] for ax in "zyx"] + + # Measure all the object distances. + for name in ("ribbon", "PD"): + seg = structures[name] + assert seg.sum() != 0, name + print("Compute vesicle distances to", name) + save_path = os.path.join(result_folder, f"{name}.npz") + measure_segmentation_to_object_distances(vesicles, seg, save_path=save_path, resolution=resolution) + + +def process_tomogram(mrc_path, force): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + + # Measure the distances. + filter_and_measure(mrc_path, seg_path, output_folder, force) + + +def main(): + force = True + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + process_tomogram(tomogram, force) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/handle_ribbon_assignments.py b/scripts/otoferlin/handle_ribbon_assignments.py new file mode 100644 index 0000000..8ac3586 --- /dev/null +++ b/scripts/otoferlin/handle_ribbon_assignments.py @@ -0,0 +1,57 @@ +import os +import pandas as pd +from synapse_net.distance_measurements import load_distances + +from common import get_all_tomograms, get_seg_path, load_table + + +def _add_one_to_assignment(mrc_path): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + + assignments = pd.read_csv(assignment_path) + assignments["ribbon_id"] = len(assignments) * [1] + assignments.to_csv(assignment_path, index=False) + + +def _update_assignments(mrc_path, num_ribbon): + print(mrc_path) + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + distance_path = os.path.join(output_folder, "distances", "ribbon.npz") + + _, _, _, seg_ids, object_ids = load_distances(distance_path, return_object_ids=True) + assert all(obj in range(1, num_ribbon + 1) for obj in object_ids) + + assignments = pd.read_csv(assignment_path) + assert len(assignments) == len(object_ids) + assert (seg_ids == assignments.vesicle_id.values).all() + assignments["ribbon_id"] = object_ids + assignments.to_csv(assignment_path, index=False) + + +def process_tomogram(mrc_path): + table = load_table() + table = table[table["File name"] == os.path.basename(mrc_path)] + assert len(table) == 1 + num_ribbon = int(table["#ribbons"].values[0]) + assert num_ribbon in (1, 2) + + if num_ribbon == 1: + _add_one_to_assignment(mrc_path) + else: + _update_assignments(mrc_path, num_ribbon) + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tomograms: + process_tomogram(tomogram) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/make_figure_napari.py b/scripts/otoferlin/make_figure_napari.py new file mode 100644 index 0000000..ef515d1 --- /dev/null +++ b/scripts/otoferlin/make_figure_napari.py @@ -0,0 +1,79 @@ +import os + +import napari +import numpy as np +import pandas as pd + +from synapse_net.file_utils import read_mrc +from common import get_all_tomograms, get_seg_path, load_segmentations, STRUCTURE_NAMES + + +colors = { + "Docked-V": (255, 170, 127), # (1, 0.666667, 0.498039) + "RA-V": (0, 85, 0), # (0, 0.333333, 0) + "MP-V": (255, 170, 0), # (1, 0.666667, 0) + "ribbon": (255, 0, 0), + "PD": (255, 0, 255), # (1, 0, 1) + "membrane": (255, 170, 255), # 1, 0.666667, 1 +} + + +def plot_napari(mrc_path, rotate=False): + data, voxel_size = read_mrc(mrc_path) + voxel_size = tuple(voxel_size[ax] for ax in "zyx") + + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + assignments = pd.read_csv(assignment_path) + + pools = np.zeros_like(vesicles) + pool_names = ["RA-V", "MP-V", "Docked-V"] + + pool_colors = {None: (0, 0, 0)} + for pool_id, pool_name in enumerate(pool_names, 1): + pool_vesicle_ids = assignments[assignments.pool == pool_name].vesicle_id.values + pool_mask = np.isin(vesicles, pool_vesicle_ids) + pools[pool_mask] = pool_id + color = colors.get(pool_name) + color = tuple(c / float(255) for c in color) + pool_colors[pool_id] = color + + if rotate: + data = np.rot90(data, k=3, axes=(1, 2)) + pools = np.rot90(pools, k=3, axes=(1, 2)) + segmentations = {name: np.rot90(segmentations[name], k=3, axes=(1, 2)) for name in STRUCTURE_NAMES} + + v = napari.Viewer() + v.add_image(data, scale=voxel_size) + v.add_labels(pools, colormap=pool_colors, scale=voxel_size) + for name in STRUCTURE_NAMES: + color = colors[name] + color = tuple(c / float(255) for c in color) + cmap = {1: color, None: (0, 0, 0)} + seg = (segmentations[name] > 0).astype("uint8") + v.add_labels(seg, colormap=cmap, scale=voxel_size, name=name) + v.scale_bar.visible = True + v.scale_bar.unit = "nm" + v.scale_bar.font_size = 18 + v.title = os.path.basename(mrc_path) + napari.run() + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + tomograms_for_vis = [ + "Bl6_NtoTDAWT1_blockH_GridE4_1_rec.mrc", + "Otof_TDAKO1blockA_GridN5_6_rec.mrc", + ] + for tomogram in tomograms: + fname = os.path.basename(tomogram) + if fname not in tomograms_for_vis: + continue + plot_napari(tomogram, rotate=fname.startswith("Otof")) + + +main() diff --git a/scripts/otoferlin/overview Otoferlin samples.xlsx b/scripts/otoferlin/overview Otoferlin samples.xlsx new file mode 100644 index 0000000..6380dfb Binary files /dev/null and b/scripts/otoferlin/overview Otoferlin samples.xlsx differ diff --git a/scripts/otoferlin/pool_assignments_and_measurements.py b/scripts/otoferlin/pool_assignments_and_measurements.py new file mode 100644 index 0000000..90a2d2e --- /dev/null +++ b/scripts/otoferlin/pool_assignments_and_measurements.py @@ -0,0 +1,125 @@ +import os + +import numpy as np +import pandas as pd + +from synapse_net.distance_measurements import measure_segmentation_to_object_distances, load_distances +from synapse_net.file_utils import read_mrc +from synapse_net.imod.to_imod import convert_segmentation_to_spheres +from skimage.measure import label +from tqdm import tqdm + +from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, load_segmentations + + +def ensure_labeled(vesicles): + n_ids = len(np.unique(vesicles)) + n_ids_labeled = len(np.unique(label(vesicles))) + assert n_ids == n_ids_labeled, f"{n_ids}, {n_ids_labeled}" + + +def measure_distances(mrc_path, seg_path, output_folder, force): + result_folder = os.path.join(output_folder, "distances") + if os.path.exists(result_folder) and not force: + return + + # Get the voxel size. + _, voxel_size = read_mrc(mrc_path) + resolution = tuple(voxel_size[ax] for ax in "zyx") + + # Load the segmentations. + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + ensure_labeled(vesicles) + structures = {name: segmentations[name] for name in STRUCTURE_NAMES} + + # Measure all the object distances. + os.makedirs(result_folder, exist_ok=True) + for name, seg in structures.items(): + if seg.sum() == 0: + print(name, "was not found, skipping the distance computation.") + continue + print("Compute vesicle distances to", name) + save_path = os.path.join(result_folder, f"{name}.npz") + measure_segmentation_to_object_distances(vesicles, seg, save_path=save_path, resolution=resolution) + + +def _measure_radii(seg_path): + segmentations = load_segmentations(seg_path) + vesicles = segmentations["vesicles"] + # The radius factor of 0.85 yields the best fit to vesicles in IMOD. + _, radii = convert_segmentation_to_spheres(vesicles, radius_factor=0.85) + return np.array(radii) + + +def assign_vesicle_pools_and_measure_radii(seg_path, output_folder, force): + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + if os.path.exists(assignment_path) and not force: + return + + distance_folder = os.path.join(output_folder, "distances") + distance_paths = {name: os.path.join(distance_folder, f"{name}.npz") for name in STRUCTURE_NAMES} + if not all(os.path.exists(path) for path in distance_paths.values()): + print("Skip vesicle pool assignment, because some distances are missing.") + print("This is probably due to the fact that the corresponding structures were not found.") + return + distances = {name: load_distances(path) for name, path in distance_paths.items()} + + # The distance criteria. + rav_ribbon_distance = 80 # nm + mpv_pd_distance = 100 # nm + mpv_mem_distance = 50 # nm + docked_pd_distance = 100 # nm + docked_mem_distance = 2 # nm + + rav_distances, seg_ids = distances["ribbon"][0], np.array(distances["ribbon"][-1]) + rav_ids = seg_ids[rav_distances < rav_ribbon_distance] + + pd_distances, mem_distances = distances["PD"][0], distances["membrane"][0] + assert len(pd_distances) == len(mem_distances) == len(rav_distances) + + mpv_ids = seg_ids[np.logical_and(pd_distances < mpv_pd_distance, mem_distances < mpv_mem_distance)] + docked_ids = seg_ids[np.logical_and(pd_distances < docked_pd_distance, mem_distances < docked_mem_distance)] + + # Create a dictionary to map vesicle ids to their corresponding pool. + # (RA-V get's over-written by MP-V, which is correct). + pool_assignments = {vid: "RA-V" for vid in rav_ids} + pool_assignments.update({vid: "MP-V" for vid in mpv_ids}) + pool_assignments.update({vid: "Docked-V" for vid in docked_ids}) + + pool_values = [pool_assignments.get(vid, None) for vid in seg_ids] + radii = _measure_radii(seg_path) + assert len(radii) == len(pool_values) + + pool_assignments = pd.DataFrame({ + "vesicle_id": seg_ids, + "pool": pool_values, + "radius": radii, + "diameter": 2 * radii, + }) + pool_assignments.to_csv(assignment_path, index=False) + + +def process_tomogram(mrc_path, force): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + + # Measure the distances. + measure_distances(mrc_path, seg_path, output_folder, force) + + # Assign the vesicle pools. + assign_vesicle_pools_and_measure_radii(seg_path, output_folder, force) + + # The surface area / volume for ribbon and PD will be done in a separate script. + + +def main(): + force = True + tomograms = get_all_tomograms(restrict_to_good_tomos=True, restrict_to_nachgeb=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + process_tomogram(tomogram, force) + + +if __name__ == "__main__": + main() diff --git a/scripts/otoferlin/postprocess_vesicles.py b/scripts/otoferlin/postprocess_vesicles.py new file mode 100644 index 0000000..370d38c --- /dev/null +++ b/scripts/otoferlin/postprocess_vesicles.py @@ -0,0 +1,73 @@ +import os +from pathlib import Path +from shutil import copyfile + +import imageio.v3 as imageio +import napari +import h5py + +from skimage.measure import label +from tqdm import tqdm + +from common import get_all_tomograms, get_seg_path +from synapse_net.file_utils import read_mrc +from automatic_processing import postprocess_vesicles + +TOMOS = [ + "Otof_TDAKO2blockC_GridE2_1_rec", + "Otof_TDAKO1blockA_GridN5_3_rec", + "Otof_TDAKO1blockA_GridN5_5_rec", + "Bl6_NtoTDAWT1_blockH_GridG2_3_rec", +] + + +def postprocess(mrc_path, process_center_crop): + output_path = get_seg_path(mrc_path) + copyfile(output_path, output_path + ".bkp") + postprocess_vesicles( + mrc_path, output_path, process_center_crop=process_center_crop, force=True + ) + + tomo, _ = read_mrc(mrc_path) + with h5py.File(output_path, "r") as f: + ves = f["segmentation/veiscles_postprocessed"][:] + + v = napari.Viewer() + v.add_image(tomo) + v.add_labels(ves) + napari.run() + + +# Postprocess vesicles in specific tomograms, where this initially +# failed due to wrong structure segmentations. +def redo_initial_postprocessing(): + tomograms = get_all_tomograms() + for tomogram in tqdm(tomograms, desc="Process tomograms"): + fname = Path(tomogram).stem + if fname not in TOMOS: + continue + print("Postprocessing", fname) + postprocess(tomogram, process_center_crop=True) + + +def label_all_vesicles(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for mrc_path in tqdm(tomograms, desc="Process tomograms"): + output_path = get_seg_path(mrc_path) + output_folder = os.path.split(output_path)[0] + vesicle_path = os.path.join(output_folder, "correction", "veiscles_postprocessed.tif") + assert os.path.exists(vesicle_path), vesicle_path + copyfile(vesicle_path, vesicle_path + ".bkp") + vesicles = imageio.imread(vesicle_path) + vesicles = label(vesicles) + imageio.imwrite(vesicle_path, vesicles, compression="zlib") + + +def main(): + # redo_initial_postprocessing() + # Label all vesicle corrections to make sure everyone has its own id + label_all_vesicles() + + +if __name__: + main() diff --git a/scripts/otoferlin/update_radius_measurements.py b/scripts/otoferlin/update_radius_measurements.py new file mode 100644 index 0000000..5c78b6e --- /dev/null +++ b/scripts/otoferlin/update_radius_measurements.py @@ -0,0 +1,31 @@ +import os +import pandas as pd +from pool_assignments_and_measurements import _measure_radii + +from common import STRUCTURE_NAMES, get_all_tomograms, get_seg_path, load_segmentations +from tqdm import tqdm + + +def update_radii(mrc_path): + seg_path = get_seg_path(mrc_path) + output_folder = os.path.split(seg_path)[0] + assert os.path.exists(output_folder) + assignment_path = os.path.join(output_folder, "vesicle_pools.csv") + radii = _measure_radii(seg_path) + + pool_assignments = pd.read_csv(assignment_path) + assert len(radii) == len(pool_assignments) + pool_assignments["radius"] = radii + pool_assignments["diameter"] = 2 * radii + + pool_assignments.to_csv(assignment_path, index=False) + + +def main(): + tomograms = get_all_tomograms(restrict_to_good_tomos=True) + for tomogram in tqdm(tomograms, desc="Process tomograms"): + update_radii(tomogram) + + +if __name__: + main() diff --git a/synapse_net/distance_measurements.py b/synapse_net/distance_measurements.py index 4cf3181..8fa7ee8 100644 --- a/synapse_net/distance_measurements.py +++ b/synapse_net/distance_measurements.py @@ -226,6 +226,7 @@ def measure_segmentation_to_object_distances( resolution: Optional[Tuple[int, int, int]] = None, save_path: Optional[os.PathLike] = None, verbose: bool = False, + return_object_ids: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Compute the distance betwen all objects in a segmentation and another object. @@ -238,6 +239,7 @@ def measure_segmentation_to_object_distances( resolution: The resolution / pixel size of the data. save_path: Path for saving the measurement results in numpy zipped format. verbose: Whether to print the progress of the distance computation. + return_object_ids: Whether to also return the object ids. Returns: The segmentation to object distances. @@ -262,7 +264,10 @@ def measure_segmentation_to_object_distances( seg_ids=seg_ids, object_ids=object_ids, ) - return distances, endpoints1, endpoints2, seg_ids + if return_object_ids: + return distances, endpoints1, endpoints2, seg_ids, objet_ids + else: + return distances, endpoints1, endpoints2, seg_ids def _extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_duplicates=True): @@ -292,12 +297,13 @@ def _extract_nearest_neighbors(pairwise_distances, seg_ids, n_neighbors, remove_ def load_distances( - measurement_path: os.PathLike + measurement_path: os.PathLike, return_object_ids: bool = False, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Load the saved distacnes from a zipped numpy file. Args: measurement_path: The path where the distances where saved. + return_object_ids: Whether to also return the object ids. Returns: The segmentation to object distances. @@ -308,7 +314,11 @@ def load_distances( auto_dists = np.load(measurement_path) distances, seg_ids = auto_dists["distances"], list(auto_dists["seg_ids"]) endpoints1, endpoints2 = auto_dists["endpoints1"], auto_dists["endpoints2"] - return distances, endpoints1, endpoints2, seg_ids + if return_object_ids: + object_ids = auto_dists["object_ids"] + return distances, endpoints1, endpoints2, seg_ids, object_ids + else: + return distances, endpoints1, endpoints2, seg_ids def create_pairwise_distance_lines( diff --git a/synapse_net/ground_truth/shape_refinement.py b/synapse_net/ground_truth/shape_refinement.py index 8c357ae..26e8e56 100644 --- a/synapse_net/ground_truth/shape_refinement.py +++ b/synapse_net/ground_truth/shape_refinement.py @@ -203,6 +203,7 @@ def refine_individual_vesicle_shapes( edge_map: np.ndarray, foreground_erosion: int = 4, background_erosion: int = 8, + compactness: float = 0.5, ) -> np.ndarray: """Refine vesicle shapes by fitting vesicles to a boundary map. @@ -215,6 +216,8 @@ def refine_individual_vesicle_shapes( You can use `edge_filter` to compute this based on the tomogram. foreground_erosion: By how many pixels the foreground should be eroded in the seeds. background_erosion: By how many pixels the background should be eroded in the seeds. + compactness: The compactness parameter passed to the watershed function. + Higher compactness leads to more regular sized vesicles. Returns: The refined vesicles. """ @@ -250,7 +253,7 @@ def fit_vesicle(prop): # Run seeded watershed to fit the shapes. seeds = fg_seed + 2 * bg_seed - seg[z] = watershed(hmap[z], seeds) == 1 + seg[z] = watershed(hmap[z], seeds, compactness=compactness) == 1 # import napari # v = napari.Viewer() diff --git a/synapse_net/imod/to_imod.py b/synapse_net/imod/to_imod.py index 5832213..99f2407 100644 --- a/synapse_net/imod/to_imod.py +++ b/synapse_net/imod/to_imod.py @@ -37,6 +37,7 @@ def write_segmentation_to_imod( segmentation: Union[str, np.ndarray], output_path: str, segmentation_key: Optional[str] = None, + color: Optional[Tuple[int, int, int]] = None, ) -> None: """Write a segmentation to a mod file as closed contour object(s). @@ -45,6 +46,7 @@ def write_segmentation_to_imod( segmentation: The segmentation (either as numpy array or filepath to a .tif file). output_path: The output path where the mod file will be saved. segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. + color: Optional color for the exported model. """ cmd = "imodauto" cmd_path = shutil.which(cmd) @@ -83,6 +85,10 @@ def write_segmentation_to_imod( # Run the command. cmd_list = [cmd, "-E", "1", "-u", tmp_path, output_path] + if color is not None: + assert len(color) == 3 + r, g, b = [str(co) for co in color] + cmd_list += ["-co", f"{r} {g} {b}"] run(cmd_list) @@ -172,6 +178,7 @@ def write_points_to_imod( min_radius: Union[float, int], output_path: str, color: Optional[Tuple[int, int, int]] = None, + name: Optional[str] = None, ) -> None: """Write point annotations to a .mod file for IMOD. @@ -182,6 +189,7 @@ def write_points_to_imod( min_radius: Minimum radius for export. output_path: Where to save the .mod file. color: Optional color for writing out the points. + name: Optional name for the exported model. """ cmd = "point2model" cmd_path = shutil.which(cmd) @@ -210,6 +218,8 @@ def _pad(inp, n=3): assert len(color) == 3 r, g, b = [str(co) for co in color] cmd += ["-co", f"{r} {g} {b}"] + if name is not None: + cmd += ["-name", name] run(cmd) @@ -222,6 +232,8 @@ def write_segmentation_to_imod_as_points( radius_factor: float = 1.0, estimate_radius_2d: bool = True, segmentation_key: Optional[str] = None, + color: Optional[Tuple[int, int, int]] = None, + name: Optional[str] = None, ) -> None: """Write segmentation results to .mod file with imod point annotations. @@ -237,6 +249,8 @@ def write_segmentation_to_imod_as_points( the radius will be computed only in 2d rather than in 3d. This can lead to better results in case of deformation across the depth axis. segmentation_key: The key to the segmentation data in case the segmentation is stored in hdf5 files. + color: Optional color for writing out the points. + name: Optional name for the exported model. """ # Read the resolution information from the mrcfile. @@ -254,7 +268,7 @@ def write_segmentation_to_imod_as_points( ) # Write the point annotations to imod. - write_points_to_imod(coordinates, radii, segmentation.shape, min_radius, output_path) + write_points_to_imod(coordinates, radii, segmentation.shape, min_radius, output_path, color=color, name=name) def _get_file_paths(input_path, ext=(".mrc", ".rec")): diff --git a/synapse_net/tools/util.py b/synapse_net/tools/util.py index 1495112..a2113c5 100644 --- a/synapse_net/tools/util.py +++ b/synapse_net/tools/util.py @@ -59,7 +59,28 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None return model -def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs): +def _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons): + from synapse_net.inference.postprocessing import ( + segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, + ) + + ribbon = segment_ribbon( + predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, + max_vesicle_distance=40, + ) + PD = segment_presynaptic_density( + predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, + ) + ref_segmentation = PD if PD.sum() > 0 else ribbon + membrane = segment_membrane_distance_based( + predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, + ) + + segmentations = {"ribbon": ribbon, "PD": PD, "membrane": membrane} + return segmentations + + +def _segment_ribbon_AZ(image, model, tiling, scale, verbose, return_predictions=False, **kwargs): # Parse additional keyword arguments from the kwargs. vesicles = kwargs.pop("extra_segmentation") threshold = kwargs.pop("threshold", 0.5) @@ -70,31 +91,21 @@ def _segment_ribbon_AZ(image, model, tiling, scale, verbose, **kwargs): image, model=model, tiling=tiling, scale=scale, verbose=verbose, threshold=threshold, **kwargs ) - # If the vesicles were passed then run additional post-processing. + # Otherwise, just return the predictions. if vesicles is None: - segmentation = predictions + if verbose: + print("Vesicle segmentation was not passed, WILL NOT run post-processing.") + segmentations = predictions - # Otherwise, just return the predictions. + # If the vesicles were passed then run additional post-processing. else: - from synapse_net.inference.postprocessing import ( - segment_ribbon, segment_presynaptic_density, segment_membrane_distance_based, - ) + if verbose: + print("Vesicle segmentation was passed, WILL run post-processing.") + segmentations = _ribbon_AZ_postprocessing(predictions, vesicles, n_slices_exclude, n_ribbons) - ribbon = segment_ribbon( - predictions["ribbon"], vesicles, n_slices_exclude=n_slices_exclude, n_ribbons=n_ribbons, - max_vesicle_distance=40, - ) - PD = segment_presynaptic_density( - predictions["PD"], ribbon, n_slices_exclude=n_slices_exclude, max_distance_to_ribbon=40, - ) - ref_segmentation = PD if PD.sum() > 0 else ribbon - membrane = segment_membrane_distance_based( - predictions["membrane"], ref_segmentation, max_distance=500, n_slices_exclude=n_slices_exclude, - ) - - segmentation = {"ribbon": ribbon, "PD": PD, "membrane": membrane} - - return segmentation + if return_predictions: + return segmentations, predictions + return segmentations def run_segmentation(