From 428c1a6fbdc26c0d7050943f0cf8248fa0e226e8 Mon Sep 17 00:00:00 2001 From: PaParaZz1 Date: Wed, 5 Feb 2025 19:24:57 +0800 Subject: [PATCH] style(nyz): add comments for rlhf dataset --- ding/utils/data/rlhf_offline_dataset.py | 62 +++++++++++++++++++++++-- ding/utils/data/rlhf_online_dataset.py | 17 ++++--- 2 files changed, 69 insertions(+), 10 deletions(-) diff --git a/ding/utils/data/rlhf_offline_dataset.py b/ding/utils/data/rlhf_offline_dataset.py index 976cf23a4a..2b95010f1d 100644 --- a/ding/utils/data/rlhf_offline_dataset.py +++ b/ding/utils/data/rlhf_offline_dataset.py @@ -3,12 +3,23 @@ from tqdm import tqdm from torch.utils.data import Dataset from torch.distributed import get_rank +from transformers import AutoTokenizer import torch import torch.nn.functional as F def zero_pad_sequences(sequences: List[torch.Tensor], side: str = "left", value: int = 0) -> torch.Tensor: - assert side in ("left", "right") + """ + Overview: + Pad sequences with zeros to create a batch tensor of uniform length. + Arguments: + - sequences (List[torch.Tensor]): A list of PyTorch tensors to be padded. + - side (str): The side to pad ('left' or 'right'), default is 'left'. + - value (int): The padding value to use, default is 0. + Returns: + - padded_sequences (torch.Tensor): A padded tensor of shape [batch_size, max_sequence_length]. + """ + assert side in ("left", "right"), side max_len = max(seq.size(-1) for seq in sequences) padded_sequences = [] for seq in sequences: @@ -22,12 +33,13 @@ class OfflineRLDataset(Dataset): """ Overview: PyTorch Dataset for OfflineRL LLM training like KTO and DPO. + This dataset supports pure text input, as well as image, video, audio, etc. """ def __init__( self, dataset: Iterable[Dict], - tokenizer, + tokenizer: AutoTokenizer, max_length: int, input_key: str = "input", extra_input_keys: List[str] = [], @@ -39,6 +51,24 @@ def __init__( num_processors: int = 8, parallel_load: bool = True ) -> None: + """ + Overview: + Initialize the OfflineRLDataset. + Arguments: + - dataset (Iterable[Dict]): The iterable dataset object to be used, such as list or huggingface dataset. + - tokenizer (AutoTokenizer): The tokenizer to be used. + - max_length (int): The maximum length of the input. + - input_key (str): The key of the input, default is "input". + - extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc. + - output_key (str): The key of the output, default is "output". + - label_key (str): The key of the label, default is "label". + - apply_chat_template (bool): Whether to apply the chat template, default is False. + - tokenizer_chat_template (str): The chat template to be used. + - input_template (str): The input template to be used. + - num_processors (int): The number of processors to be used, default is 8. + - parallel_load (bool): Whether to parallel load the dataset in the `__init__` method, default is True. + Parallel loading is usually used for huggingface dataset. + """ super().__init__() self.tokenizer = tokenizer self.max_length = max_length @@ -98,7 +128,23 @@ def _preprocess_data( output_key: str = "output", label_key: str = "label", apply_chat_template: Union[bool, Callable] = False, - ) -> str: + ) -> Dict[str, Any]: + """ + Overview: + Preprocess the data and return the processed data. + Arguments: + - data (Dict[str, Any]): The data to be processed. + - input_template (str): The input template to be used. + - input_key (str): The key of the input, default is "input". + - extra_input_keys (List[str]): The extra input keys, such as "image", "video", "audio", etc. + - output_key (str): The key of the output, default is "output". + - label_key (str): The key of the label, default is "label". + - apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \ + tokenizer's default template. If a Callable is provided, uses that function to apply the template \ + (typically tokenizer.apply_chat_template). + Returns: + - processed_data (Dict[str, Any]): The processed data. + """ label = data[label_key] if extra_input_keys: extra_inputs = {key: data[key] for key in extra_input_keys} @@ -160,6 +206,7 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]: Returns: - item (Dict[str, Union[torch.Tensor, int]]): The item at the given index. """ + # extra inputs: usually image, video, audio, etc. if self.extra_input_keys: extra_inputs = {key: getattr(self, key)[idx] for key in self.extra_input_keys} else: @@ -173,6 +220,14 @@ def __getitem__(self, idx: int) -> Dict[str, Union[torch.Tensor, int]]: } def collate_fn(self, item_list: List[Dict[str, Union[torch.Tensor, int]]]): + """ + Overview: + Collate the items into a batch, which is used to create a batch for training. + Arguments: + - item_list (List[Dict[str, Union[torch.Tensor, int]]]): The list of items to be collated. + Returns: + - collated_items (Dict[str, Union[torch.Tensor, int]]): The collated items. + """ def tokenizer(prompt: str, response: str): text = (prompt + response).rstrip("\n") @@ -191,6 +246,7 @@ def tokenizer(prompt: str, response: str): inputs["attention_mask"][0][-1] = True return inputs["input_ids"], inputs["attention_mask"] + # tot_extra_inputs: Dict[str, List[torch.Tensor]] tot_ids, tot_masks, tot_labels, prompt_ids_lens, tot_extra_inputs = [], [], [], [], {} for item in item_list: input_ids, attention_mask = tokenizer(item["prompt"], item["response"]) diff --git a/ding/utils/data/rlhf_online_dataset.py b/ding/utils/data/rlhf_online_dataset.py index b192fe455e..d307f09a32 100644 --- a/ding/utils/data/rlhf_online_dataset.py +++ b/ding/utils/data/rlhf_online_dataset.py @@ -2,18 +2,20 @@ from tqdm import tqdm from torch.utils.data import Dataset from torch.distributed import get_rank +from transformers import AutoTokenizer class OnlineRLDataset(Dataset): """ Overview: PyTorch Dataset for OnlineRL LLM training like PPO. + This dataset only supports pure text input now. """ def __init__( self, dataset: Iterable[Dict], - tokenizer, + tokenizer: AutoTokenizer, input_key: str = "input", apply_chat_template: bool = False, input_template: str = None, @@ -23,9 +25,9 @@ def __init__( Initialize the OnlineRLDataset. Arguments: - dataset (torch.utils.data.Dataset): The dataset to preprocess. - - tokenizer (): The tokenizer to preprocess the data. - - input_key (str): The key of the input data. - - apply_chat_template (bool): Whether to apply the chat template. + - tokenizer (AutoTokenizer): The tokenizer to preprocess the data. + - input_key (str): The key of the input data, default is "input". + - apply_chat_template (bool): Whether to apply the chat template, default is False. - input_template (str): The template to format the data. """ super().__init__() @@ -57,7 +59,7 @@ def __getitem__(self, idx: int) -> str: """ Overview: Get the item at the given index. - Args: + Arguments: - idx (int): The index of the item to get. Returns: - item (str): The item at the given index. @@ -78,8 +80,9 @@ def _preprocess_data( - data (Dict[str, Any]): The data to preprocess. - input_template (str): The template to format the data. - input_key (str): The key of the input data. - - apply_chat_template (Union[bool, Callable]): The function to apply the chat template, \ - usually is the `tokenizer.apply_chat_template`. + - apply_chat_template (Union[bool, Callable]): Controls chat template application. If True, uses the \ + tokenizer's default template. If a Callable is provided, uses that function to apply the template \ + (typically tokenizer.apply_chat_template). Returns: - prompt (str): The formatted prompt. """