Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Otoferlin #80

Draft
wants to merge 35 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
456f99a
Add first version of otoferlin processing code
constantinpape Dec 6, 2024
f10df21
Enable local file paths for otoferlin experiments
constantinpape Dec 6, 2024
c97c47a
Merge branch 'main' into otoferlin
constantinpape Dec 7, 2024
5048422
Update visualization scripts
constantinpape Dec 7, 2024
964cf2c
Add vesicle comparison and domain adaptation
constantinpape Dec 7, 2024
10181ce
Update inference for otoferlin WIP
constantinpape Dec 8, 2024
b6e223f
Update otoferlin inference WIP
constantinpape Dec 8, 2024
fd11ee8
Add script for structure postprocessing
constantinpape Dec 8, 2024
e7288b5
More ribbon inference updates
constantinpape Dec 8, 2024
224ef6c
Fix issues in structure processing
constantinpape Dec 8, 2024
451ff69
Implement vesicle post-processing
constantinpape Dec 8, 2024
3bcca74
Add first version of correction script
constantinpape Dec 8, 2024
70a5856
Update otoferlin correction script
constantinpape Dec 9, 2024
03dc031
Add path to DA model on the WS
constantinpape Dec 9, 2024
69d5021
Implement vesicle pool correction WIP
constantinpape Dec 10, 2024
f630ee1
Update otoferlin analysis
constantinpape Dec 10, 2024
ea84a3d
Add vesicle postprocessing scripts
constantinpape Dec 11, 2024
9b8be05
Update vesicle pool correction script
constantinpape Dec 11, 2024
6eb5d12
Implement vesicle labeling
constantinpape Dec 11, 2024
504144c
merge on WS
constantinpape Dec 11, 2024
e99ab19
Add overview table
constantinpape Dec 11, 2024
7cdfb4d
Merge branch 'otoferlin' of https://github.com/computational-cell-ana…
constantinpape Dec 11, 2024
0b602ec
Fix table loading
constantinpape Dec 11, 2024
81ff67f
Merge branch 'otoferlin' of https://github.com/computational-cell-ana…
constantinpape Dec 11, 2024
079444f
Fix table
constantinpape Dec 11, 2024
ae83e74
Merge branch 'otoferlin' of https://github.com/computational-cell-ana…
constantinpape Dec 11, 2024
5a22bc4
More name fixes
constantinpape Dec 11, 2024
66f296f
Merge branch 'otoferlin' of https://github.com/computational-cell-ana…
constantinpape Dec 11, 2024
adbfdc9
Update otoferlin analysis
constantinpape Dec 11, 2024
791f635
Implement IMOD export
constantinpape Dec 11, 2024
c4fe694
Update otoferlin analyse
constantinpape Dec 12, 2024
0bdf4c1
Finish imodexport and implement napari vis
constantinpape Dec 12, 2024
a44ab05
Update figure script
constantinpape Dec 12, 2024
5d4f7cb
Update otoferlin analysis
constantinpape Dec 14, 2024
6c3c431
Add object filter logic
constantinpape Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions scripts/inner_ear/processing/filter_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import os
from pathlib import Path
from tqdm import tqdm

import h5py
import imageio.v3 as imageio
import numpy as np
from skimage.measure import label
from skimage.segmentation import relabel_sequential

from synapse_net.file_utils import get_data_path
from parse_table import parse_table, get_data_root, _match_correction_folder, _match_correction_file


def _load_segmentation(seg_path):
ext = Path(seg_path).suffix
assert ext in (".h5", ".tif"), ext
if ext == ".tif":
seg = imageio.imread(seg_path)
else:
with h5py.File(seg_path, "r") as f:
seg = f["segmentation"][:]
return seg


def _save_segmentation(seg_path, seg):
ext = Path(seg_path).suffix
assert ext in (".h5", ".tif"), ext
if ext == ".tif":
imageio.imwrite(seg_path, seg, compression="zlib")
else:
with h5py.File(seg_path, "a") as f:
f.create_dataset("segmentation", data=seg, compression="gzip")
return seg


