Skip to content

Commit 24baf4f

Browse files
committed
train
1 parent ba0b289 commit 24baf4f

File tree

76 files changed

+32
-22
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+32
-22
lines changed
-574 Bytes
Binary file not shown.
-2.73 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
-1.4 KB
Binary file not shown.
-5.36 KB
Binary file not shown.
-5.01 KB
Binary file not shown.
-13.4 KB
Binary file not shown.
-3.58 KB
Binary file not shown.
-1.83 KB
Binary file not shown.
-805 Bytes
Binary file not shown.
-7.58 KB
Binary file not shown.
-906 Bytes
Binary file not shown.
-4.09 KB
Binary file not shown.

data/base_datasets.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,10 @@ def __init__(self, args):
3131
self.video_decode_backend = args.video_decode_backend
3232
self.num_frames = args.num_frames
3333
self.text_type = args.text_type
34-
self.chatgpt = self.text_type == 'polish_mplug'
34+
self.total_text = ['raw', 'mplug', 'polish_mplug', 'sound_mplug'] + [f'ofa{i}' for i in range(8)]
35+
self.weight = [0.2, 0.2, 0.2, 0.2] + [0.2 / 8] * 8
3536
self.title = self.text_type == 'raw'
36-
self.data_root = '/A_Youtube'
37+
self.data_root = '/apdcephfs_cq3/share_1311970/A_Youtube'
3738
with open(args.train_data, 'r') as f:
3839
self.id2title_folder_caps = json.load(f)
3940
self.ids = list(self.id2title_folder_caps.keys())[:args.train_num_samples]
@@ -58,26 +59,26 @@ def __len__(self):
5859
return len(self.ids)
5960
# return self.id2title_folder_caps.shape[0]
6061

62+
6163
def __getitem__(self, idx):
6264
id = self.ids[idx]
6365
folder = self.id2title_folder_caps[id]['folder']
6466
try:
65-
text_output = self.get_text(id)
67+
text_output, ofa_number = self.get_text(id)
6668
input_ids, attention_mask = text_output['input_ids'], text_output['attention_mask']
6769
if self.clip_type == 'vl':
6870
matched_modality = self.get_video(id, folder)
6971
elif self.clip_type == 'al':
7072
matched_modality = self.get_audio(id, folder)
7173
elif self.clip_type == 'dl':
72-
matched_modality = self.get_depth(id, folder)
74+
matched_modality = self.get_depth(id, folder, ofa_number)
7375
elif self.clip_type == 'tl':
74-
matched_modality = self.get_thermal(id, folder)
76+
matched_modality = self.get_thermal(id, folder, ofa_number)
7577
return matched_modality['pixel_values'], input_ids, attention_mask
7678
except Exception as error_msg:
7779
logging.info(f"Failed at {id} with \"{error_msg}\"")
7880
return self.__getitem__(random.randint(0, self.__len__()-1))
7981

