Skip to content

Commit

Permalink
🪪 Adds a more fine-grained profiling context (huggingface#2975)
Browse files Browse the repository at this point in the history
* adds a more fine grained profiling context

* precommit

* fix reward func name

* add reward to RM name

* Update trl/extras/profiling.py

Co-authored-by: Quentin Gallouédec <[email protected]>

* some doc and fixes

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
3 people authored Feb 27, 2025
1 parent c0854c3 commit ac327d5
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 31 deletions.
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,6 @@
title: Text Environments
- local: script_utils
title: Script Utilities
- local: others
title: Others
title: API
9 changes: 9 additions & 0 deletions docs/source/others.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Other

## profiling_decorator

[[autodoc]] extras.profiling.profiling_decorator

## profiling_context

[[autodoc]] extras.profiling.profiling_context
70 changes: 59 additions & 11 deletions trl/extras/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,30 +12,78 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import functools
import time
from typing import Generator

from transformers import is_wandb_available
from transformers import Trainer, is_wandb_available


if is_wandb_available():
import wandb


def profiling_decorator(func):
@contextlib.contextmanager
def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
"""
Decorator to profile a function and log the time taken to execute it.
A context manager function for profiling a block of code. Results are logged to Weights & Biases if enabled.
Args:
trainer (`~transformers.Trainer`):
Trainer object.
name (`str`):
Name of the block to be profiled. Used as a key in the logged dictionary.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_context
class MyTrainer(Trainer):
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
with profiling_context(self, "matrix_multiplication"):
# Code to profile: simulate a computationally expensive operation
result = A @ B # Matrix multiplication
```
"""
start_time = time.perf_counter()
yield
end_time = time.perf_counter()
duration = end_time - start_time

if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration})


def profiling_decorator(func: callable) -> callable:
"""
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
Args:
func (`callable`):
Function to be profiled.
Example:
```python
from transformers import Trainer
from trl.extras.profiling import profiling_decorator
class MyTrainer(Trainer):
@profiling_decorator
def some_method(self):
A = np.random.rand(1000, 1000)
B = np.random.rand(1000, 1000)
# Code to profile: simulate a computationally expensive operation
result = A @ B
```
"""

@functools.wraps(func)
def wrapper(self, *args, **kwargs):
start_time = time.perf_counter()
result = func(self, *args, **kwargs)
end_time = time.perf_counter()
duration = end_time - start_time

if "wandb" in self.args.report_to and wandb.run is not None and self.accelerator.is_main_process:
wandb.log({f"profiling/Time taken: {self.__class__.__name__}.{func.__name__}": duration})
return result
with profiling_context(self, func.__name__):
return func(self, *args, **kwargs)

return wrapper
48 changes: 28 additions & 20 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from transformers.utils import is_peft_available

from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from ..extras.profiling import profiling_decorator
from ..extras.profiling import profiling_context, profiling_decorator
from ..import_utils import is_rich_available, is_vllm_available
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
Expand Down Expand Up @@ -729,9 +729,10 @@ def _generate_and_score_completions(
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
# prompt individually.
ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
all_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
)
with profiling_context(self, "vLLM.generate"):
all_outputs = self.llm.generate(
ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
)
completion_ids = []
for outputs in all_outputs:
for output in outputs.outputs:
Expand Down Expand Up @@ -812,23 +813,30 @@ def _generate_and_score_completions(
zip(self.reward_funcs, self.reward_processing_classes)
):
if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
reward_func_name = reward_func.__name__
with profiling_context(self, reward_func_name):
if isinstance(
reward_func, nn.Module
): # Module instead of PretrainedModel for compat with compiled models
if is_conversational(inputs[0]):
messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
else:
texts = [p + c for p, c in zip(prompts, completions)]
reward_inputs = reward_processing_class(
texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
)
reward_inputs = super()._prepare_inputs(reward_inputs)
with torch.inference_mode():
rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
else:
# Repeat all input columns (but "prompt" and "completion") to match the number of generations
keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)

# Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
# completions may be distributed across processes
Expand Down

0 comments on commit ac327d5

Please sign in to comment.