Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion ALMEval_code/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
TemporalReasoningDataset,
TemporalReasoningGivenCaptionDataset,
TemporalReasoningGivenUncutAudioDataset,
TemporalReasoningSingleAudioDataset,
TemporalReasoningGivenCaptionSingleAudioDataset,
TemporalReasoningGivenUncutSingleAudioDataset,
SpatialReasoningDataset,
SpatialReasoningChannelwiseDataset,
PerceptionDataset,
Expand All @@ -15,6 +18,10 @@
'tr': TemporalReasoningDataset,
'tr_cap': TemporalReasoningGivenCaptionDataset,
'tr_uncut': TemporalReasoningGivenUncutAudioDataset,
# Single-audio temporal variants: 3 clips are concatenated into one audio (2s gaps).
'tr_single': TemporalReasoningSingleAudioDataset, # no extra context
'tr_cap_single': TemporalReasoningGivenCaptionSingleAudioDataset, # + global caption
'tr_uncut_single': TemporalReasoningGivenUncutSingleAudioDataset, # + uncut reference audio
# Spatial Reasoning
'sr': SpatialReasoningDataset,
'sr_ch': SpatialReasoningChannelwiseDataset,
Expand Down Expand Up @@ -71,4 +78,4 @@ def build_dataset(dataset_names: list, dataset_root: str) -> list:
dataset_class = DATASET_REGISTRY[alias]
dataset_objects.append(dataset_class(dataset_root=dataset_root))

return dataset_objects
return dataset_objects
53 changes: 48 additions & 5 deletions ALMEval_code/datasets/starbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import random
from loguru import logger
import numpy as np
from pydub import AudioSegment

class MCQBaseDataset(Dataset):
DATASET_ALIAS = None
Expand Down Expand Up @@ -284,11 +285,20 @@ def parse_multi_choice_response(self, response: str, options: list) -> str:
# fallback: no valid match found
return 'Z'

def merge_pydub(audio_paths, out_path, silence_sec=2.0):
combined = AudioSegment.empty()
silence = AudioSegment.silent(duration=silence_sec * 1000) # pydub用毫秒

for path in audio_paths:
combined += AudioSegment.from_file(path) + silence
combined.export(out_path, format="wav")
return out_path


class TemporalReasoningDataset(MCQBaseDataset):
DATASET_ALIAS="tr"
JSON_PATH="meta_info/holistic_reasoning_temporal.json"
MULTI_AUDIO = True

def _perm_by_roundrobin(self, rotate_id: int, tid: str) -> list[int]:
PERMS_123 = [
Expand Down Expand Up @@ -351,13 +361,28 @@ def build_prompt(self, line: int | dict, **kwargs) -> dict:
prompts.extend(self._build_auxiliary_prompts(line))

assert len(audio_paths) == 3, "Temporal reasoning expects 3 audio clips."
for i in range(len(audio_paths)):

if self.MULTI_AUDIO:
for i in range(len(audio_paths)):
prompts.extend(
[
{'type':'text', 'value':f'\nclip {i+1}:'},
{'type':'audio', 'value': os.path.join(self.dataset_root, audio_paths[i])}
]
)
else:
rel_dir = os.path.dirname(audio_paths[0]).replace('starbench_audios/', 'starbench_audios_cat2sec/')
audio_abs_paths = [os.path.join(self.dataset_root, audio_paths[i]) for i in range(len(audio_paths))]
concat_path = os.path.join(self.dataset_root, rel_dir, f"{shuffled_seg_order}.wav")
os.makedirs(os.path.dirname(concat_path), exist_ok=True)
merge_pydub(audio_abs_paths, concat_path)
prompts.extend(
[
{'type':'text', 'value':f'\nclip {i+1}:'},
{'type':'audio', 'value': os.path.join(self.dataset_root, audio_paths[i])}
{'type':'text', 'value':"Below are three audio segments (clip 1, clip 2, and clip 3) concatenated with a 2-second gap between them."},
{'type':'audio', 'value': concat_path}
]
)
)