def _filter_n_objects(segmentation, num_objects):
# Create individual objects for all disconnected pieces.
segmentation = label(segmentation)
# Find object ids and sizes, excluding background.
ids, sizes = np.unique(segmentation, return_counts=True)
ids, sizes = ids[1:], sizes[1:]
# Only keep the biggest 'num_objects' objects.
keep_ids = ids[np.argsort(sizes)[::-1]][:num_objects]
segmentation[~np.isin(segmentation, keep_ids)] = 0
# Relabel the segmentation sequentially.
segmentation, _, _ = relabel_sequential(segmentation)
# Ensure that we have the correct number of objects.
n_ids = int(segmentation.max())
assert n_ids == num_objects
return segmentation


def process_tomogram(folder, num_ribbon, num_pd):
data_path = get_data_path(folder)
output_folder = os.path.join(folder, "automatisch", "v2")
fname = Path(data_path).stem

correction_folder = _match_correction_folder(folder)

ribbon_path = _match_correction_file(correction_folder, "ribbon")
if not os.path.exists(ribbon_path):
ribbon_path = os.path.join(output_folder, f"{fname}_ribbon.h5")
assert os.path.exists(ribbon_path), ribbon_path
ribbon = _load_segmentation(ribbon_path)

pd_path = _match_correction_file(correction_folder, "PD")
if not os.path.exists(pd_path):
pd_path = os.path.join(output_folder, f"{fname}_pd.h5")
assert os.path.exists(pd_path), pd_path
PD = _load_segmentation(pd_path)

# Filter the ribbon and the PD.
print("Filtering number of ribbons:", num_ribbon)
ribbon = _filter_n_objects(ribbon, num_ribbon)
bkp_path_ribbon = ribbon_path + ".bkp"
os.rename(ribbon_path, bkp_path_ribbon)
_save_segmentation(ribbon_path, ribbon)

print("Filtering number of PDs:", num_pd)
PD = _filter_n_objects(PD, num_pd)
bkp_path_pd = pd_path + ".bkp"
os.rename(pd_path, bkp_path_pd)
_save_segmentation(pd_path, PD)


def filter_objects(table, version):
for i, row in tqdm(table.iterrows(), total=len(table)):
folder = row["Local Path"]
if folder == "":
continue

# We have to handle the segmentation without ribbon separately.
if row["PD vorhanden? "] == "nein":
continue

n_pds = row["Anzahl PDs"]
if n_pds == "unklar":
n_pds = 1

n_pds = int(n_pds)
n_ribbons = int(row["Anzahl Ribbons"])
if (n_ribbons == 2 and n_pds == 1):
print(f"The tomogram {folder} has {n_ribbons} ribbons and {n_pds} PDs.")
print("The structure post-processing for this case is not yet implemented and will be skipped.")
continue

micro = row["EM alt vs. Neu"]
if micro == "beides":
process_tomogram(folder, n_ribbons, n_pds)

folder_new = os.path.join(folder, "Tomo neues EM")
if not os.path.exists(folder_new):
folder_new = os.path.join(folder, "neues EM")
assert os.path.exists(folder_new), folder_new
process_tomogram(folder_new, n_ribbons, n_pds)

elif micro == "alt":
process_tomogram(folder, n_ribbons, n_pds)

elif micro == "neu":
process_tomogram(folder, n_ribbons, n_pds)


def main():
data_root = get_data_root()
table_path = os.path.join(data_root, "Electron-Microscopy-Susi", "Übersicht.xlsx")
table = parse_table(table_path, data_root)
version = 2
filter_objects(table, version)


if __name__ == "__main__":
main()
4 changes: 4 additions & 0 deletions scripts/otoferlin/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
data/
sync_segmentation.sh
segmentation/
results/
9 changes: 9 additions & 0 deletions scripts/otoferlin/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Otoferlin Analysis


## Notes on improvements:

