From ac327d5e84b54c4bc3c0ffc23c7d582a99d4028e Mon Sep 17 00:00:00 2001 From: Edward Beeching Date: Thu, 27 Feb 2025 21:58:39 +0100 Subject: [PATCH] =?UTF-8?q?=F0=9F=AA=AA=20Adds=20a=20more=20fine-grained?= =?UTF-8?q?=20profiling=20context=20(#2975)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * adds a more fine grained profiling context * precommit * fix reward func name * add reward to RM name * Update trl/extras/profiling.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * some doc and fixes --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- docs/source/_toctree.yml | 2 ++ docs/source/others.md | 9 +++++ trl/extras/profiling.py | 70 +++++++++++++++++++++++++++++++------ trl/trainer/grpo_trainer.py | 48 ++++++++++++++----------- 4 files changed, 98 insertions(+), 31 deletions(-) create mode 100644 docs/source/others.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4755daaf6e..b287478a92 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -107,4 +107,6 @@ title: Text Environments - local: script_utils title: Script Utilities + - local: others + title: Others title: API diff --git a/docs/source/others.md b/docs/source/others.md new file mode 100644 index 0000000000..bd89447e7b --- /dev/null +++ b/docs/source/others.md @@ -0,0 +1,9 @@ +# Other + +## profiling_decorator + +[[autodoc]] extras.profiling.profiling_decorator + +## profiling_context + +[[autodoc]] extras.profiling.profiling_context diff --git a/trl/extras/profiling.py b/trl/extras/profiling.py index 1cb439e2cc..7798231749 100644 --- a/trl/extras/profiling.py +++ b/trl/extras/profiling.py @@ -12,30 +12,78 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import functools import time +from typing import Generator -from transformers import is_wandb_available +from transformers import Trainer, is_wandb_available if is_wandb_available(): import wandb -def profiling_decorator(func): +@contextlib.contextmanager +def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: """ - Decorator to profile a function and log the time taken to execute it. + A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled. + + Args: + trainer (`~transformers.Trainer`): + Trainer object. + name (`str`): + Name of the block to be profiled. Used as a key in the logged dictionary. + + Example: + ```python + from transformers import Trainer + from trl.extras.profiling import profiling_context + + class MyTrainer(Trainer): + def some_method(self): + A = np.random.rand(1000, 1000) + B = np.random.rand(1000, 1000) + with profiling_context(self, "matrix_multiplication"): + # Code to profile: simulate a computationally expensive operation + result = A @ B # Matrix multiplication + ``` + """ + start_time = time.perf_counter() + yield + end_time = time.perf_counter() + duration = end_time - start_time + + if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: + wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}) + + +def profiling_decorator(func: callable) -> callable: + """ + Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. + + Args: + func (`callable`): + Function to be profiled. + + Example: + ```python + from transformers import Trainer + from trl.extras.profiling import profiling_decorator + + class MyTrainer(Trainer): + @profiling_decorator + def some_method(self): + A = np.random.rand(1000, 1000) + B = np.random.rand(1000, 1000) + # Code to profile: simulate a computationally expensive operation + result = A @ B + ``` """ @functools.wraps(func) def wrapper(self, *args, **kwargs): - start_time = time.perf_counter() - result = func(self, *args, **kwargs) - end_time = time.perf_counter() - duration = end_time - start_time - - if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process: - wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration}) - return result + with profiling_context(self, func.__name__): + return func(self, *args, **kwargs) return wrapper diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6173030343..fe13a12687 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -46,7 +46,7 @@ from transformers.utils import is_peft_available from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template -from ..extras.profiling import profiling_decorator +from ..extras.profiling import profiling_context, profiling_decorator from ..import_utils import is_rich_available, is_vllm_available from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation from .callbacks import SyncRefModelCallback @@ -729,9 +729,10 @@ def _generate_and_score_completions( # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text)) - all_outputs = self.llm.generate( - ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False - ) + with profiling_context(self, "vLLM.generate"): + all_outputs = self.llm.generate( + ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False + ) completion_ids = [] for outputs in all_outputs: for output in outputs.outputs: @@ -812,23 +813,30 @@ def _generate_and_score_completions( zip(self.reward_funcs, self.reward_processing_classes) ): if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models - if is_conversational(inputs[0]): - messages = [{"messages": p + c} for p, c in zip(prompts, completions)] - texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] - else: - texts = [p + c for p, c in zip(prompts, completions)] - reward_inputs = reward_processing_class( - texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False - ) - reward_inputs = super()._prepare_inputs(reward_inputs) - with torch.inference_mode(): - rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}" else: - # Repeat all input columns (but "prompt" and "completion") to match the number of generations - keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] - reward_kwargs = {key: [example[key] for example in inputs] for key in keys} - output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) - rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) + reward_func_name = reward_func.__name__ + with profiling_context(self, reward_func_name): + if isinstance( + reward_func, nn.Module + ): # Module instead of PretrainedModel for compat with compiled models + if is_conversational(inputs[0]): + messages = [{"messages": p + c} for p, c in zip(prompts, completions)] + texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] + else: + texts = [p + c for p, c in zip(prompts, completions)] + reward_inputs = reward_processing_class( + texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False + ) + reward_inputs = super()._prepare_inputs(reward_inputs) + with torch.inference_mode(): + rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) + else: + # Repeat all input columns (but "prompt" and "completion") to match the number of generations + keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] + reward_kwargs = {key: [example[key] for example in inputs] for key in keys} + output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) + rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the # completions may be distributed across processes