Skip to content

Commit

Permalink
Improve demo snippet
Browse files Browse the repository at this point in the history
  • Loading branch information
kopalja authored Jan 19, 2025
1 parent c684b32 commit 5b922be
Showing 1 changed file with 34 additions and 36 deletions.
70 changes: 34 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ pip install git+https://github.com/kinit-sk/overshoot.git
import torch
from torchvision import datasets, transforms
from overshoot import AdamO
torch.manual_seed(21)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

class CNN(torch.nn.Module):
def __init__(self):
Expand All @@ -49,47 +41,53 @@ class CNN(torch.nn.Module):
x = torch.relu(self.fc1(x))
return self.fc2(x)


def train_test(model, optimizer):
torch.manual_seed(1) # Make training process same for both variant
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
# Training loop
for epoch in range(5):

for epoch in range(4):
model.train()
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
loss = criterion(model(images), labels)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(f"Epoch [{epoch+1}/{5}], Loss: {loss.item():.4f}")

# Move weights to base variant
if isinstance(optimizer, AdamO):
optimizer.move_to_base()
# Move weights to base variant
if isinstance(optimizer, AdamO):
optimizer.move_to_base()

model.eval()
with torch.no_grad():
correct, total = 0, 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

model.eval()
with torch.no_grad():
correct, total = 0, 0
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")
# Move weights to overshoot variant
if isinstance(optimizer, AdamO):
optimizer.move_to_overshoot()

print(f"({epoch+1}/5) Test Accuracy: {100 * correct / total:.2f}%")


# Init two equal models
model1, model2 = CNN(), CNN()
model2.load_state_dict(model1.state_dict())

print("AdamW")
model = CNN().to(device)
train_test(model, torch.optim.AdamW(model.parameters()))

train_test(model1, torch.optim.AdamW(model1.parameters()))

print("AdamO")
model = CNN().to(device)
train_test(model, AdamO(model.parameters()))
train_test(model2, AdamO(model2.parameters()))
```
## Benchmark Overshoot on various tasks

Expand Down

0 comments on commit 5b922be

Please sign in to comment.