Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add first train batch to train unit for extracting example inputs #971

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from typing import Any, cast, Dict, Generic, Iterator, TypeVar, Union

import torch
from pyre_extensions import none_throws
from torchtnt.framework._unit_utils import (
_find_optimizers_for_module,
_step_requires_iterator,
)

from torchtnt.framework.state import State
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import _is_fsdp_module, FSDPOptimizerWrapper
Expand Down Expand Up @@ -312,6 +312,7 @@ def on_train_epoch_end(self, state: State) -> None:
def __init__(self) -> None:
super().__init__()
self.train_progress = Progress()
self.first_train_batch: TTrainData | None = None

def on_train_start(self, state: State) -> None:
"""Hook called before training starts.
Expand All @@ -329,6 +330,14 @@ def on_train_epoch_start(self, state: State) -> None:
"""
pass

@property
def first_train_batch(self) -> TTrainData:
return none_throws(self.first_train_batch)

@first_train_batch.setter
def first_train_batch(self, data: TTrainData) -> None:
self.first_train_batch = data

@abstractmethod
# pyre-fixme[3]: Return annotation cannot be `Any`.
def train_step(self, state: State, data: TTrainData) -> Any:
Expand Down
Loading