Skip to content

Commit 7b19ddd

Browse files
authored
Merge branch 'DeepLabCut:main' into main
2 parents 0623bad + dd0ef5a commit 7b19ddd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+340
-163
lines changed

deeplabcut/benchmark/base.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,7 @@ def _validate_predictions(self, name: str, predictions: dict) -> dict:
132132
"individuals were detected in those images."
133133
)
134134

135-
return {
136-
img: predictions.get(img, tuple())
137-
for img in test_images
138-
}
135+
return {img: predictions.get(img, tuple()) for img in test_images}
139136

140137

141138
@dataclasses.dataclass

deeplabcut/benchmark/metrics.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,7 @@ def calc_rmse_from_obj(
234234
kpts.pop(ind)
235235

236236
test_objects = {
237-
k: v
238-
for k, v in eval_results_obj.items()
239-
if k in gt["annotations"].keys()
237+
k: v for k, v in eval_results_obj.items() if k in gt["annotations"].keys()
240238
}
241239
if len(gt["annotations"]) != len(test_objects):
242240
gt_images = list(gt["annotations"].keys())

deeplabcut/create_project/new.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,9 @@ def create_new_project(
276276
cfg_file["x2"] = 640
277277
cfg_file["y1"] = 277
278278
cfg_file["y2"] = 624
279-
cfg_file[
280-
"batch_size"
281-
] = 8 # batch size during inference (video - analysis); see https://www.biorxiv.org/content/early/2018/10/30/457242
279+
cfg_file["batch_size"] = (
280+
8 # batch size during inference (video - analysis); see https://www.biorxiv.org/content/early/2018/10/30/457242
281+
)
282282
cfg_file["corner2move2"] = (50, 50)
283283
cfg_file["move2corner"] = True
284284
cfg_file["skeleton_color"] = "black"

deeplabcut/generate_training_dataset/multiple_individuals_trainingsetmanipulation.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,7 @@ def create_multianimaltraining_dataset(
272272
# Loading the encoder (if necessary downloading from TF)
273273
dlcparent_path = auxiliaryfunctions.get_deeplabcut_path()
274274
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
275-
model_path = auxfun_models.check_for_weights(
276-
net_type, Path(dlcparent_path)
277-
)
275+
model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))
278276

279277
if Shuffles is None:
280278
Shuffles = range(1, num_shuffles + 1, 1)
@@ -425,9 +423,9 @@ def create_multianimaltraining_dataset(
425423
"multi_step": [[1e-4, 7500], [5 * 1e-5, 12000], [1e-5, 200000]],
426424
"save_iters": 10000,
427425
"display_iters": 500,
428-
"num_idchannel": len(cfg["individuals"])
429-
if cfg.get("identity", False)
430-
else 0,
426+
"num_idchannel": (
427+
len(cfg["individuals"]) if cfg.get("identity", False) else 0
428+
),
431429
"crop_size": list(crop_size),
432430
"crop_sampling": crop_sampling,
433431
}

deeplabcut/generate_training_dataset/trainingsetmanipulation.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -988,9 +988,7 @@ def create_training_dataset(
988988
defaultconfigfile = os.path.join(dlcparent_path, "pose_cfg.yaml")
989989
elif posecfg_template:
990990
defaultconfigfile = posecfg_template
991-
model_path = auxfun_models.check_for_weights(
992-
net_type, Path(dlcparent_path)
993-
)
991+
model_path = auxfun_models.check_for_weights(net_type, Path(dlcparent_path))
994992

995993
if Shuffles is None:
996994
Shuffles = range(1, num_shuffles + 1)

deeplabcut/gui/tabs/extract_frames.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ def _set_page(self):
9595

9696
self.main_layout.addWidget(
9797
_create_label_widget(
98-
"Frame extraction from a video subset (optional for automatic extraction)", "font:bold"
98+
"Frame extraction from a video subset (optional for automatic extraction)",
99+
"font:bold",
99100
)
100101
)
101102
self.video_selection_widget = VideoSelectionWidget(self.root, self)
@@ -206,7 +207,9 @@ def extract_frames(self):
206207
return
207208
first_video = videos[0]
208209
if len(videos) > 1:
209-
self.root.writer.write(f"Only the first video ({first_video}) will be opened.")
210+
self.root.writer.write(
211+
f"Only the first video ({first_video}) will be opened."
212+
)
210213
video_path_in_folder = self._check_symlink(first_video)
211214
_ = launch_napari(str(video_path_in_folder))
212215
return

deeplabcut/gui/tabs/modelzoo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def _set_page(self):
7272

7373
tooltip_label = QtWidgets.QLabel()
7474
tooltip_label.setPixmap(
75-
QPixmap(os.path.join(BASE_DIR, "assets", "icons", "help2.png")).scaledToWidth(30)
75+
QPixmap(
76+
os.path.join(BASE_DIR, "assets", "icons", "help2.png")
77+
).scaledToWidth(30)
7678
)
7779
tooltip_label.setToolTip(
7880
"Approximate animal sizes in pixels, for spatial pyramid search. If left blank, defaults to video height +/- 50 pixels",

deeplabcut/gui/tracklet_toolbox.py

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from deeplabcut.utils.auxfun_videos import VideoReader
2020
from deeplabcut.utils.auxiliaryfunctions import attempt_to_make_folder
2121
from matplotlib.path import Path
22-
from matplotlib.widgets import Slider, LassoSelector, Button, CheckButtons
22+
from matplotlib.widgets import Slider, LassoSelector, Button, CheckButtons, TextBox
2323
from PySide6.QtWidgets import QMessageBox
2424
from PySide6.QtCore import QMutex
2525

@@ -327,6 +327,9 @@ def __init__(self, manager, videoname, trail_len=50):
327327

328328
self.dps = []
329329

330+
self.swap_id1 = None
331+
self.swap_id2 = None
332+
330333
def _prepare_canvas(self, manager, fig):
331334
params = {
332335
"keymap.save": "s",
@@ -358,7 +361,7 @@ def _prepare_canvas(self, manager, fig):
358361

359362
img = self.video.read_frame()
360363
self.im = self.ax1.imshow(img)
361-
self.scat = self.ax1.scatter([], [], s=self.dotsize ** 2, picker=True)
364+
self.scat = self.ax1.scatter([], [], s=self.dotsize**2, picker=True)
362365
self.scat.set_offsets(manager.xy[:, 0])
363366
self.scat.set_color(self.colors)
364367
self.trails = sum(
@@ -374,6 +377,7 @@ def _prepare_canvas(self, manager, fig):
374377
)
375378
self.vline_x = self.ax2.axvline(0, 0, 1, c="k", ls=":")
376379
self.vline_y = self.ax3.axvline(0, 0, 1, c="k", ls=":")
380+
377381
custom_lines = [
378382
plt.Line2D([0], [0], color=self.cmap(i), lw=4)
379383
for i in range(len(manager.individuals))
@@ -420,10 +424,15 @@ def _prepare_canvas(self, manager, fig):
420424
self.ax_flag = self.fig.add_axes([0.75, 0.1, 0.05, 0.03])
421425
self.ax_save = self.fig.add_axes([0.80, 0.1, 0.05, 0.03])
422426
self.ax_help = self.fig.add_axes([0.85, 0.1, 0.05, 0.03])
427+
self.ax_swap = self.fig.add_axes([0.90, 0.1, 0.05, 0.03]) # New button
428+
423429
self.save_button = Button(self.ax_save, "Save", color="darkorange")
424430
self.save_button.on_clicked(self.save)
425431
self.help_button = Button(self.ax_help, "Help")
426432
self.help_button.on_clicked(self.display_help)
433+
self.swap_button = Button(self.ax_swap, "Swap") # New button
434+
self.swap_button.on_clicked(self.swap_tracklets) # Placeholder action
435+
427436
self.drag_toggle = CheckButtons(self.ax_drag, ["Drag"])
428437
self.drag_toggle.on_clicked(self.toggle_draggable_points)
429438
self.flag_button = Button(self.ax_flag, "Flag")
@@ -441,9 +450,75 @@ def _prepare_canvas(self, manager, fig):
441450
self.ax1_background = self.fig.canvas.copy_from_bbox(self.ax1.bbox)
442451
self.fig.show()
443452

453+
# Create dropdowns for selecting tracklets to swap, placing them near the swap button
454+
self.ax_dropdown1 = self.fig.add_axes([0.9, 0.15, 0.05, 0.03])
455+
self.ax_dropdown2 = self.fig.add_axes([0.9, 0.20, 0.05, 0.03])
456+
self.textbox1 = TextBox(self.ax_dropdown1, "ID 1")
457+
self.textbox2 = TextBox(self.ax_dropdown2, "ID 2")
458+
self.textbox1.on_submit(self.set_swap_id1)
459+
self.textbox2.on_submit(self.set_swap_id2)
460+
444461
def show(self, fig=None):
445462
self._prepare_canvas(self.manager, fig)
446463

464+
def swap_tracklets(self, event):
465+
if self.swap_id1 is not None and self.swap_id2 is not None:
466+
467+
# Get tracklet indices for each individual
468+
inds1 = [
469+
k
470+
for k in range(len(self.manager.tracklet2id))
471+
if self.manager.tracklet2id[k] == self.swap_id1
472+
]
473+
inds2 = [
474+
k
475+
for k in range(len(self.manager.tracklet2id))
476+
if self.manager.tracklet2id[k] == self.swap_id2
477+
]
478+
479+
print(f"Swapping tracklets {self.swap_id1} and {self.swap_id2}")
480+
481+
# Frames to swap
482+
frames = []
483+
if len(self.cuts) == 2:
484+
frames = list(range(min(self.cuts), max(self.cuts) + 1))
485+
elif len(self.cuts) == 1:
486+
frames = [self.cuts[0]]
487+
else:
488+
frames = list(range(self.curr_frame, self.manager.nframes))
489+
490+
# Swap the tracklets
491+
for i in range(min(len(inds1), len(inds2))):
492+
self.manager.swap_tracklets(inds1[i], inds2[i], frames)
493+
self.display_traces()
494+
self.slider.set_val(self.curr_frame)
495+
496+
def set_swap_id1(self, val):
497+
# check that the input is a valid from the list of individuals
498+
if int(val) in self.manager.tracklet2id:
499+
self.swap_id1 = int(val)
500+
print("ID 1 set.")
501+
else:
502+
print(
503+
f"Invalid ID. Please select a valid ID from the list of individuals: {set(self.manager.tracklet2id)}"
504+
)
505+
self.swap_id1 = None
506+
507+
def set_swap_id2(self, val):
508+
# check that the input is a valid from the list of individuals
509+
if int(val) in self.manager.tracklet2id:
510+
self.swap_id2 = int(val)
511+
print("ID 2 set.")
512+
else:
513+
print(
514+
f"Invalid ID. Please select a valid ID from the list of individuals: {set(self.manager.tracklet2id)}"
515+
)
516+
self.swap_id2 = None
517+
518+
def terminate(self, event):
519+
plt.close(self.fig)
520+
self.player.terminate()
521+
447522
def fill_shaded_areas(self):
448523
self.clean_collections()
449524
if self.picked_pair:
@@ -587,9 +662,9 @@ def on_press(self, event):
587662
if len(self.cuts) > 1:
588663
self.cuts.sort()
589664
if self.picked_pair:
590-
self.manager.tracklet_swaps[self.picked_pair][
591-
self.cuts
592-
] = ~self.manager.tracklet_swaps[self.picked_pair][self.cuts]
665+
self.manager.tracklet_swaps[self.picked_pair][self.cuts] = (
666+
~self.manager.tracklet_swaps[self.picked_pair][self.cuts]
667+
)
593668
self.fill_shaded_areas()
594669
self.cuts = []
595670
for line in self.ax_slider.lines:
@@ -807,7 +882,7 @@ def on_change(self, val):
807882

808883
def update_dotsize(self, val):
809884
self.dotsize = val
810-
self.scat.set_sizes([self.dotsize ** 2])
885+
self.scat.set_sizes([self.dotsize**2])
811886

812887
@staticmethod
813888
def calc_distance(x1, y1, x2, y2):

deeplabcut/gui/window.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,7 @@ def _check_for_updates(silent=True):
5959
_ = msg.addButton("Skip", msg.RejectRole)
6060
msg.exec_()
6161
if msg.clickedButton() is update_btn:
62-
subprocess.check_call(
63-
[sys.executable, "-m", *command]
64-
)
62+
subprocess.check_call([sys.executable, "-m", *command])
6563

6664

6765
class MainWindow(QMainWindow):

deeplabcut/modelzoo/api/superanimal_inference.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ def video_inference(
271271
customized_test_config="",
272272
):
273273
if superanimal_name not in MODELOPTIONS:
274-
raise ValueError(f"{superanimal_name} not available. Available ones are: {MODELOPTIONS}. If you are confident `superanimal_name` is right, try updating `dlclibrary` with `pip install -U dlclibrary`.")
274+
raise ValueError(
275+
f"{superanimal_name} not available. Available ones are: {MODELOPTIONS}. If you are confident `superanimal_name` is right, try updating `dlclibrary` with `pip install -U dlclibrary`."
276+
)
275277

276278
dlc_root_path = auxiliaryfunctions.get_deeplabcut_path()
277279

0 commit comments

Comments
 (0)