Skip to content

Commit

Permalink
style(nyz): add comments for rlhf dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
PaParaZz1 committed Feb 5, 2025
1 parent 6e0e14b commit 428c1a6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 10 deletions.
62 changes: 59 additions & 3 deletions ding/utils/data/rlhf_offline_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,23 @@
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.distributed import get_rank
from transformers import AutoTokenizer
import torch
import torch.nn.functional as F


def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor:
assert side in ("left", "right")
"""
Overview:
Pad sequences with zeros to create a batch tensor of uniform length.
Arguments:
- sequences (List[torch.Tensor]): A list of PyTorch tensors to be padded.
- side (str): The side to pad ('left' or 'right'), default is 'left'.
- value (int): The padding value to use, default is 0.
Returns:
- padded_sequences (torch.Tensor): A padded tensor of shape [batch_size, max_sequence_length].
"""
assert side in ("left", "right"), side
max_len = max(seq.size(-1) for seq in sequences)
padded_sequences = []
for seq in sequences:
Expand All @@ -22,12 +33,13 @@ class OfflineRLDataset(Dataset):
"""
Overview:
PyTorch Dataset for OfflineRL LLM training like KTO and DPO.
This dataset supports pure text input, as well as image, video, audio, etc.
"""

def __init__(
self,
dataset: Iterable[Dict],
tokenizer,
tokenizer: AutoTokenizer,
max_length: int,
input_key: str = "input",
extra_input_keys: List[str] = [],
Expand All @@ -39,6 +51,24 @@ def __init__(
num_processors: int = 8,
parallel_load: bool = True
) -> None:
"""
Overview:
Initialize the OfflineRLDataset.
Arguments:
- dataset (Iterable[Dict]): The iterable dataset object to be used, such as list or huggingface dataset.
- tokenizer (AutoTokenizer): The tokenizer to be used.
- max_length (int): The maximum length of the input.
- input_key (str): The key of the input, default is "input".
- extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
- output_key (str): The key of the output, default is "output".
- label_key (str): The key of the label, default is "label".
- apply_chat_template (bool): Whether to apply the chat template, default is False.
- tokenizer_chat_template (str): The chat template to be used.
- input_template (str): The input template to be used.
- num_processors (int): The number of processors to be used, default is 8.
- parallel_load (bool): Whether to parallel load the dataset in the `__init__` method, default is True.
Parallel loading is usually used for huggingface dataset.
"""
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
Expand Down Expand Up @@ -98,7 +128,23 @@ def _preprocess_data(
output_key: str = "output",
label_key: str = "label",
apply_chat_template: Union[bool, Callable] = False,
) -> str:
) -> Dict[str, Any]:
"""
Overview:
Preprocess the data and return the processed data.
Arguments:
- data (Dict[str, Any]): The data to be processed.
- input_template (str): The input template to be used.
- input_key (str): The key of the input, default is "input".
- extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc.
- output_key (str): The key of the output, default is "output".
- label_key (str): The key of the label, default is "label".
- apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
tokenizer's default template. If a Callable is provided, uses that function to apply the template \
(typically tokenizer.apply_chat_template).
Returns:
- processed_data (Dict[str, Any]): The processed data.
"""
label = data[label_key]
if extra_input_keys:
extra_inputs = {key: data[key] for key in extra_input_keys}
Expand Down Expand Up @@ -160,6 +206,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
Returns:
- item (Dict[str, Union[torch.Tensor, int]]): The item at the given index.
"""
# extra inputs: usually image, video, audio, etc.
if self.extra_input_keys:
extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys}
else:
Expand All @@ -173,6 +220,14 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]:
}

def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]):
"""
Overview:
Collate the items into a batch, which is used to create a batch for training.
Arguments:
- item_list (List[Dict[str, Union[torch.Tensor, int]]]): The list of items to be collated.
Returns:
- collated_items (Dict[str, Union[torch.Tensor, int]]): The collated items.
"""

def tokenizer(prompt: str, response: str):
text = (prompt + response).rstrip("\n")
Expand All @@ -191,6 +246,7 @@ def tokenizer(prompt: str, response: str):
inputs["attention_mask"][0][-1] = True
return inputs["input_ids"], inputs["attention_mask"]

# tot_extra_inputs: Dict[str, List[torch.Tensor]]
tot_ids, tot_masks, tot_labels, prompt_ids_lens, tot_extra_inputs = [], [], [], [], {}
for item in item_list:
input_ids, attention_mask = tokenizer(item["prompt"], item["response"])
Expand Down
17 changes: 10 additions & 7 deletions ding/utils/data/rlhf_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
from tqdm import tqdm
from torch.utils.data import Dataset
from torch.distributed import get_rank
from transformers import AutoTokenizer


class OnlineRLDataset(Dataset):
"""
Overview:
PyTorch Dataset for OnlineRL LLM training like PPO.
This dataset only supports pure text input now.
"""

def __init__(
self,
dataset: Iterable[Dict],
tokenizer,
tokenizer: AutoTokenizer,
input_key: str = "input",
apply_chat_template: bool = False,
input_template: str = None,
Expand All @@ -23,9 +25,9 @@ def __init__(
Initialize the OnlineRLDataset.
Arguments:
- dataset (torch.utils.data.Dataset): The dataset to preprocess.
- tokenizer (): The tokenizer to preprocess the data.
- input_key (str): The key of the input data.
- apply_chat_template (bool): Whether to apply the chat template.
- tokenizer (AutoTokenizer): The tokenizer to preprocess the data.
- input_key (str): The key of the input data, default is "input".
- apply_chat_template (bool): Whether to apply the chat template, default is False.
- input_template (str): The template to format the data.
"""
super().__init__()
Expand Down Expand Up @@ -57,7 +59,7 @@ def __getitem__(self, idx: int) -> str:
"""
Overview:
Get the item at the given index.
Args:
Arguments:
- idx (int): The index of the item to get.
Returns:
- item (str): The item at the given index.
Expand All @@ -78,8 +80,9 @@ def _preprocess_data(
- data (Dict[str, Any]): The data to preprocess.
- input_template (str): The template to format the data.
- input_key (str): The key of the input data.
- apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \
usually is the `tokenizer.apply_chat_template`.
- apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \
tokenizer's default template. If a Callable is provided, uses that function to apply the template \
(typically tokenizer.apply_chat_template).
Returns:
- prompt (str): The formatted prompt.
"""
Expand Down

0 comments on commit 428c1a6

Please sign in to comment.