- Try less than 20 exclude slices
- Update boundary post-proc (not robust when PD not found and selects wrong objects)
- Can we find fiducials with ilastik and mask them out? They are interfering with Ribbon finding.
- Alternative: just restrict the processing to a center crop by default.
201 changes: 201 additions & 0 deletions scripts/otoferlin/automatic_processing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os

import h5py
import numpy as np

from skimage.measure import label
from skimage.segmentation import relabel_sequential

from synapse_net.distance_measurements import measure_segmentation_to_object_distances
from synapse_net.file_utils import read_mrc
from synapse_net.inference.vesicles import segment_vesicles
from synapse_net.tools.util import get_model, compute_scale_from_voxel_size, _segment_ribbon_AZ
from tqdm import tqdm

from common import get_all_tomograms, get_seg_path, get_adapted_model, load_segmentations

# These are tomograms for which the sophisticated membrane processing fails.
# In this case, we just select the largest boundary piece.
SIMPLE_MEM_POSTPROCESSING = [
"Otof_TDAKO1blockA_GridN5_2_rec.mrc", "Otof_TDAKO2blockC_GridF5_1_rec.mrc", "Otof_TDAKO2blockC_GridF5_2_rec.mrc",
"Bl6_NtoTDAWT1_blockH_GridF3_1_rec.mrc", "Bl6_NtoTDAWT1_blockH_GridG2_3_rec.mrc", "Otof_TDAKO1blockA_GridN5_5_rec.mrc",
"Otof_TDAKO2blockC_GridE2_1_rec.mrc", "Otof_TDAKO2blockC_GridE2_2_rec.mrc",

]


