Skip to content

Commit a5102ed

Browse files
authored
Merge pull request #884 from rohitgr7/package/pl_10
Update pytorch-lightning 1.0
2 parents 83bb841 + e582fec commit a5102ed

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

tests/test_pytorch_lightning.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import unittest
22

3+
import pytorch_lightning as pl
34
import torch
45
import torch.nn.functional as F
56
from torch.utils.data import DataLoader, TensorDataset
67

7-
import pytorch_lightning as pl
8-
from pytorch_lightning.metrics.functional import to_onehot
9-
108

119
class LitDataModule(pl.LightningDataModule):
1210

@@ -16,10 +14,10 @@ def __init__(self, batch_size=16):
1614
self.batch_size = batch_size
1715

1816
def setup(self, stage=None):
19-
X_train = torch.rand(100, 1, 28, 28).float()
20-
y_train = to_onehot(torch.randint(0, 10, size=(100,)), num_classes=10).float()
17+
X_train = torch.rand(100, 1, 28, 28)
18+
y_train = torch.randint(0, 10, size=(100,))
2119
X_valid = torch.rand(20, 1, 28, 28)
22-
y_valid = to_onehot(torch.randint(0, 10, size=(20,)), num_classes=10).float()
20+
y_valid = torch.randint(0, 10, size=(20,))
2321

2422
self.train_ds = TensorDataset(X_train, y_train)
2523
self.valid_ds = TensorDataset(X_valid, y_valid)
@@ -38,26 +36,23 @@ def __init__(self):
3836
self.l1 = torch.nn.Linear(28 * 28, 10)
3937

4038
def forward(self, x):
41-
return torch.relu(self.l1(x.view(x.size(0), -1)))
39+
return F.relu(self.l1(x.view(x.size(0), -1)))
4240

4341
def training_step(self, batch, batch_idx):
4442
x, y = batch
4543
y_hat = self(x)
46-
loss = F.binary_cross_entropy_with_logits(y_hat, y)
47-
result = pl.TrainResult(loss)
48-
result.log('train_loss', loss, on_epoch=True)
49-
return result
44+
loss = F.cross_entropy(y_hat, y)
45+
self.log('train_loss', loss)
46+
return loss
5047

5148
def validation_step(self, batch, batch_idx):
5249
x, y = batch
5350
y_hat = self(x)
54-
loss = F.binary_cross_entropy_with_logits(y_hat, y)
55-
result = pl.EvalResult(checkpoint_on=loss)
56-
result.log('val_loss', loss)
57-
return result
51+
loss = F.cross_entropy(y_hat, y)
52+
self.log('val_loss', loss)
5853

5954
def configure_optimizers(self):
60-
return torch.optim.Adam(self.parameters(), lr=0.02)
55+
return torch.optim.Adam(self.parameters(), lr=1e-2)
6156

6257

6358
class TestPytorchLightning(unittest.TestCase):

0 commit comments

Comments
 (0)