Skip to content

Commit 43bb3f1

Browse files
committed
improved alignment caching mechanism
1 parent 0692957 commit 43bb3f1

File tree

6 files changed

+96
-93
lines changed

6 files changed

+96
-93
lines changed

mmif/serialize/annotation.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ def __init__(self, anno_obj: Optional[Union[bytes, str, dict]] = None, *_) -> No
6060
self._required_attributes = ["_type", "properties"]
6161
super().__init__(anno_obj)
6262

63+
def __hash__(self):
64+
return hash(self.serialize())
65+
6366
def _deserialize(self, input_dict: dict) -> None:
6467
self.at_type = input_dict.pop('_type', '')
6568
# TODO (krim @ 6/1/21): If annotation IDs must follow a certain string format,
@@ -70,21 +73,31 @@ def _deserialize(self, input_dict: dict) -> None:
7073
for k, v in self.properties.items():
7174
self._add_prop_aliases(k, v)
7275

73-
def _cache_alignment(self, alignment_id: str, alignedto_id: str) -> None:
76+
def _cache_alignment(self, alignment_ann: 'Annotation', alignedto_ann: 'Annotation') -> None:
7477
"""
7578
Cache alignment information. This cache will not be serialized. Both ID arguments must be in their long_id
7679
format.
7780
:param alignment_id: long_id of the Alignment annotation that has this annotation on one side
7881
:param alignedto_id: long_id of the annotation that this annotation is aligned to (other side of Alignment)
7982
"""
80-
self._alignments[alignment_id] = alignedto_id
83+
self._alignments[alignment_ann] = alignedto_ann
8184

82-
def aligned_to_by(self, alignment_id: str) -> Optional[str]:
85+
def aligned_to_by(self, alignment: 'Annotation') -> Optional['Annotation']:
8386
"""
8487
Retrieve the long_id of the annotation that this annotation is aligned to.
85-
:param alignment_id: ID if the Alignment annotation
88+
:param alignment: Alignment annotation that has this annotation on one side
89+
"""
90+
return self._alignments.get(alignment)
91+
92+
def get_all_aligned(self) -> Iterator['Annotation']:
8693
"""
87-
return self._alignments.get(alignment_id)
94+
Generator to iterate through all alignments and aligned annotations.
95+
:return: yields the alignment annotation and the aligned annotation in order
96+
"""
97+
for alignment, aligned in self._alignments.items():
98+
yield alignment
99+
yield aligned
100+
88101

89102
def _add_prop_aliases(self, key_to_add, val_to_add):
90103
"""

mmif/serialize/mmif.py

Lines changed: 54 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
from collections import defaultdict
1212
from datetime import datetime
13-
from typing import List, Union, Optional, Dict, cast, Iterator, Tuple
13+
from typing import List, Union, Optional, Dict, cast, Iterator
1414

1515
import jsonschema.validators
1616

@@ -235,14 +235,39 @@ def _deserialize(self, input_dict: dict) -> None:
235235
# add quick access to `start` and `end` values if the annotation is using `targets` property
236236
if 'targets' in ann.properties:
237237
if 'start' in ann.properties or 'end' in ann.properties:
238-
raise ValueError(f"Annotation {ann.id} (in view {view.id}) has `targes` and `start`/`end/` "
238+
raise ValueError(f"Annotation {ann.id} (in view {view.id}) has `targets` and `start`/`end/` "
239239
f"properties at the same time. Annotation anchors are ambiguous.")
240240
ann._props_ephemeral['start'] = self._get_linear_anchor_point(ann, start=True)
241241
ann._props_ephemeral['end'] = self._get_linear_anchor_point(ann, start=False)
242242

243243
## caching alignments
244244
if ann.at_type == AnnotationTypes.Alignment:
245-
view._cache_alignment(ann)
245+
self._cache_alignment(ann)
246+
247+
def _cache_alignment(self, alignment_ann: Annotation):
248+
view = self.views.get(alignment_ann.parent)
249+
if view is None:
250+
warnings.warn(f"Alignment {alignment_ann.long_id} doesn't have a parent view, but it should.", RuntimeWarning)
251+
return
252+
253+
## caching alignments
254+
def _desprately_search_annotation_object(ann_short_id):
255+
ann_long_id = f"{view.id}{self.id_delimiter}{ann_short_id}"
256+
try:
257+
return self.__getitem__(ann_long_id)
258+
except KeyError:
259+
return self.__getitem__(ann_short_id)
260+
261+
if all(map(lambda x: x in alignment_ann.properties, ('source', 'target'))):
262+
source_ann = _desprately_search_annotation_object(alignment_ann.get('source'))
263+
target_ann = _desprately_search_annotation_object(alignment_ann.get('target'))
264+
if isinstance(source_ann, Annotation) and isinstance(target_ann, Annotation):
265+
source_ann._cache_alignment(alignment_ann, target_ann)
266+
target_ann._cache_alignment(alignment_ann, source_ann)
267+
else:
268+
warnings.warn(
269+
f"Alignment {alignment_ann.long_id} has `source` and `target` properties that do not point to Annotation objects.",
270+
RuntimeWarning)
246271

247272
def generate_capital_annotations(self):
248273
"""
@@ -566,19 +591,16 @@ def get_all_views_with_error(self) -> List[View]:
566591

