Skip to content
This repository was archived by the owner on May 14, 2024. It is now read-only.

Add WandB support to 1.5 training #279

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2022,7 +2022,26 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable logging and output TensorBoard log to this directory / ログ出力を有効にしてこのディレクトリにTensorBoard用のログを出力する",
)
parser.add_argument(
"--log_with",
type=str,
default=None,
choices=["tensorboard", "wandb", "all"],
help="what logging tool(s) to use (if 'all', TensorBoard and WandB are both used) / ログ出力に使用するツール (allを指定するとTensorBoardとWandBの両方が使用される)",
)
parser.add_argument("--log_prefix", type=str, default=None, help="add prefix for each log directory / ログディレクトリ名の先頭に追加する文字列")
parser.add_argument(
"--log_tracker_name",
type=str,
default=None,
help="name of tracker to use for logging, default is script-specific default name / ログ出力に使用するtrackerの名前、省略時はスクリプトごとのデフォルト名",
)
parser.add_argument(
"--wandb_api_key",
type=str,
default=None,
help="specify WandB API key to log in before starting training (optional). / WandB APIキーを指定して学習開始前にログインする(オプション)",
)
parser.add_argument(
"--noise_offset",
type=float,
Expand Down Expand Up @@ -2234,6 +2253,11 @@ def read_config_from_file(args: argparse.Namespace, parser: argparse.ArgumentPar
print(f"Config file already exists. Aborting... / 出力先の設定ファイルが既に存在します: {config_path}")
exit(1)

# remove unnecessary keys
for key in ["config_file", "output_config", "wandb_api_key"]:
if key in args:
del args[key]

# convert args to dictionary
args_dict = vars(args)

Expand Down Expand Up @@ -2689,12 +2713,35 @@ def prepare_accelerator(args: argparse.Namespace):
log_prefix = "" if args.log_prefix is None else args.log_prefix
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())

if args.log_with is None:
if logging_dir is not None:
log_with = "tensorboard"
else:
log_with = None
else:
log_with = args.log_with
if log_with in ["tensorboard", "all"]:
if logging_dir is None:
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
if log_with in ["wandb", "all"]:
try:
import wandb
except ImportError:
raise ImportError("No wandb / wandb がインストールされていないようです")
if logging_dir is not None:
os.makedirs(logging_dir, exist_ok=True)
os.environ["WANDB_DIR"] = logging_dir
if args.wandb_api_key is not None:
wandb.login(key=args.wandb_api_key)


accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=log_with,
logging_dir=logging_dir,
)
accelerator.init_trackers(project_name="Kohya-ss_" + args.output_name)

# accelerateの互換性問題を解決する
accelerator_0_15 = True
Expand Down Expand Up @@ -3146,6 +3193,18 @@ def sample_images(
)

image.save(os.path.join(save_dir, img_filename))

# wandb有効時のみログを送信
try:
wandb_tracker = accelerator.get_tracker("wandb")
try:
import wandb
except ImportError: # 事前に一度確認するのでここはエラー出ないはず
raise ImportError("No wandb / wandb がインストールされていないようです")

wandb_tracker.log({f"sample_{i}": wandb.Image(image)})
except: # wandb 無効時
pass

# clear pipeline and cache to reduce vram usage
del pipeline
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@ lion-pytorch==0.0.6
# for network module
# locon==0.0.4
lycoris-lora==0.1.4
wandb==0.15.0
# for kohya_ss library
.