diff --git a/mmif/utils/video_document_helper.py b/mmif/utils/video_document_helper.py index c14689e5..ec3cebac 100644 --- a/mmif/utils/video_document_helper.py +++ b/mmif/utils/video_document_helper.py @@ -6,7 +6,7 @@ import mmif from mmif import Annotation, Document, Mmif from mmif.utils.timeunit_helper import convert -from mmif.vocabulary import DocumentTypes +from mmif.vocabulary import DocumentTypes, AnnotationTypes for cv_dep in ('cv2', 'ffmpeg', 'PIL'): try: @@ -127,6 +127,42 @@ def extract_mid_frame(mmif: Mmif, time_frame: Annotation, as_PIL: bool = False): return extract_frames_as_images(vd, [get_mid_framenum(mmif, time_frame)], as_PIL=as_PIL)[0] +def get_representative_framenum(mmif: Mmif, time_frame: Annotation): + """ + Calculates the representative frame number from an annotation. + + :param mmif: :py:class:`~mmif.serialize.mmif.Mmif` instance + :param time_frame: :py:class:`~mmif.serialize.annotation.Annotation` instance that holds a time interval annotation containing a `representatives` property (``"@type": ".../TimeFrame/..."``) + :return: representative frame number as an integer + """ + if 'representatives' not in time_frame.properties: + raise ValueError(f'The time frame {time_frame.id} does not have a representative.') + timeunit = time_frame.get_property('timeUnit') + video_document = mmif[time_frame.get_property('document')] + fps = get_framerate(video_document) + representatives = time_frame.get_property('representatives') + top_representative_id = representatives[0] + try: + representative_timepoint_anno = mmif[time_frame._parent_view_id+time_frame.id_delimiter+top_representative_id] + except KeyError: + raise ValueError(f'Representative timepoint {top_representative_id} not found in any view.') + return convert(representative_timepoint_anno.get_property('timePoint'), timeunit, 'frame', fps) + + +def extract_representative_frame(mmif: Mmif, time_frame: Annotation, as_PIL: bool = False): + """ + Extracts the representative frame of an annotation as a numpy ndarray or PIL Image. + + :param mmif: :py:class:`~mmif.serialize.mmif.Mmif` instance + :param time_frame: :py:class:`~mmif.serialize.annotation.Annotation` instance that holds a time interval annotation (``"@type": ".../TimeFrame/..."``) + :param as_PIL: return :py:class:`~PIL.Image.Image` instead of :py:class:`~numpy.ndarray` + :return: frame as a :py:class:`numpy.ndarray` or :py:class:`PIL.Image.Image` + """ + video_document = mmif[time_frame.get_property('document')] + rep_frame_num = get_representative_framenum(mmif, time_frame) + return extract_frames_as_images(video_document, [rep_frame_num], as_PIL=as_PIL)[0] + + def sample_frames(start_frame: int, end_frame: int, sample_rate: float = 1) -> List[int]: """ Helper function to sample frames from a time interval. diff --git a/tests/test_utils.py b/tests/test_utils.py index 332787e6..150bd71f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -44,6 +44,23 @@ def test_extract_mid_frame(self): tf = self.a_view.new_annotation(AnnotationTypes.TimeFrame, start=0, end=3, timeUnit='seconds', document='d1') self.assertEqual(vdh.convert(1.5, 's', 'f', self.fps), vdh.get_mid_framenum(self.mmif_obj, tf)) + def test_extract_representative_frame(self): + tp = self.a_view.new_annotation(AnnotationTypes.TimePoint, timePoint=1500, timeUnit='milliseconds', document='d1') + tf = self.a_view.new_annotation(AnnotationTypes.TimeFrame, start=1000, end=2000, timeUnit='milliseconds', document='d1') + tf.add_property('representatives', [tp.id]) + rep_frame_num = vdh.get_representative_framenum(self.mmif_obj, tf) + expected_frame_num = vdh.millisecond_to_framenum(self.video_doc, tp.get_property('timePoint')) + self.assertEqual(expected_frame_num, rep_frame_num) + # check there is an error if no representatives + tf = self.a_view.new_annotation(AnnotationTypes.TimeFrame, start=1000, end=2000, timeUnit='milliseconds', document='d1') + with pytest.raises(ValueError): + vdh.get_representative_framenum(self.mmif_obj, tf) + # check there is an error if there is a representative referencing a timepoint that + # does not exist + tf.add_property('representatives', ['fake_tp_id']) + with pytest.raises(ValueError): + vdh.get_representative_framenum(self.mmif_obj, tf) + def test_get_framerate(self): self.assertAlmostEqual(29.97, vdh.get_framerate(self.video_doc), places=0)