diff --git a/ALMEval_code/datasets/__init__.py b/ALMEval_code/datasets/__init__.py index d6a1969..24043c4 100644 --- a/ALMEval_code/datasets/__init__.py +++ b/ALMEval_code/datasets/__init__.py @@ -2,6 +2,9 @@ TemporalReasoningDataset, TemporalReasoningGivenCaptionDataset, TemporalReasoningGivenUncutAudioDataset, + TemporalReasoningSingleAudioDataset, + TemporalReasoningGivenCaptionSingleAudioDataset, + TemporalReasoningGivenUncutSingleAudioDataset, SpatialReasoningDataset, SpatialReasoningChannelwiseDataset, PerceptionDataset, @@ -15,6 +18,10 @@ 'tr': TemporalReasoningDataset, 'tr_cap': TemporalReasoningGivenCaptionDataset, 'tr_uncut': TemporalReasoningGivenUncutAudioDataset, + # Single-audio temporal variants: 3 clips are concatenated into one audio (2s gaps). + 'tr_single': TemporalReasoningSingleAudioDataset, # no extra context + 'tr_cap_single': TemporalReasoningGivenCaptionSingleAudioDataset, # + global caption + 'tr_uncut_single': TemporalReasoningGivenUncutSingleAudioDataset, # + uncut reference audio # Spatial Reasoning 'sr': SpatialReasoningDataset, 'sr_ch': SpatialReasoningChannelwiseDataset, @@ -71,4 +78,4 @@ def build_dataset(dataset_names: list, dataset_root: str) -> list: dataset_class = DATASET_REGISTRY[alias] dataset_objects.append(dataset_class(dataset_root=dataset_root)) - return dataset_objects \ No newline at end of file + return dataset_objects diff --git a/ALMEval_code/datasets/starbench.py b/ALMEval_code/datasets/starbench.py index d24e1b8..d018c40 100644 --- a/ALMEval_code/datasets/starbench.py +++ b/ALMEval_code/datasets/starbench.py @@ -9,6 +9,7 @@ import random from loguru import logger import numpy as np +from pydub import AudioSegment class MCQBaseDataset(Dataset): DATASET_ALIAS = None @@ -284,11 +285,20 @@ def parse_multi_choice_response(self, response: str, options: list) -> str: # fallback: no valid match found return 'Z' +def merge_pydub(audio_paths, out_path, silence_sec=2.0): + combined = AudioSegment.empty() + silence = AudioSegment.silent(duration=silence_sec * 1000) # pydub用毫秒 + + for path in audio_paths: + combined += AudioSegment.from_file(path) + silence + combined.export(out_path, format="wav") + return out_path class TemporalReasoningDataset(MCQBaseDataset): DATASET_ALIAS="tr" JSON_PATH="meta_info/holistic_reasoning_temporal.json" + MULTI_AUDIO = True def _perm_by_roundrobin(self, rotate_id: int, tid: str) -> list[int]: PERMS_123 = [ @@ -351,13 +361,28 @@ def build_prompt(self, line: int | dict, **kwargs) -> dict: prompts.extend(self._build_auxiliary_prompts(line)) assert len(audio_paths) == 3, "Temporal reasoning expects 3 audio clips." - for i in range(len(audio_paths)): + + if self.MULTI_AUDIO: + for i in range(len(audio_paths)): + prompts.extend( + [ + {'type':'text', 'value':f'\nclip {i+1}:'}, + {'type':'audio', 'value': os.path.join(self.dataset_root, audio_paths[i])} + ] + ) + else: + rel_dir = os.path.dirname(audio_paths[0]).replace('starbench_audios/', 'starbench_audios_cat2sec/') + audio_abs_paths = [os.path.join(self.dataset_root, audio_paths[i]) for i in range(len(audio_paths))] + concat_path = os.path.join(self.dataset_root, rel_dir, f"{shuffled_seg_order}.wav") + os.makedirs(os.path.dirname(concat_path), exist_ok=True) + merge_pydub(audio_abs_paths, concat_path) prompts.extend( [ - {'type':'text', 'value':f'\nclip {i+1}:'}, - {'type':'audio', 'value': os.path.join(self.dataset_root, audio_paths[i])} + {'type':'text', 'value':"Below are three audio segments (clip 1, clip 2, and clip 3) concatenated with a 2-second gap between them."}, + {'type':'audio', 'value': concat_path} ] - ) + ) + msg = { "meta":{ "id": line['id'], @@ -379,6 +404,8 @@ def build_prompt(self, line: int | dict, **kwargs) -> dict: class TemporalReasoningGivenCaptionDataset(TemporalReasoningDataset): DATASET_ALIAS = "tr_cap" JSON_PATH = "meta_info/holistic_reasoning_temporal.json" + MULTI_AUDIO = True + def _build_auxiliary_prompts(self, line: dict) -> list: """Overrides the hook to add a caption.""" caption_prompt = f"Here is a caption that describes the full, uncut audio scene to help you reconstruct the original context: {line['global_caption']}\n Below are 3 audio clips:\n" @@ -395,6 +422,7 @@ def build_prompt(self, line: dict, **kwargs) -> dict: class TemporalReasoningGivenUncutAudioDataset(TemporalReasoningDataset): DATASET_ALIAS = "tr_uncut" JSON_PATH = "meta_info/holistic_reasoning_temporal.json" + MULTI_AUDIO = True def _build_auxiliary_prompts(self, line: dict) -> list: """Overrides the hook to add the uncut audio.""" @@ -410,6 +438,22 @@ def build_prompt(self, line: dict, **kwargs) -> dict: return msg +class TemporalReasoningSingleAudioDataset(TemporalReasoningDataset): + DATASET_ALIAS = "tr_single" + # Use one merged audio prompt (clip1+clip2+clip3 with 2s silence) instead of 3 separate clips. + MULTI_AUDIO = False + +class TemporalReasoningGivenCaptionSingleAudioDataset(TemporalReasoningGivenCaptionDataset): + DATASET_ALIAS = "tr_cap_single" + # Same as tr_single, but includes global_caption context in the prompt. + MULTI_AUDIO = False + +class TemporalReasoningGivenUncutSingleAudioDataset(TemporalReasoningGivenUncutAudioDataset): + DATASET_ALIAS = "tr_uncut_single" + # Same as tr_single, but also provides the uncut full-scene reference audio. + MULTI_AUDIO = False + + class RotationBasedDataset(MCQBaseDataset): @@ -603,4 +647,3 @@ class PerceptionNonSpatialDataset(PerceptionDataset): - diff --git a/ALMEval_code/models/audio_flamingo.py b/ALMEval_code/models/audio_flamingo.py index ed125c9..73cf418 100644 --- a/ALMEval_code/models/audio_flamingo.py +++ b/ALMEval_code/models/audio_flamingo.py @@ -1,6 +1,8 @@ import os +import tempfile import torch from .base import BaseModel +from datasets.starbench import merge_pydub # Audio Flamingo 3 HF version, you'd better update transformers >= 5.0 class AudioFlamingo3HF(BaseModel): @@ -54,17 +56,33 @@ def __init__(self, model_path='nvidia/audio-flamingo-3-hf', thinking=False, **kw def generate_inner(self, msgs): meta = msgs.get('meta', None) prompts = msgs.get('prompts', None) - content = [] + + # Collect text and audio items separately + text_parts = [] + audio_paths = [] for x in prompts: if x['type'] == 'text': - content.append({"type": "text", "text": x['value']}) + text_parts.append(x['value']) elif x['type'] == 'audio': - content.append({"type": "audio", "path": x['value']}) + audio_paths.append(x['value']) + + # HF processor requires text:audio = 1:1 + # If multiple audios, merge them into one with 2s silence gaps + merged_path = None + if len(audio_paths) > 1: + merged_path = tempfile.NamedTemporaryFile(suffix='.wav', delete=False).name + merge_pydub(audio_paths, merged_path) + audio_paths = [merged_path] + + # Build content: all text first, then single audio + combined_text = "\n".join(text_parts) if self.thinking: - content.append({ - "type": "text", - "text": "Please think and reason about the input audio before you respond.", - }) + combined_text += "\nPlease think and reason about the input audio before you respond." + + content = [ + {"type": "text", "text": combined_text}, + {"type": "audio", "path": audio_paths[0]}, + ] messages = [{'role': 'user', 'content': content}] inputs = self.processor.apply_chat_template( @@ -91,5 +109,8 @@ def generate_inner(self, msgs): output = self.processor.batch_decode( model_output.sequences[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] - print('【audio flamingo3 output】:', output) + if merged_path and os.path.exists(merged_path): + os.remove(merged_path) + + print('【audio flamingo3 output】:', output) return output