Skip to content

Commit 3ec226e

Browse files
committed
Update 06-optim-save-load
1 parent 8e9ad1c commit 3ec226e

File tree

1 file changed

+12
-12
lines changed
  • tutorials/01-basics/06-optim-save-load

1 file changed

+12
-12
lines changed

Diff for: tutorials/01-basics/06-optim-save-load/main.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
# =============================== #
2-
# Optimization #
3-
# =============================== #
1+
# ========================================= #
2+
# Optimization #
3+
# ========================================= #
44

55
# Prerequisite Code
66

@@ -65,9 +65,9 @@ def forward(self, x):
6565
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
6666

6767

68-
# =============================== #
69-
# Full Implementation #
70-
# =============================== #
68+
# ========================================= #
69+
# Full Implementation #
70+
# ========================================= #
7171

7272
def train_loop(dataloader, model, loss_fn, optimizer):
7373
size = len(dataloader.dataset)
@@ -101,19 +101,19 @@ def test_loop(dataloader, model, loss_fn):
101101
print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
102102

103103

104-
# =============================== #
105-
# Training #
106-
# =============================== #
104+
# ========================================= #
105+
# Training #
106+
# ========================================= #
107107

108108
for t in range(epochs):
109109
print(f"Epoch {t + 1}\n-------------------------------")
110110
train_loop(train_dataloader, model, loss_fn, optimizer)
111111
test_loop(test_dataloader, model, loss_fn)
112112
print("Done!")
113113

114-
# =============================== #
115-
# Save and Load the Model #
116-
# =============================== #
114+
# ========================================= #
115+
# Save and Load the Model #
116+
# ========================================= #
117117

118118
model = models.vgg16(pretrained=True)
119119
torch.save(model.state_dict(), 'model_weights.pth') # save the model

0 commit comments

Comments
 (0)