Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
kopalja authored Feb 2, 2025
1 parent 6ae979b commit f4a22e0
Showing 1 changed file with 3 additions and 7 deletions.
10 changes: 3 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ pip install git+https://github.com/kinit-sk/overshoot.git
```python
import torch
from torchvision import datasets, transforms
from torch.optim import AdamW, SGD
from overshoot import AdamO, SGDO

class MLP(torch.nn.Module):
Expand All @@ -43,7 +42,6 @@ class MLP(torch.nn.Module):
x = torch.relu(x)
return self.fc2(x)


def train_test(model, optimizer):
torch.manual_seed(42) # Make training process same
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
Expand All @@ -63,7 +61,7 @@ def train_test(model, optimizer):
# Move weights to base variant
if isinstance(optimizer, (AdamO, SGDO)):
optimizer.move_to_base()

model.eval()
with torch.no_grad():
correct = 0
Expand All @@ -76,21 +74,19 @@ def train_test(model, optimizer):
if isinstance(optimizer, (AdamO, SGDO)):
optimizer.move_to_overshoot()



# Init four equal models
models = [MLP() for _ in range(4)]
for m in models[1:]:
m.load_state_dict(models[0].state_dict())

print("AdamW")
train_test(models[0], AdamW(models[0].parameters()))
train_test(models[0], torch.optim.AdamW(models[0].parameters()))

print("AdamO (AdamW + overshoot)")
train_test(models[1], AdamO(models[1].parameters(), overshoot=5))

print("SGD")
train_test(models[2], SGD(models[2].parameters(), lr=0.01, momentum=0.9))
train_test(models[2], torch.optim.SGD(models[2].parameters(), lr=0.01, momentum=0.9))

print("SGD (SGD + overshoot)")
train_test(models[3], SGDO(models[3].parameters(), lr=0.01, momentum=0.9, overshoot=5))
Expand Down

0 comments on commit f4a22e0

Please sign in to comment.