|
10 | 10 | import warnings
|
11 | 11 | from collections import defaultdict
|
12 | 12 | from datetime import datetime
|
13 |
| -from typing import List, Union, Optional, Dict, cast |
| 13 | +from typing import List, Union, Optional, Dict, cast, Iterator, Tuple |
14 | 14 |
|
15 | 15 | import jsonschema.validators
|
16 | 16 |
|
@@ -609,7 +609,74 @@ def get_view_contains(self, at_types: Union[ThingTypesBase, str, List[Union[str,
|
609 | 609 | if at_types in view.metadata.contains:
|
610 | 610 | return view
|
611 | 611 | return None
|
612 |
| - |
| 612 | + |
| 613 | + def _is_in_time_between(self, start: Union[int, float], end: Union[int, float], annotation: Annotation) -> bool: |
| 614 | + s, e = self.get_start(annotation), self.get_end(annotation) |
| 615 | + return (s < start < e) or (s > start and e < end) or (s < end < e) |
| 616 | + |
| 617 | + def _handle_time_unit(self, input_unit: str, ann_unit: str, |
| 618 | + start: int, end: int) -> Tuple[Union[int, float, str], Union[int, float, str]]: |
| 619 | + from mmif.utils.timeunit_helper import convert |
| 620 | + start = convert(start, input_unit, ann_unit, 1) |
| 621 | + end = convert(end, input_unit, ann_unit, 1) |
| 622 | + return start, end |
| 623 | + |
| 624 | + def get_annotations_between_time(self, start: int, end: int, time_unit: str = "milliseconds") -> Iterator[Annotation]: |
| 625 | + """ |
| 626 | + Version: 1.0 |
| 627 | + Returns all 'Token' annotations aligned with 'TimeFrame' annotations sorted by start time within start and end time |
| 628 | + Note: this function only works for mmif object obtained from Whisper-wrapper |
| 629 | +
|
| 630 | + :param start: the start time |
| 631 | + :param end: the end time |
| 632 | + :param time_unit: the time unit, either string "milliseconds" or "seconds", defaults to "milliseconds" |
| 633 | + :return: a generator of 'Token' annotations |
| 634 | + """ |
| 635 | + assert start <= end, "Start time must be less than end time" |
| 636 | + assert start >= 0, "Start time must be greater than or equal to zero" |
| 637 | + assert end >= 0, "End time must be greater than or equal to zero" |
| 638 | + # 0. Initialize container and helper method |
| 639 | + valid_tf_anns = [] |
| 640 | + tf_to_anns = defaultdict(list) |
| 641 | + |
| 642 | + # 1. find all views that contain the type of TF |
| 643 | + views = self.get_all_views_contain([AnnotationTypes.TimeFrame, AnnotationTypes.Alignment]) |
| 644 | + |
| 645 | + # 2. For each view, extract annotations that satisfy conditions that are TF/TP and fall into time interval |
| 646 | + for view in views: |
| 647 | + # Make sure time unit stay at the same level |
| 648 | + start_time, end_time = self._handle_time_unit(time_unit, view.metadata.contains.get(AnnotationTypes.TimeFrame)["timeUnit"], |
| 649 | + start, end) |
| 650 | + tf_anns = view.get_annotations(at_type=AnnotationTypes.TimeFrame) |
| 651 | + al_anns = view.get_annotations(at_type=AnnotationTypes.Alignment) |
| 652 | + |
| 653 | + # Select 'TimeFrame' annotations within given time interval |
| 654 | + for tf in tf_anns: |
| 655 | + if self._is_in_time_between(start_time, end_time, tf): |
| 656 | + valid_tf_anns.append(tf) |
| 657 | + |
| 658 | + # Map 'TimeFrame' annotation to its aligned annotation |
| 659 | + for align in al_anns: |
| 660 | + source_id, target_id = align.get_property('source'), align.get_property('target') |
| 661 | + to_long_id = lambda x: x if self.id_delimiter in x else f'{view.id}{self.id_delimiter}{x}' |
| 662 | + try: |
| 663 | + source, target = view.get_annotation_by_id(source_id), view.get_annotation_by_id(target_id) |
| 664 | + if source in valid_tf_anns: |
| 665 | + tf_to_anns[to_long_id(source_id)].append(target) |
| 666 | + elif target in valid_tf_anns: |
| 667 | + tf_to_anns[to_long_id(target_id)].append(source) |
| 668 | + except KeyError: |
| 669 | + pass |
| 670 | + |
| 671 | + # 3. For those extracted 'TimeFrame' annotations, sort them by their start time |
| 672 | + sort_tf_anns = sorted(valid_tf_anns, key=lambda x: self.get_start(x)) |
| 673 | + |
| 674 | + # 4. Yield all annotations aligned with sorted 'TimeFrame' annotations |
| 675 | + for tf_ann in sort_tf_anns: |
| 676 | + anns = tf_to_anns[tf_ann.long_id] |
| 677 | + for ann in anns: |
| 678 | + yield ann |
| 679 | + |
613 | 680 | def _get_linear_anchor_point(self, ann: Annotation, targets_sorted=False, start: bool = True) -> Union[int, float]:
|
614 | 681 | # TODO (krim @ 2/5/24): Update the return type once timeunits are unified to `ms` as integers (https://github.com/clamsproject/mmif/issues/192)
|
615 | 682 | """
|
|
0 commit comments