80-
8182
def get_video(self, id, folder):
8283
video_path = opj(self.data_root, folder, f'{id}.mp4')
8384
video = load_and_transform_video(video_path, self.video_transform,
@@ -90,11 +91,11 @@ def get_audio(self, id, folder):
9091
if os.path.exists(audio_path):
9192
pass
9293
else:
93-
audio_path = audio_path[:-4] + '.m4a'
94+
audio_path = audio_path[:-4] + '.wav'
9495
if os.path.exists(audio_path):
9596
pass
9697
else:
97-
audio_path = audio_path[:-4] + '.wav'
98+
audio_path = audio_path[:-4] + '.m4a'
9899
if not os.path.exists(audio_path):
99100
# self.audio_error_file.write(audio_path[:-4] + '\n')
100101
raise FileNotFoundError(f'Not found audio file at \'{audio_path[:-4]}\' with .mp3 .m4a .wav')
@@ -110,30 +111,39 @@ def get_audio(self, id, folder):
110111
return audio
111112

112113
def get_text(self, id):
113-
text = self.id2title_folder_caps[id][self.text_type]
114-
text_output = load_and_transform_text(text, self.tokenizer, title=self.title)
115-
return text_output
114+
if self.text_type != 'mix':
115+
text = self.id2title_folder_caps[id][self.text_type]
116+
text_output = load_and_transform_text(text, self.tokenizer, title=self.title)
117+
return text_output, None
118+
else:
119+
text_type = random.choices(self.total_text, self.weight)[0]
120+
ofa_number = None
121+
if text_type.startswith('ofa'):
122+
ofa_number = int(text_type[-1])
123+
text = self.id2title_folder_caps[id]['ofa'][ofa_number]
124+
else:
125+
text = self.id2title_folder_caps[id][text_type]
126+
text_output = load_and_transform_text(text, self.tokenizer, title=text_type=='raw')
127+
return text_output, ofa_number
116128

117-
def get_depth(self, id, folder):
129+
def get_depth(self, id, folder, ofa_number):
118130
depth_folder = opj(self.data_root, folder, f'{id}_depth_f8glpn_folder')
119-
# random_id = random.randint(0, 7)
120-
random_id = 3
131+
random_id = random.randint(0, 7) if ofa_number is None else ofa_number
132+
# random_id = 3
121133
depth_path = os.path.join(depth_folder, f'{random_id}.png')
122134
depth = load_and_transform_depth(depth_path, self.depth_transform)
123135
return depth
124136

125-
def get_thermal(self, id, folder):
126-
thermal_folder = opj(self.data_root, folder, f'{id}_thermal_f8_folder')
127-
# random_id = random.randint(0, 7)
128-
random_id = 3
137+
def get_thermal(self, id, folder, ofa_number):
138+
thermal_folder = opj(self.data_root, folder, f'{id}_thermal_folder')
139+
random_id = random.randint(0, 7) if ofa_number is None else ofa_number
140+
# random_id = 3
129141
thermal_path = os.path.join(thermal_folder, f'{random_id}.jpg')
130142
thermal = load_and_transform_thermal(thermal_path, self.thermal_transform)
131143
return thermal
132144

133145

134146

135-
136-
137147
if __name__ == '__main__':
138148
parser = argparse.ArgumentParser('Pre-training', add_help=False)
139149
parser.add_argument('--num_frames', default=8, type=float, help='')
-574 Bytes
Binary file not shown.
-2.65 KB
Binary file not shown.
-1.34 KB
Binary file not shown.
-3.35 KB
Binary file not shown.
-3.38 KB
Binary file not shown.
-4.92 KB
Binary file not shown.
-3.71 KB
Binary file not shown.
-6.45 KB
Binary file not shown.
-1.67 KB
Binary file not shown.
-9.45 KB
Binary file not shown.
-245 Bytes
Binary file not shown.
-9.31 KB
Binary file not shown.
-677 Bytes
Binary file not shown.
-6.4 KB
Binary file not shown.
-4.98 KB
Binary file not shown.
-13.5 KB
Binary file not shown.
Binary file not shown.
-2.75 KB
Binary file not shown.
-13.8 KB
Binary file not shown.
Binary file not shown.
-3.95 KB
Binary file not shown.
-8.59 KB
Binary file not shown.
-4.04 KB
Binary file not shown.
-19.9 KB
Binary file not shown.
-2.96 KB
Binary file not shown.
-159 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.

scripts/thermal_language/train.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,4 @@ TORCH_DISTRIBUTED_DEBUG=DETAIL HF_DATASETS_OFFLINE=1 TRANSFORMERS_OFFLINE=1 torc
2121
--precision "amp" --workers 10 --video-decode-backend "imgs" \
2222
--save-frequency 1 --log-every-n-steps 20 --report-to "tensorboard" --resume "latest" \
2323
--do_eval \
24-
--val_t_cls_data "LLVIP" "FLIRV1" "FLIRV2" "LSOTB"
24+
--val_t_cls_data "LLVIP" "FLIRV1" "FLIRV2"
-574 Bytes
Binary file not shown.
-2.87 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
-1.4 KB
Binary file not shown.
-127 Bytes
Binary file not shown.
-3.34 KB
Binary file not shown.
-2.56 KB
Binary file not shown.
-999 Bytes
Binary file not shown.
-12.8 KB
Binary file not shown.
-577 Bytes
Binary file not shown.
-1.73 KB
Binary file not shown.

training/params.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def parse_args(args):
3232
parser.add_argument("--languagebind_weight", type=str, default='', help="",)
3333
parser.add_argument("--num-frames", type=int, default=8, help="",)
3434
parser.add_argument("--clip-type", type=str, default="", choices=['vl', 'al', 'dl', 'tl'], help="",)
35-
parser.add_argument("--text-type", type=str, default="", help="'raw', 'ofa', 'mplug', 'chatgpt'",)
35+
parser.add_argument("--text-type", type=str, default="chatgpt", help="'raw', 'ofa', 'mplug', 'polish_mplug'",)
3636
parser.add_argument("--add-time-attn", default=False, action="store_true", help="")
3737
parser.add_argument("--unlock-time-attn", default=False, action="store_true", help="")
3838
parser.add_argument("--coef-lr", type=float, default=1e-4, help="")
-2.64 KB
Binary file not shown.
-4.91 KB
Binary file not shown.
-15.5 KB
Binary file not shown.
-2.64 KB
Binary file not shown.
-1.45 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
-12.8 KB
Binary file not shown.
-4.44 KB
Binary file not shown.
-18.7 KB
Binary file not shown.
-36.5 KB
Binary file not shown.
Binary file not shown.
-3.11 KB
Binary file not shown.
Binary file not shown.
Binary file not shown.
-3.2 KB
Binary file not shown.
-2.32 KB
Binary file not shown.
-5.39 KB
Binary file not shown.
-2.32 KB
Binary file not shown.

0 commit comments

Comments
 (0)