msg = {
"meta":{
"id": line['id'],
Expand All @@ -379,6 +404,8 @@ def build_prompt(self, line: int | dict, **kwargs) -> dict:
class TemporalReasoningGivenCaptionDataset(TemporalReasoningDataset):
DATASET_ALIAS = "tr_cap"
JSON_PATH = "meta_info/holistic_reasoning_temporal.json"
MULTI_AUDIO = True

def _build_auxiliary_prompts(self, line: dict) -> list:
"""Overrides the hook to add a caption."""
caption_prompt = f"Here is a caption that describes the full, uncut audio scene to help you reconstruct the original context: {line['global_caption']}\n Below are 3 audio clips:\n"
Expand All @@ -395,6 +422,7 @@ def build_prompt(self, line: dict, **kwargs) -> dict:
class TemporalReasoningGivenUncutAudioDataset(TemporalReasoningDataset):
DATASET_ALIAS = "tr_uncut"
JSON_PATH = "meta_info/holistic_reasoning_temporal.json"
MULTI_AUDIO = True

def _build_auxiliary_prompts(self, line: dict) -> list:
"""Overrides the hook to add the uncut audio."""
Expand All @@ -410,6 +438,22 @@ def build_prompt(self, line: dict, **kwargs) -> dict:
return msg


class TemporalReasoningSingleAudioDataset(TemporalReasoningDataset):
DATASET_ALIAS = "tr_single"
# Use one merged audio prompt (clip1+clip2+clip3 with 2s silence) instead of 3 separate clips.
MULTI_AUDIO = False

class TemporalReasoningGivenCaptionSingleAudioDataset(TemporalReasoningGivenCaptionDataset):
DATASET_ALIAS = "tr_cap_single"
# Same as tr_single, but includes global_caption context in the prompt.
MULTI_AUDIO = False

class TemporalReasoningGivenUncutSingleAudioDataset(TemporalReasoningGivenUncutAudioDataset):
DATASET_ALIAS = "tr_uncut_single"
# Same as tr_single, but also provides the uncut full-scene reference audio.
MULTI_AUDIO = False




class RotationBasedDataset(MCQBaseDataset):
Expand Down Expand Up @@ -603,4 +647,3 @@ class PerceptionNonSpatialDataset(PerceptionDataset):




37 changes: 29 additions & 8 deletions ALMEval_code/models/audio_flamingo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import tempfile
import torch
from .base import BaseModel
from datasets.starbench import merge_pydub

# Audio Flamingo 3 HF version, you'd better update transformers >= 5.0
class AudioFlamingo3HF(BaseModel):
Expand Down Expand Up @@ -54,17 +56,33 @@ def __init__(self, model_path='nvidia/audio-flamingo-3-hf', thinking=False, **kw
def generate_inner(self, msgs):
meta = msgs.get('meta', None)
prompts = msgs.get('prompts', None)
content = []

# Collect text and audio items separately
text_parts = []
audio_paths = []
for x in prompts:
if x['type'] == 'text':
content.append({"type": "text", "text": x['value']})
text_parts.append(x['value'])
elif x['type'] == 'audio':
content.append({"type": "audio", "path": x['value']})
audio_paths.append(x['value'])

# HF processor requires text:audio = 1:1
# If multiple audios, merge them into one with 2s silence gaps
merged_path = None
if len(audio_paths) > 1:
merged_path = tempfile.NamedTemporaryFile(suffix='.wav', delete=False).name
merge_pydub(audio_paths, merged_path)
audio_paths = [merged_path]

# Build content: all text first, then single audio
combined_text = "\n".join(text_parts)
if self.thinking:
content.append({
"type": "text",
"text": "Please think and reason about the input audio before you respond.",
})
combined_text += "\nPlease think and reason about the input audio before you respond."

content = [
{"type": "text", "text": combined_text},
{"type": "audio", "path": audio_paths[0]},
]
messages = [{'role': 'user', 'content': content}]

inputs = self.processor.apply_chat_template(
Expand All @@ -91,5 +109,8 @@ def generate_inner(self, msgs):
output = self.processor.batch_decode(
model_output.sequences[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
print('【audio flamingo3 output】:', output)
if merged_path and os.path.exists(merged_path):
os.remove(merged_path)

print('【audio flamingo3 output】:', output)
return output