Skip to content

Commit

Permalink
Try to make tensorboard properly optional
Browse files Browse the repository at this point in the history
  • Loading branch information
RedTachyon committed Jan 16, 2024
1 parent 322c736 commit 37fb9f3
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
3 changes: 1 addition & 2 deletions coltra/policy_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from numpy.random import Generator
import torch
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from typarse import BaseConfig

from coltra.agents import Agent
Expand Down Expand Up @@ -114,7 +113,7 @@ def train_on_data(
data_dict: dict[str, OnPolicyRecord],
shape: tuple[int, int],
step: int = 0,
writer: Optional[SummaryWriter] = None,
writer: Optional["SummaryWriter"] = None,
) -> dict[str, float]:
"""
Performs a single update step with PPO on the given batch of data.
Expand Down
1 change: 0 additions & 1 deletion coltra/research/policy_fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .policy_fusion import JointModel
8 changes: 5 additions & 3 deletions coltra/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import shortuuid
import torch
import yaml
from torch.utils.tensorboard import SummaryWriter
from tqdm import trange

from coltra.collectors import collect_crowd_data, collect_family_data
Expand Down Expand Up @@ -67,8 +66,10 @@ def __init__(
)

# Setup tensorboard
self.writer: Optional[SummaryWriter]
# self.writer: Optional[SummaryWriter]
if self.config.tensorboard_name:
from torch.utils.tensorboard import SummaryWriter

dt_string = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if use_uuid:
dt_string += "_" + shortuuid.uuid()
Expand Down Expand Up @@ -242,8 +243,9 @@ def __init__(
)

# Setup tensorboard
self.writer: Optional[SummaryWriter]
if self.config.tensorboard_name:
from torch.utils.tensorboard import SummaryWriter

dt_string = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
if use_uuid:
dt_string += "_" + shortuuid.uuid()
Expand Down
4 changes: 2 additions & 2 deletions coltra/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.optim.adamax import Adamax
from torch.optim.sgd import SGD

from torch.utils.tensorboard import SummaryWriter
# from torch.utils.tensorboard import SummaryWriter

from coltra.buffers import Observation, Action

Expand All @@ -37,7 +37,7 @@
def write_dict(
metrics: dict[str, Union[int, float]],
step: int,
writer: Optional[SummaryWriter] = None,
writer: Optional["SummaryWriter"] = None,
):
"""Writes a dictionary to a tensorboard SummaryWriter"""
if writer is not None:
Expand Down

0 comments on commit 37fb9f3

Please sign in to comment.