Skip to content

Commit 6e0e14b

Browse files
committed
feature(nyz): add vision input support and fix bugs
1 parent 1295987 commit 6e0e14b

File tree

2 files changed

+59
-10
lines changed

2 files changed

+59
-10
lines changed

ding/utils/data/rlhf_offline_dataset.py

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
tokenizer,
3131
max_length: int,
3232
input_key: str = "input",
33+
extra_input_keys: List[str] = [],
3334
output_key: str = "output",
3435
label_key: str = "label",
3536
apply_chat_template: bool = False,
@@ -41,6 +42,7 @@ def __init__(
4142
super().__init__()
4243
self.tokenizer = tokenizer
4344
self.max_length = max_length
45+
self.extra_input_keys = extra_input_keys
4446

4547
if apply_chat_template:
4648
apply_chat_template = self.tokenizer.apply_chat_template
@@ -53,6 +55,7 @@ def __init__(
5355
self._preprocess_data,
5456
input_template=input_template,
5557
input_key=input_key,
58+
extra_input_keys=extra_input_keys,
5659
output_key=output_key,
5760
label_key=label_key,
5861
apply_chat_template=apply_chat_template
@@ -67,29 +70,40 @@ def __init__(
6770
self.responses = processed_dataset["response"]
6871
self.labels = processed_dataset["label"]
6972
self.prompt_ids_lens = processed_dataset["prompt_ids_len"]
73+
for key in extra_input_keys:
74+
setattr(self, key, processed_dataset[key])
7075
else:
7176
self.prompts = []
7277
self.responses = []
7378
self.labels = []
7479
self.prompt_ids_lens = []
80+
for key in extra_input_keys:
81+
setattr(self, key, [])
7582
for data in tqdm(dataset, desc="Preprocessing data", disable=not get_rank() == 0):
7683
processed_data = self._preprocess_data(data)
7784
if processed_data["prompt"] is not None:
7885
self.prompts.append(processed_data["prompt"])
7986
self.responses.append(processed_data["response"])
8087
self.labels.append(processed_data["label"])
8188
self.prompt_ids_lens.append(processed_data["prompt_ids_len"])
89+
for key in extra_input_keys:
90+
getattr(self, key).append(processed_data[key])
8291

8392
def _preprocess_data(
8493
self,
8594
data: Dict[str, Any],
8695
input_template: str = None,
8796
input_key: str = "input",
97+
extra_input_keys: List[str] = [],
8898
output_key: str = "output",
8999
label_key: str = "label",
90100
apply_chat_template: Union[bool, Callable] = False,
91101
) -> str:
92102
label = data[label_key]
103+
if extra_input_keys:
104+
extra_inputs = {key: data[key] for key in extra_input_keys}
105+
else:
106+
extra_inputs = {}
93107

94108
if apply_chat_template:
95109
if output_key:
@@ -120,7 +134,13 @@ def _preprocess_data(
120134
if prompt_ids_len >= self.max_length - 2:
121135
prompt = None
122136

123-
return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len}
137+
return {
138+
"prompt": prompt,
139+
"response": response,
140+
"label": label,
141+
"prompt_ids_len": prompt_ids_len,
142+
**extra_inputs
143+
}
124144

125145
def __len__(self) -> int:
126146
"""
@@ -135,14 +155,21 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
135155
"""
136156
Overview:
137157
Get the item at the given index.
158+
Arguments:
159+
- idx (int): The index of the item to get.
138160
Returns:
139161
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
140162
"""
163+
if self.extra_input_keys:
164+
extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys}
165+
else:
166+
extra_inputs = {}
141167
return {
142168
"prompt": self.prompts[idx],
143169
"response": self.responses[idx],
144170
"label": self.labels[idx],
145-
"prompt_ids_len": self.prompt_ids_lens[idx]
171+
"prompt_ids_len": self.prompt_ids_lens[idx],
172+
**extra_inputs
146173
}
147174

148175
def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]):
@@ -164,13 +191,17 @@ def tokenizer(prompt: str, response: str):
164191
inputs["attention_mask"][0][-1] = True
165192
return inputs["input_ids"], inputs["attention_mask"]
166193

167-
tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], []
194+
tot_ids, tot_masks, tot_labels, prompt_ids_lens, tot_extra_inputs = [], [], [], [], {}
168195
for item in item_list:
169196
input_ids, attention_mask = tokenizer(item["prompt"], item["response"])
170197
tot_ids.append(input_ids)
171198
tot_masks.append(attention_mask)
172199
tot_labels.append(item["label"])
173200
prompt_ids_lens.append(item["prompt_ids_len"])
201+
for key in self.extra_input_keys:
202+
if key not in tot_extra_inputs:
203+
tot_extra_inputs[key] = []
204+
tot_extra_inputs[key].append(item[key])
174205

175206
# add unmatched y'| x (used to estimate the KL divergence between policy and reference)
176207
for idx in range(len(item_list)):
@@ -180,7 +211,11 @@ def tokenizer(prompt: str, response: str):
180211
tot_masks.append(attention_mask)
181212
tot_labels.append(-1)
182213
prompt_ids_lens.append(item_list[idx]["prompt_ids_len"])
214+
for key in self.extra_input_keys:
215+
if key not in tot_extra_inputs:
216+
tot_extra_inputs[key] = []
217+
tot_extra_inputs[key].append(item_list[idx][key])
183218

184219
input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id)
185220
attention_mask = zero_pad_sequences(tot_masks, side="right")
186-
return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens
221+
return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens, tot_extra_inputs

ding/utils/data/tests/test_rlhf_offline_dataset.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33
from ding.utils.data import OfflineRLDataset
44
from transformers import AutoTokenizer
55

6+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
7+
IMG_START_TOKEN = '<img>'
8+
IMG_END_TOKEN = '</img>'
9+
IMG_CONTEXT_NUM = 10 # user-defined number of image patches in the context
10+
611

712
@pytest.fixture
813
def dataset():
@@ -11,16 +16,18 @@ def dataset():
1116
# split pair data into two separate datasets
1217
hf_dataset_1 = hf_dataset.map(
1318
lambda x: {
14-
"prompt": x["query"],
19+
"query": f"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * IMG_CONTEXT_NUM}{IMG_END_TOKEN}\n{x['query']}",
20+
"image": x["image"],
1521
"response": x["response"][0],
16-
'human_ranking': x["human_ranking"][0]
22+
"human_ranking": x["human_ranking"][0]
1723
}
1824
)
1925
hf_dataset_2 = hf_dataset.map(
2026
lambda x: {
21-
"prompt": x["query"],
27+
"query": f"{IMG_START_TOKEN}{IMG_CONTEXT_TOKEN * IMG_CONTEXT_NUM}{IMG_END_TOKEN}\n{x['query']}",
28+
"image": x["image"],
2229
"response": x["response"][1],
23-
'human_ranking': x["human_ranking"][0]
30+
"human_ranking": x["human_ranking"][1]
2431
}
2532
)
2633
# combine two datasets
@@ -33,7 +40,7 @@ def dataset():
3340
@pytest.fixture
3441
def tokenizer():
3542
# Load a tokenizer
36-
return AutoTokenizer.from_pretrained("Qwen/Qwen2.5-Coder-7B")
43+
return AutoTokenizer.from_pretrained("OpenGVLab/InternVL2_5-4B")
3744

3845

3946
@pytest.mark.unittest
@@ -44,6 +51,7 @@ def test_offline_rl_dataset_initialization(dataset, tokenizer):
4451
tokenizer=tokenizer,
4552
max_length=1024,
4653
input_key="query",
54+
extra_input_keys=["image"],
4755
output_key="response",
4856
label_key="human_ranking"
4957
)
@@ -53,6 +61,7 @@ def test_offline_rl_dataset_initialization(dataset, tokenizer):
5361
tokenizer=tokenizer,
5462
max_length=256,
5563
input_key="query",
64+
extra_input_keys=["image"],
5665
output_key="response",
5766
label_key="human_ranking"
5867
)
@@ -68,6 +77,7 @@ def test_offline_rl_dataset_item_retrieval(dataset, tokenizer):
6877
tokenizer=tokenizer,
6978
max_length=256,
7079
input_key="query",
80+
extra_input_keys=["image"],
7181
output_key="response",
7282
label_key="human_ranking"
7383
)
@@ -76,6 +86,7 @@ def test_offline_rl_dataset_item_retrieval(dataset, tokenizer):
7686
assert "response" in item
7787
assert "label" in item
7888
assert "prompt_ids_len" in item
89+
assert "image" in item
7990
print(item)
8091

8192

@@ -92,8 +103,11 @@ def test_offline_rl_dataset_collate_fn(dataset, tokenizer):
92103
)
93104
B = 10
94105
item_list = [offline_dataset[i] for i in range(B)]
95-
input_ids, attention_mask, labels, prompt_ids_lens = offline_dataset.collate_fn(item_list)
106+
input_ids, attention_mask, labels, prompt_ids_lens, extra_inputs = offline_dataset.collate_fn(item_list)
96107
assert input_ids.size(0) == len(item_list) * 2 # because of the unmatched y'| x
97108
assert attention_mask.size(0) == len(item_list) * 2
98109
assert labels.size(0) == len(item_list) * 2
99110
assert len(prompt_ids_lens) == len(item_list) * 2
111+
for key in offline_dataset.extra_input_keys:
112+
assert key in extra_inputs
113+
assert extra_inputs[key].size(0) == len(item_list) * 2

0 commit comments

Comments
 (0)