From 5b922bed576532ca22f9c1b0619a0e60074e1b27 Mon Sep 17 00:00:00 2001 From: Jakub Kopal <32540156+kopalja@users.noreply.github.com> Date: Sun, 19 Jan 2025 10:28:30 +0100 Subject: [PATCH] Improve demo snippet --- README.md | 70 +++++++++++++++++++++++++++---------------------------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 93fdb80..841a1af 100644 --- a/README.md +++ b/README.md @@ -29,14 +29,6 @@ pip install git+https://github.com/kinit-sk/overshoot.git import torch from torchvision import datasets, transforms from overshoot import AdamO -torch.manual_seed(21) - -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) -train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) -test_dataset = datasets.MNIST(root='./data', train=False, transform=transform) -train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) -test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False) class CNN(torch.nn.Module): def __init__(self): @@ -49,47 +41,53 @@ class CNN(torch.nn.Module): x = torch.relu(self.fc1(x)) return self.fc2(x) - def train_test(model, optimizer): + torch.manual_seed(1) # Make training process same for both variant + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) + train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) + test_dataset = datasets.MNIST(root='./data', train=False, transform=transform) + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False) criterion = torch.nn.CrossEntropyLoss() - # Training loop - for epoch in range(5): + + for epoch in range(4): model.train() for images, labels in train_loader: - images, labels = images.to(device), labels.to(device) - # Forward pass - outputs = model(images) - loss = criterion(outputs, labels) + loss = criterion(model(images), labels) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() - - print(f"Epoch [{epoch+1}/{5}], Loss: {loss.item():.4f}") - # Move weights to base variant - if isinstance(optimizer, AdamO): - optimizer.move_to_base() + # Move weights to base variant + if isinstance(optimizer, AdamO): + optimizer.move_to_base() + + model.eval() + with torch.no_grad(): + correct, total = 0, 0 + for images, labels in test_loader: + outputs = model(images) + _, predicted = torch.max(outputs, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() - model.eval() - with torch.no_grad(): - correct, total = 0, 0 - for images, labels in test_loader: - images, labels = images.to(device), labels.to(device) - outputs = model(images) - _, predicted = torch.max(outputs, 1) - total += labels.size(0) - correct += (predicted == labels).sum().item() - - print(f"Test Accuracy: {100 * correct / total:.2f}%") + # Move weights to overshoot variant + if isinstance(optimizer, AdamO): + optimizer.move_to_overshoot() + + print(f"({epoch+1}/5) Test Accuracy: {100 * correct / total:.2f}%") + +# Init two equal models +model1, model2 = CNN(), CNN() +model2.load_state_dict(model1.state_dict()) + print("AdamW") -model = CNN().to(device) -train_test(model, torch.optim.AdamW(model.parameters())) - +train_test(model1, torch.optim.AdamW(model1.parameters())) + print("AdamO") -model = CNN().to(device) -train_test(model, AdamO(model.parameters())) +train_test(model2, AdamO(model2.parameters())) ``` ## Benchmark Overshoot on various tasks