diff --git a/README.md b/README.md index 3991753..ca1ada1 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,6 @@ pip install git+https://github.com/kinit-sk/overshoot.git ```python import torch from torchvision import datasets, transforms -from torch.optim import AdamW, SGD from overshoot import AdamO, SGDO class MLP(torch.nn.Module): @@ -43,7 +42,6 @@ class MLP(torch.nn.Module): x = torch.relu(x) return self.fc2(x) - def train_test(model, optimizer): torch.manual_seed(42) # Make training process same transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) @@ -63,7 +61,7 @@ def train_test(model, optimizer): # Move weights to base variant if isinstance(optimizer, (AdamO, SGDO)): optimizer.move_to_base() - + model.eval() with torch.no_grad(): correct = 0 @@ -76,21 +74,19 @@ def train_test(model, optimizer): if isinstance(optimizer, (AdamO, SGDO)): optimizer.move_to_overshoot() - - # Init four equal models models = [MLP() for _ in range(4)] for m in models[1:]: m.load_state_dict(models[0].state_dict()) print("AdamW") -train_test(models[0], AdamW(models[0].parameters())) +train_test(models[0], torch.optim.AdamW(models[0].parameters())) print("AdamO (AdamW + overshoot)") train_test(models[1], AdamO(models[1].parameters(), overshoot=5)) print("SGD") -train_test(models[2], SGD(models[2].parameters(), lr=0.01, momentum=0.9)) +train_test(models[2], torch.optim.SGD(models[2].parameters(), lr=0.01, momentum=0.9)) print("SGD (SGD + overshoot)") train_test(models[3], SGDO(models[3].parameters(), lr=0.01, momentum=0.9, overshoot=5))