def _get_center_crop(input_):
halo_xy = (600, 600)
bb_xy = tuple(
slice(max(sh // 2 - ha, 0), min(sh // 2 + ha, sh)) for sh, ha in zip(input_.shape[1:], halo_xy)
)
bb = (np.s_[:],) + bb_xy
return bb, input_.shape


def _get_tiling():
# tile = {"x": 768, "y": 768, "z": 48}
tile = {"x": 512, "y": 512, "z": 48}
halo = {"x": 128, "y": 128, "z": 8}
return {"tile": tile, "halo": halo}


def process_vesicles(mrc_path, output_path, process_center_crop):
key = "segmentation/vesicles"
if os.path.exists(output_path):
with h5py.File(output_path, "r") as f:
if key in f:
return

input_, voxel_size = read_mrc(mrc_path)
if process_center_crop:
bb, full_shape = _get_center_crop(input_)
input_ = input_[bb]

model = get_adapted_model()
scale = compute_scale_from_voxel_size(voxel_size, "ribbon")
print("Rescaling volume for vesicle segmentation with factor:", scale)
tiling = _get_tiling()
segmentation = segment_vesicles(input_, model=model, scale=scale, tiling=tiling)

if process_center_crop:
full_seg = np.zeros(full_shape, dtype=segmentation.dtype)
full_seg[bb] = segmentation
segmentation = full_seg

with h5py.File(output_path, "a") as f:
f.create_dataset(key, data=segmentation, compression="gzip")


def _simple_membrane_postprocessing(membrane_prediction):
seg = label(membrane_prediction)
ids, sizes = np.unique(seg, return_counts=True)
ids, sizes = ids[1:], sizes[1:]
return (seg == ids[np.argmax(sizes)]).astype("uint8")


def process_ribbon_structures(mrc_path, output_path, process_center_crop):
key = "segmentation/ribbon"
with h5py.File(output_path, "r") as f:
if key in f:
return
vesicles = f["segmentation/vesicles"][:]

input_, voxel_size = read_mrc(mrc_path)
if process_center_crop:
bb, full_shape = _get_center_crop(input_)
input_, vesicles = input_[bb], vesicles[bb]
assert input_.shape == vesicles.shape

model_name = "ribbon"
model = get_model(model_name)
scale = compute_scale_from_voxel_size(voxel_size, model_name)
tiling = _get_tiling()

segmentations, predictions = _segment_ribbon_AZ(
input_, model, tiling=tiling, scale=scale, verbose=True, extra_segmentation=vesicles,
return_predictions=True, n_slices_exclude=5,
)

# The distance based post-processing for membranes fails for some tomograms.
# In these cases, just choose the largest membrane piece.
fname = os.path.basename(mrc_path)
if fname in SIMPLE_MEM_POSTPROCESSING:
segmentations["membrane"] = _simple_membrane_postprocessing(predictions["membrane"])

if process_center_crop:
for name, seg in segmentations.items():
full_seg = np.zeros(full_shape, dtype=seg.dtype)
full_seg[bb] = seg
segmentations[name] = full_seg
for name, pred in predictions.items():
full_pred = np.zeros(full_shape, dtype=seg.dtype)
full_pred[bb] = pred
predictions[name] = full_pred

with h5py.File(output_path, "a") as f:
for name, seg in segmentations.items():
f.create_dataset(f"segmentation/{name}", data=seg, compression="gzip")
f.create_dataset(f"prediction/{name}", data=predictions[name], compression="gzip")


def postprocess_vesicles(
mrc_path, output_path, process_center_crop, force=False
):
key = "segmentation/veiscles_postprocessed"
with h5py.File(output_path, "r") as f:
if key in f and not force:
return
vesicles = f["segmentation/vesicles"][:]
if process_center_crop:
bb, full_shape = _get_center_crop(vesicles)
vesicles = vesicles[bb]
else:
bb = np.s_[:]

segs = load_segmentations(output_path)
ribbon = segs["ribbon"][bb]
membrane = segs["membrane"][bb]

# Filter out small vesicle fragments.
min_size = 5000
ids, sizes = np.unique(vesicles, return_counts=True)
ids, sizes = ids[1:], sizes[1:]
filter_ids = ids[sizes < min_size]
vesicles[np.isin(vesicles, filter_ids)] = 0

input_, voxel_size = read_mrc(mrc_path)
voxel_size = tuple(voxel_size[ax] for ax in "zyx")
input_ = input_[bb]

# Filter out all vesicles farther than 120 nm from the membrane or ribbon.
max_dist = 120
seg = (ribbon + membrane) > 0
distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, seg, resolution=voxel_size)
filter_ids = seg_ids[distances > max_dist]
vesicles[np.isin(vesicles, filter_ids)] = 0

vesicles, _, _ = relabel_sequential(vesicles)

if process_center_crop:
full_seg = np.zeros(full_shape, dtype=vesicles.dtype)
full_seg[bb] = vesicles
vesicles = full_seg
with h5py.File(output_path, "a") as f:
if key in f:
f[key][:] = vesicles
else:
f.create_dataset(key, data=vesicles, compression="gzip")


def process_tomogram(mrc_path):
output_path = get_seg_path(mrc_path)
output_folder = os.path.split(output_path)[0]
os.makedirs(output_folder, exist_ok=True)

process_center_crop = True

process_vesicles(mrc_path, output_path, process_center_crop)
process_ribbon_structures(mrc_path, output_path, process_center_crop)
postprocess_vesicles(mrc_path, output_path, process_center_crop)


def main():
tomograms = get_all_tomograms()
# for tomogram in tqdm(tomograms, desc="Process tomograms"):
# process_tomogram(tomogram)

# Update the membrane postprocessing for the tomograms where this went wrong.
for tomo in tqdm(tomograms, desc="Fix membrame postprocesing"):
if os.path.basename(tomo) not in SIMPLE_MEM_POSTPROCESSING:
continue
seg_path = get_seg_path(tomo)
with h5py.File(seg_path, "r") as f:
pred = f["prediction/membrane"][:]
seg = _simple_membrane_postprocessing(pred)
with h5py.File(seg_path, "a") as f:
f["segmentation/membrane"][:] = seg


if __name__ == "__main__":
main()
Loading