567592
get_views_with_error = get_all_views_with_error
568593

569-
def get_all_views_contain(self, at_types: Union[ThingTypesBase, str, List[Union[str, ThingTypesBase]]]) -> List[View]:
594+
def get_all_views_contain(self, *at_types: Union[ThingTypesBase, str]) -> List[View]:
570595
"""
571596
Returns the list of all views in the MMIF if given types
572597
are present in that view's 'contains' metadata.
573598
574599
:param at_types: a list of types or just a type to check for. When given more than one types, all types must be found.
575600
:return: the list of views that contain the type
576601
"""
577-
if isinstance(at_types, list):
578-
return [view for view in self.views
579-
if all(map(lambda x: x in view.metadata.contains, at_types))]
580-
else:
581-
return [view for view in self.views if at_types in view.metadata.contains]
602+
return [view for view in self.views
603+
if all(map(lambda x: x in view.metadata.contains, at_types))]
582604

583605
get_views_contain = get_all_views_contain
584606

@@ -621,35 +643,20 @@ def get_view_contains(self, at_types: Union[ThingTypesBase, str, List[Union[str,
621643
return view
622644
return None
623645

624-
def _is_in_time_range(self, ann: Annotation, start: Union[int, float], end: Union[int, float]) -> bool:
646+
def _is_in_time_range(self, ann: Annotation, range_s: Union[int, float], range_e: Union[int, float]) -> bool:
625647
"""
626-
Checks if the annotation is anchored within the given time range.
648+
Checks if the annotation is anchored within the given time range. Any overlap is considered included.
627649
628-
:param ann: the Annotation object to check
629-
:param start: the start time point in milliseconds
630-
:param end: the end time point in milliseconds
650+
:param ann: the Annotation object to check, must be time-based itself or anchored to time-based annotations
651+
:param range_s: the start time point of the range (in milliseconds)
652+
:param range_e: the end time point of the range (in milliseconds)
631653
632654
:return: True if the annotation is anchored within the time range, False otherwise
633655
"""
634-
s, e = self.get_start(ann), self.get_end(ann)
635-
return (s < start < e) or (s < end < e) or (s > start and e < end)
636-
637-
def _handle_time_unit(self, input_unit: str, ann_unit: str,
638-
start: int, end: int) -> Tuple[Union[int, float, str], Union[int, float, str]]:
639-
"""
640-
Helper method to convert time unit defined by user to the unit in mmif object.
656+
ann_s, ann_e = self.get_start(ann), self.get_end(ann)
657+
return (ann_s < range_s < ann_e) or (ann_s < range_e < ann_e) or (ann_s > range_s and ann_e < range_e)
641658

642-
:param input_unit: the time unit defined by user
643-
:param ann_unit: the time unit in mmif object
644-
:param start: the start time point in the unit of `input_unit`
645-
:param end: the end time point in the unit of `input_unit`
646-
647-
:return: the start and end time points in the unit of `ann_unit`
648-
"""
649-
from mmif.utils.timeunit_helper import convert
650-
return convert(start, input_unit, ann_unit, 1), convert(end, input_unit, ann_unit, 1)
651-
652-
def get_annotations_between_time(self, start: Union[int, float], end: Union[int, float],
659+
def get_annotations_between_time(self, start: Union[int, float], end: Union[int, float],
653660
time_unit: str = "ms") -> Iterator[Annotation]:
654661
"""
655662
Finds annotations that are anchored between the given time points.
@@ -662,34 +669,24 @@ def get_annotations_between_time(self, start: Union[int, float], end: Union[int,
662669
assert start < end, f"Start time point must be smaller than the end time point, given {start} and {end}"
663670
assert start >= 0, f"Start time point must be non-negative, given {start}"
664671
assert end >= 0, f"End time point must be non-negative, given {end}"
672+
673+
from mmif.utils.timeunit_helper import convert
665674

666-
tf_in_range = []
667-
tf_to_anns = defaultdict(list)
675+
time_anchors_in_range = []
668676

669-
# Runtime: O(V * (TF * AL))
670-
for view in self.get_all_views_contain([AnnotationTypes.TimeFrame, AnnotationTypes.Alignment]):
677+
for view in self.get_all_views_contain(AnnotationTypes.TimeFrame) + self.get_all_views_contain(AnnotationTypes.TimePoint):
671678
time_unit_in_view = view.metadata.contains.get(AnnotationTypes.TimeFrame)["timeUnit"]
672-
start_time, end_time = self._handle_time_unit(time_unit, time_unit_in_view, start, end)
673-
674-
tf_anns = view.get_annotations(AnnotationTypes.TimeFrame)
675-
al_anns = view.get_annotations(AnnotationTypes.Alignment)
676-
677-
for tf_ann in tf_anns:
678-
if self._is_in_time_range(tf_ann, start_time, end_time):
679-
tf_in_range.append(tf_ann)
680-
tf_to_anns[self.get_start(tf_ann)] = []
681-
682-
for al_ann in al_anns:
683-
for tf in tf_in_range:
684-
target_ann_long_id = tf.aligned_to_by(al_ann.long_id)
685-
if target_ann_long_id:
686-
tf_to_anns[self.get_start(tf)].append(view.get_annotation_by_id(target_ann_long_id))
687-
break
688-
689-
# Runtime: O(TF + AL)
690-
for start_point, anns in dict(sorted(tf_to_anns.items())).items():
691-
for ann in anns:
692-
yield ann
679+
680+
start_time = convert(start, time_unit, time_unit_in_view, 1)
681+
end_time = convert(end, time_unit, time_unit_in_view, 1)
682+
for ann in view.get_annotations():
683+
if ann.at_type in (AnnotationTypes.TimeFrame, AnnotationTypes.TimePoint) and self._is_in_time_range(ann, start_time, end_time):
684+
time_anchors_in_range.append(ann)
685+
time_anchors_in_range.sort(key=lambda x: self.get_start(x))
686+
for time_anchor in time_anchors_in_range:
687+
yield time_anchor
688+
for aligned in time_anchor.get_all_aligned():
689+
yield aligned
693690

694691
def _get_linear_anchor_point(self, ann: Annotation, targets_sorted=False, start: bool = True) -> Union[int, float]:
695692
# TODO (krim @ 2/5/24): Update the return type once timeunits are unified to `ms` as integers (https://github.com/clamsproject/mmif/issues/192)

mmif/serialize/view.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,9 @@ def add_annotation(self, annotation: 'Annotation', overwrite=False) -> 'Annotati
122122
self.annotations.append(annotation, overwrite)
123123
self.new_contain(annotation.at_type)
124124
if annotation.at_type == AnnotationTypes.Alignment:
125-
self._cache_alignment(annotation)
125+
self._parent_mmif._cache_alignment(annotation)
126126
return annotation
127127

128-
def _cache_alignment(self, alignent_ann: 'Annotation'):
129-
if all(map(lambda x: x in alignent_ann.properties, ('source', 'target'))):
130-
source_ann = self.get_annotation_by_id(alignent_ann.get('source'))
131-
target_ann = self.get_annotation_by_id(alignent_ann.get('target'))
132-
source_ann._cache_alignment(alignent_ann.long_id, target_ann.long_id)
133-
target_ann._cache_alignment(alignent_ann.long_id, source_ann.long_id)
134-
135128
def new_textdocument(self, text: str, lang: str = "en", did: Optional[str] = None,
136129
overwrite=False, **properties) -> 'Document':
137130
"""

mmif/utils/video_document_helper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import importlib
2+
import math
23
import warnings
34
from typing import List, Union, Tuple
4-
import math
55

66
import mmif
77
from mmif import Annotation, Document, Mmif
88
from mmif.utils.timeunit_helper import convert
9-
from mmif.vocabulary import DocumentTypes, AnnotationTypes
9+
from mmif.vocabulary import DocumentTypes
1010

1111
for cv_dep in ('cv2', 'ffmpeg', 'PIL'):
1212
try:
@@ -212,19 +212,19 @@ def convert_timepoint(mmif: Mmif, timepoint: Annotation, out_unit: str) -> Union
212212
return convert(timepoint.get_property('timePoint'), in_unit, out_unit, get_framerate(vd))
213213

214214

215-
def convert_timeframe(mmif: Mmif, time_frame: Annotation, out_unit: str) -> Union[Tuple[Union[int, float, str], Union[int, float, str]]]:
215+
def convert_timeframe(mmif: Mmif, time_frame: Annotation, out_unit: str) -> Tuple[Union[int, float, str], Union[int, float, str]]:
216216
"""
217217
Converts start and end points in a ``TimeFrame`` annotation a different time unit.
218218
219219
:param mmif: :py:class:`~mmif.serialize.mmif.Mmif` instance
220220
:param time_frame: :py:class:`~mmif.serialize.annotation.Annotation` instance that holds a time interval annotation (``"@type": ".../TimeFrame/..."``)
221221
:param out_unit: time unit to which the point is converted
222-
:return: tuple of frame numbers (integer) or seconds/milliseconds (float) of input start and end
222+
:return: tuple of frame numbers, seconds/milliseconds, or ISO notation of TimeFrame's start and end
223223
"""
224224
in_unit = time_frame.get_property('timeUnit')
225225
vd = mmif[time_frame.get_property('document')]
226-
return convert(mmif.get_start(time_frame), in_unit, out_unit, get_framerate(vd)), \
227-
convert(mmif.get_end(time_frame), in_unit, out_unit, get_framerate(vd))
226+
fps = get_framerate(vd)
227+
return convert(time_frame.get_property('start'), in_unit, out_unit, fps), convert(time_frame.get_property('end'), in_unit, out_unit, fps)
228228

229229

230230
def framenum_to_second(video_doc: Document, frame: int):

tests/test_serialize.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def test_document_location_helpers(self):
222222
def test_document_location_helpers_http(self):
223223
new_doc = Document()
224224
new_doc.id = "d1"
225-
new_doc.location = f"https://www.gnu.org/licenses/gpl-3.0.txt"
225+
new_doc.location = f"https://example.com/"
226226
self.assertEqual(new_doc.location_scheme(), 'https')
227227
try:
228228
path = new_doc.location_path()
@@ -277,11 +277,11 @@ def test_get_all_views_contain(self):
277277
self.assertEqual(2, len(views))
278278
views = mmif_obj.get_all_views_contain('http://vocab.lappsgrid.org/SemanticTag')
279279
self.assertEqual(1, len(views))
280-
views = mmif_obj.get_views_contain([
280+
views = mmif_obj.get_views_contain(
281281
AnnotationTypes.TimeFrame,
282282
DocumentTypes.TextDocument,
283283
AnnotationTypes.Alignment,
284-
])
284+
)
285285
self.assertEqual(1, len(views))
286286
views = mmif_obj.get_all_views_contain(not_existing_attype)
287287
self.assertEqual(0, len(views))
@@ -324,16 +324,16 @@ def test_get_alignments(self):
324324
self.assertEqual(1, len(views_and_alignments))
325325
self.assertTrue('v6' in views_and_alignments)
326326

327-
def test_alignment_caching(self):
327+
def test_cache_alignment(self):
328328
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])
329329
views_and_alignments = mmif_obj.get_alignments(DocumentTypes.TextDocument, AnnotationTypes.TimeFrame)
330330
for vid, alignments in views_and_alignments.items():
331331
v = mmif_obj.get_view_by_id(vid)
332332
for alignment in alignments:
333333
s = v.get_annotation_by_id(alignment.get('source'))
334334
t = v.get_annotation_by_id(alignment.get('target'))
335-
self.assertTrue(s.aligned_to_by(alignment.long_id).endswith(t.long_id))
336-
self.assertTrue(t.aligned_to_by(alignment.long_id).endswith(s.long_id))
335+
self.assertTrue(s.aligned_to_by(alignment).long_id.endswith(t.long_id))
336+
self.assertTrue(t.aligned_to_by(alignment).long_id.endswith(s.long_id))
337337

338338
def test_new_view_id(self):
339339
p = Mmif.view_prefix
@@ -393,23 +393,24 @@ def test_get_annotations_between_time(self):
393393
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])
394394

395395
# Test case 1: All token annotations are selected
396-
selected_token_anns = mmif_obj.get_annotations_between_time(0, 22000)
397-
self.assertEqual(28, len(list(selected_token_anns)))
396+
selected_token_anns = [ann for ann in mmif_obj.get_annotations_between_time(0, 22000) if ann.is_type(token_type)]
397+
self.assertEqual(28, len(selected_token_anns))
398398
for i, ann in enumerate(selected_token_anns):
399-
self.assertTrue(ann.is_type(token_type))
400-
self.assertEqual(tokens_in_order[i], ann.get_property("text"))
399+
self.assertEqual(tokens_in_order[i], ann.get_property("word"))
401400

402401
# Test case 2: No token annotation are selected
403-
selected_token_anns = mmif_obj.get_annotations_between_time(0, 5, time_unit="seconds")
404-
self.assertEqual(0, len(list(selected_token_anns)))
402+
selected_token_anns = list(mmif_obj.get_annotations_between_time(0, 5, time_unit="seconds"))
403+
self.assertEqual(4, len(list(selected_token_anns)))
404+
for ann in selected_token_anns:
405+
self.assertFalse(ann.is_type(token_type))
405406

406407
# Test case 3(a): Partial tokens are selected (involve partial overlap)
407408
selected_token_anns = mmif_obj.get_annotations_between_time(7, 10, time_unit="seconds")
408-
self.assertEqual(tokens_in_order[3:9], [ann.get_property("text") for ann in selected_token_anns])
409+
self.assertEqual(tokens_in_order[3:9], [ann.get_property("word") for ann in selected_token_anns])
409410

410411
# Test case 3(b): Partial tokens are selected (only full overlap)
411412
selected_token_anns = mmif_obj.get_annotations_between_time(11500, 14600)
412-
self.assertEqual(tokens_in_order[12:17], [ann.get_property("text") for ann in selected_token_anns])
413+
self.assertEqual(tokens_in_order[12:17], [ann.get_property("word") for ann in selected_token_anns])
413414

414415
def test_add_document(self):
415416
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])

tests/test_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from mmif import Mmif, Document, AnnotationTypes
66
from mmif.utils import sequence_helper as sqh
7+
from mmif.utils import text_document_helper as tdh
78
from mmif.utils import timeunit_helper as tuh
89
from mmif.utils import video_document_helper as vdh
9-
from mmif.utils import text_document_helper as tdh
10-
from mmif.serialize import mmif
1110
from tests.mmif_examples import *
1211

1312

@@ -178,7 +177,7 @@ def test_width_based_smoothing(self):
178177
class TestTextDocHelper(unittest.TestCase):
179178
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])
180179

181-
@pytest.mark.skip("The only valid test cases come from kalbi app which annotates wrong property")
180+
@pytest.mark.skip("The only valid test cases come from kaldi app which annotates wrong property")
182181
def test_slice_text(self):
183182
sliced_text_full_overlap = tdh.slice_text(self.mmif_obj, 11500, 14600)
184183
sliced_text_partial_overlap = tdh.slice_text(self.mmif_obj, 7, 10, unit="seconds")

0 commit comments

Comments
 (0)