Skip to content

Commit 416c717

Browse files
authored
Merge pull request #284 from clamsproject/280-fix-textslicer
280-fix-textslicer
2 parents c42234f + 0f894a1 commit 416c717

File tree

4 files changed

+167
-9
lines changed

4 files changed

+167
-9
lines changed

mmif/serialize/mmif.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import warnings
1111
from collections import defaultdict
1212
from datetime import datetime
13-
from typing import List, Union, Optional, Dict, cast
13+
from typing import List, Union, Optional, Dict, cast, Iterator, Tuple
1414

1515
import jsonschema.validators
1616

@@ -609,7 +609,74 @@ def get_view_contains(self, at_types: Union[ThingTypesBase, str, List[Union[str,
609609
if at_types in view.metadata.contains:
610610
return view
611611
return None
612-
612+
613+
def _is_in_time_between(self, start: Union[int, float], end: Union[int, float], annotation: Annotation) -> bool:
614+
s, e = self.get_start(annotation), self.get_end(annotation)
615+
return (s < start < e) or (s > start and e < end) or (s < end < e)
616+
617+
def _handle_time_unit(self, input_unit: str, ann_unit: str,
618+
start: int, end: int) -> Tuple[Union[int, float, str], Union[int, float, str]]:
619+
from mmif.utils.timeunit_helper import convert
620+
start = convert(start, input_unit, ann_unit, 1)
621+
end = convert(end, input_unit, ann_unit, 1)
622+
return start, end
623+
624+
def get_annotations_between_time(self, start: int, end: int, time_unit: str = "milliseconds") -> Iterator[Annotation]:
625+
"""
626+
Version: 1.0
627+
Returns all 'Token' annotations aligned with 'TimeFrame' annotations sorted by start time within start and end time
628+
Note: this function only works for mmif object obtained from Whisper-wrapper
629+
630+
:param start: the start time
631+
:param end: the end time
632+
:param time_unit: the time unit, either string "milliseconds" or "seconds", defaults to "milliseconds"
633+
:return: a generator of 'Token' annotations
634+
"""
635+
assert start <= end, "Start time must be less than end time"
636+
assert start >= 0, "Start time must be greater than or equal to zero"
637+
assert end >= 0, "End time must be greater than or equal to zero"
638+
# 0. Initialize container and helper method
639+
valid_tf_anns = []
640+
tf_to_anns = defaultdict(list)
641+
642+
# 1. find all views that contain the type of TF
643+
views = self.get_all_views_contain([AnnotationTypes.TimeFrame, AnnotationTypes.Alignment])
644+
645+
# 2. For each view, extract annotations that satisfy conditions that are TF/TP and fall into time interval
646+
for view in views:
647+
# Make sure time unit stay at the same level
648+
start_time, end_time = self._handle_time_unit(time_unit, view.metadata.contains.get(AnnotationTypes.TimeFrame)["timeUnit"],
649+
start, end)
650+
tf_anns = view.get_annotations(at_type=AnnotationTypes.TimeFrame)
651+
al_anns = view.get_annotations(at_type=AnnotationTypes.Alignment)
652+
653+
# Select 'TimeFrame' annotations within given time interval
654+
for tf in tf_anns:
655+
if self._is_in_time_between(start_time, end_time, tf):
656+
valid_tf_anns.append(tf)
657+
658+
# Map 'TimeFrame' annotation to its aligned annotation
659+
for align in al_anns:
660+
source_id, target_id = align.get_property('source'), align.get_property('target')
661+
to_long_id = lambda x: x if self.id_delimiter in x else f'{view.id}{self.id_delimiter}{x}'
662+
try:
663+
source, target = view.get_annotation_by_id(source_id), view.get_annotation_by_id(target_id)
664+
if source in valid_tf_anns:
665+
tf_to_anns[to_long_id(source_id)].append(target)
666+
elif target in valid_tf_anns:
667+
tf_to_anns[to_long_id(target_id)].append(source)
668+
except KeyError:
669+
pass
670+
671+
# 3. For those extracted 'TimeFrame' annotations, sort them by their start time
672+
sort_tf_anns = sorted(valid_tf_anns, key=lambda x: self.get_start(x))
673+
674+
# 4. Yield all annotations aligned with sorted 'TimeFrame' annotations
675+
for tf_ann in sort_tf_anns:
676+
anns = tf_to_anns[tf_ann.long_id]
677+
for ann in anns:
678+
yield ann
679+
613680
def _get_linear_anchor_point(self, ann: Annotation, targets_sorted=False, start: bool = True) -> Union[int, float]:
614681
# TODO (krim @ 2/5/24): Update the return type once timeunits are unified to `ms` as integers (https://github.com/clamsproject/mmif/issues/192)
615682
"""

mmif/utils/text_document_helper.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import mmif
2+
from mmif import Annotation
3+
4+
5+
def slice_text(mmif_obj, start: int, end: int, unit: str = "milliseconds") -> str:
6+
token_type = "http://vocab.lappsgrid.org/Token"
7+
anns_found = mmif_obj.get_annotations_between_time(start, end, unit)
8+
tokens_sliced = []
9+
for ann in anns_found:
10+
if ann.is_type(token_type):
11+
tokens_sliced.append(ann.get_property('word'))
12+
return ' '.join(tokens_sliced)

tests/test_serialize.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,64 @@ def test_new_view_id(self):
342342
self.assertEqual(e_view.id, f'{p}4')
343343
self.assertEqual(len(mmif_obj.views), 5)
344344

345+
def test_get_annotations_between_time(self):
346+
token_type = "http://vocab.lappsgrid.org/Token"
347+
# Below tokens are obtained by 'jq' in CLI using command:
348+
# jq '[
349+
# .views[3].annotations |
350+
# .[] |
351+
# select(."@type"=="http://vocab.lappsgrid.org/Token")] |
352+
# sort_by(.properties.id | ltrimstr("t") | tonumber) |
353+
# map(.properties.text)' <examples>.json
354+
tokens_in_order = ["Hello",
355+
",",
356+
"this",
357+
"is",
358+
"Jim",
359+
"Lehrer",
360+
"with",
361+
"the",
362+
"NewsHour",
363+
"on",
364+
"PBS",
365+
".",
366+
"In",
367+
"the",
368+
"nineteen",
369+
"eighties",
370+
",",
371+
"barking",
372+
"dogs",
373+
"have",
374+
"increasingly",
375+
"become",
376+
"a",
377+
"problem",
378+
"in",
379+
"urban",
380+
"areas",
381+
"."]
382+
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])
383+
384+
# Test case 1: All token annotations are selected
385+
selected_token_anns = mmif_obj.get_annotations_between_time(0, 22000)
386+
self.assertEqual(28, len(list(selected_token_anns)))
387+
for i, ann in enumerate(selected_token_anns):
388+
self.assertTrue(ann.is_type(token_type))
389+
self.assertEqual(tokens_in_order[i], ann.get_property("text"))
390+
391+
# Test case 2: No token annotation are selected
392+
selected_token_anns = mmif_obj.get_annotations_between_time(0, 5, time_unit="seconds")
393+
self.assertEqual(0, len(list(selected_token_anns)))
394+
395+
# Test case 3(a): Partial tokens are selected (involve partial overlap)
396+
selected_token_anns = mmif_obj.get_annotations_between_time(7, 10, time_unit="seconds")
397+
self.assertEqual(tokens_in_order[3:9], [ann.get_property("text") for ann in selected_token_anns])
398+
399+
# Test case 3(b): Partial tokens are selected (only full overlap)
400+
selected_token_anns = mmif_obj.get_annotations_between_time(11500, 14600)
401+
self.assertEqual(tokens_in_order[12:17], [ann.get_property("text") for ann in selected_token_anns])
402+
345403
def test_add_document(self):
346404
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])
347405
med_obj = Document(FRACTIONAL_EXAMPLES['doc_only'])

tests/test_utils.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@
66
from mmif.utils import sequence_helper as sqh
77
from mmif.utils import timeunit_helper as tuh
88
from mmif.utils import video_document_helper as vdh
9+
from mmif.utils import text_document_helper as tdh
10+
from mmif.serialize import mmif
11+
from tests.mmif_examples import *
912

1013

1114
class TestTimeunitHelper(unittest.TestCase):
12-
1315
FPS = 30
14-
16+
1517
def test_convert(self):
1618
self.assertEqual(1000, tuh.convert(1, 's', 'ms', self.FPS))
1719
self.assertEqual(1.1, tuh.convert(1100, 'ms', 's', self.FPS))
@@ -35,7 +37,7 @@ def setUp(self):
3537
})
3638
self.video_doc.add_property('fps', self.fps)
3739
self.mmif_obj.add_document(self.video_doc)
38-
40+
3941
def test_extract_mid_frame(self):
4042
tf = self.a_view.new_annotation(AnnotationTypes.TimeFrame, start=100, end=200, timeUnit='frame', document='d1')
4143
self.assertEqual(150, vdh.get_mid_framenum(self.mmif_obj, tf))
@@ -92,11 +94,12 @@ def test_sample_frames(self):
9294
s_frame = vdh.second_to_framenum(self.video_doc, 3)
9395
e_frame = vdh.second_to_framenum(self.video_doc, 5)
9496
self.assertEqual(1, len(vdh.sample_frames(s_frame, e_frame, 60)))
95-
97+
9698
def test_convert_timepoint(self):
97-
timepoint_ann = self.a_view.new_annotation(AnnotationTypes.BoundingBox, timePoint=3, timeUnit='second', document='d1')
99+
timepoint_ann = self.a_view.new_annotation(AnnotationTypes.BoundingBox, timePoint=3, timeUnit='second',
100+
document='d1')
98101
self.assertEqual(vdh.convert(3, 's', 'f', self.fps), vdh.convert_timepoint(self.mmif_obj, timepoint_ann, 'f'))
99-
102+
100103
def test_convert_timeframe(self):
101104
self.a_view.metadata.new_contain(AnnotationTypes.TimeFrame, timeUnit='frame', document='d1')
102105
timeframe_ann = self.a_view.new_annotation(AnnotationTypes.TimeFrame, start=100, end=200)
@@ -105,7 +108,7 @@ def test_convert_timeframe(self):
105108

106109

107110
class TestSequenceHelper(unittest.TestCase):
108-
111+
109112
def test_validate_labelset(self):
110113
mmif_obj = Mmif(validate=False)
111114
view = mmif_obj.new_view()
@@ -172,5 +175,23 @@ def test_width_based_smoothing(self):
172175
sqh.smooth_outlying_short_intervals(scores, 1, 1))
173176

174177

178+
class TestTextDocHelper(unittest.TestCase):
179+
mmif_obj = Mmif(MMIF_EXAMPLES['everything'])
180+
181+
@pytest.mark.skip("The only valid test cases come from kalbi app which annotates wrong property")
182+
def test_slice_text(self):
183+
sliced_text_full_overlap = tdh.slice_text(self.mmif_obj, 11500, 14600)
184+
sliced_text_partial_overlap = tdh.slice_text(self.mmif_obj, 7, 10, unit="seconds")
185+
no_sliced_text = tdh.slice_text(self.mmif_obj, 0, 5000)
186+
full_sliced_text = tdh.slice_text(self.mmif_obj, 0, 22, unit="seconds")
187+
self.assertEqual("In the nineteen eighties ,", sliced_text_full_overlap)
188+
self.assertEqual("is Jim Lehrer with the NewsHour", sliced_text_partial_overlap)
189+
self.assertEqual("", no_sliced_text)
190+
self.assertEqual(
191+
"Hello , this is Jim Lehrer with the NewsHour on PBS . "
192+
"In the nineteen eighties , barking dogs have increasingly become a problem in urban areas .",
193+
full_sliced_text)
194+
195+
175196
if __name__ == '__main__':
176197
unittest.main()

0 commit comments

Comments
 (0)