Skip to content

Commit

Permalink
Create package
Browse files Browse the repository at this point in the history
  • Loading branch information
kopalja committed Jan 17, 2025
1 parent f111e2a commit 7a993bc
Show file tree
Hide file tree
Showing 25 changed files with 134 additions and 116 deletions.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from misc import (get_model_size, init_dataset, init_model, supported_datasets,
supported_models)
from optimizers import optimizers_map
from misc import optimizers_map
from train import OvershootTrainer
from trainer_configs import get_trainer_config

Expand Down
109 changes: 108 additions & 1 deletion misc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional

import numpy as np
from peft import LoraConfig, TaskType, get_peft_model
Expand All @@ -18,6 +18,15 @@
from models.vae import VAE
from trainer_configs import *

from overshoot.sgd_overshoot import SGDO
from overshoot.adamw_overshoot_delayed import AdamO as OvershootAdamW_delayed

from optimizers_old.backups2.sgdo_adaptive import SGDO as SGDO_adaptive
from optimizers_old.backups2.adamw_overshoot_replication import AdamW as OvershootAdamW_replication
from optimizers_old.backups2.adamw_overshoot_full_approximation import AdamW as OvershootAdamW_full_approximation
from optimizers_old.backups2.adamw_overshoot_denom_approximation import AdamW as OvershootAdamW_denom_approximation
from optimizers_old.backups2.adamw_overshoot_adaptive import AdamW as OvershootAdamW_adaptive

supported_datasets = [
"mnist",
"f-mnist",
Expand All @@ -41,6 +50,25 @@
"minilm",
]

optimizers_map = {
"sgd": torch.optim.SGD,
"sgd_momentum": torch.optim.SGD,
"sgd_nesterov": torch.optim.SGD,
"sgd_overshoot": SGDO,
"sgd_adaptive": SGDO_adaptive,
"adam": torch.optim.Adam,
"adamW": torch.optim.AdamW,
"adam_zero": torch.optim.Adam,
"adamW_zero": torch.optim.AdamW,
"nadam": torch.optim.NAdam,
"adamW_overshoot_replication": OvershootAdamW_replication,
"adamW_overshoot_full_approximation": OvershootAdamW_full_approximation,
"adamW_overshoot_denom_approximation": OvershootAdamW_denom_approximation,
"adamW_overshoot_delayed": OvershootAdamW_delayed,
"adamW_overshoot_adaptive": OvershootAdamW_adaptive,
"rmsprop": torch.optim.RMSprop,
}


def init_dataset(dataset_name: str, model_name: Optional[str], seed: Optional[int] = None):
if dataset_name == "mnist":
Expand Down Expand Up @@ -196,3 +224,82 @@ def get_model_size(model: torch.nn.Module):
buffer_size = sum(p.numel() for p in model.buffers()) * 4
size_all_mb = (param_size + buffer_size) / 1024 / 1024
return round(size_all_mb, 2)

def create_optimizer(opt_name: str, param_groups, overshoot_factor: float, lr: float, config, foreach: Optional[bool] = None) -> torch.optim.Optimizer:
if opt_name == "nadam":
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
betas=(config.adam_beta1, config.adam_beta2),
momentum_decay=1000000000000000000000000, # Turn of momentum decay
weight_decay=config.weight_decay,
decoupled_weight_decay=True,
foreach=foreach,
)
elif opt_name == "adamW_overshoot_delayed":
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.weight_decay,
overshoot=overshoot_factor,
overshoot_delay=config.overshoot_delay,
foreach=foreach
)
elif opt_name == "adamW_overshoot_adaptive":
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.weight_decay,
cosine_target=config.target_cosine_similarity,
foreach=foreach
)
elif opt_name.startswith("adamW_overshoot"):
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.weight_decay,
overshoot=overshoot_factor,
foreach=foreach,
)
elif "adam" in opt_name:
config.adam_beta1 *= "zero" not in opt_name
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
betas=(config.adam_beta1, config.adam_beta2),
weight_decay=config.weight_decay,
foreach=foreach,
)
elif "sgd_adaptive" in opt_name:
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
momentum=config.sgd_momentum,
weight_decay=config.weight_decay,
cosine_target=config.target_cosine_similarity,
foreach=foreach,
)
elif "sgd_overshoot" in opt_name:
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
momentum=config.sgd_momentum,
weight_decay=config.weight_decay,
overshoot=overshoot_factor,
foreach=foreach,
)
elif "sgd" in opt_name:
opt = optimizers_map[opt_name](
param_groups,
lr=lr,
momentum=0 if opt_name == "sgd" else config.sgd_momentum,
weight_decay=config.weight_decay,
nesterov="nesterov" in opt_name,
foreach=foreach,
)
else:
raise Exception(f"Optimizer {opt_name} not recognized.")
return opt
111 changes: 0 additions & 111 deletions optimizers/__init__.py

This file was deleted.

File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added overshoot/__init__.py
Empty file.
File renamed without changes.
1 change: 0 additions & 1 deletion optimizers/sgd_overshoot.py → overshoot/sgd_overshoot.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,6 @@ def _fused_sgd(
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
maximize=maximize,
is_first_step=is_first_step,
grad_scale=device_grad_scale,
Expand Down
23 changes: 23 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from setuptools import setup, find_packages

setup(
name="overshoot",
version="0.1.0",
description="Overshoot version of SGD and AdamW optimizers",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
author="Jakub Kopal",
author_email="[email protected]",
url="https://github.com/kinit-sk/overshoot",
license="MIT",
packages=find_packages(),
install_requires=[
"torch>=2.4.0",
],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.9",
)
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from trainer_configs import DefaultConfig
from custom_datasets import UnifiedDatasetInterface
from misc import compute_model_distance, get_gpu_stats
from optimizers import create_optimizer
from misc import create_optimizer

# ------------------------------------------------------------------------------
torch.cuda.empty_cache()
Expand Down
2 changes: 1 addition & 1 deletion train_with_pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader

from optimizers import optimizers_map
from optimizers.overshoot import optimizers_map


from misc import init_dataset, init_model, get_gpu_stats, compute_model_distance, supported_datasets, supported_models
Expand Down

0 comments on commit 7a993bc

Please sign in to comment.