1
1
import unittest
2
2
3
+ import pytorch_lightning as pl
3
4
import torch
4
5
import torch .nn .functional as F
5
6
from torch .utils .data import DataLoader , TensorDataset
6
7
7
- import pytorch_lightning as pl
8
- from pytorch_lightning .metrics .functional import to_onehot
9
-
10
8
11
9
class LitDataModule (pl .LightningDataModule ):
12
10
@@ -16,10 +14,10 @@ def __init__(self, batch_size=16):
16
14
self .batch_size = batch_size
17
15
18
16
def setup (self , stage = None ):
19
- X_train = torch .rand (100 , 1 , 28 , 28 ). float ()
20
- y_train = to_onehot ( torch .randint (0 , 10 , size = (100 ,)), num_classes = 10 ). float ( )
17
+ X_train = torch .rand (100 , 1 , 28 , 28 )
18
+ y_train = torch .randint (0 , 10 , size = (100 ,))
21
19
X_valid = torch .rand (20 , 1 , 28 , 28 )
22
- y_valid = to_onehot ( torch .randint (0 , 10 , size = (20 ,)), num_classes = 10 ). float ( )
20
+ y_valid = torch .randint (0 , 10 , size = (20 ,))
23
21
24
22
self .train_ds = TensorDataset (X_train , y_train )
25
23
self .valid_ds = TensorDataset (X_valid , y_valid )
@@ -38,26 +36,23 @@ def __init__(self):
38
36
self .l1 = torch .nn .Linear (28 * 28 , 10 )
39
37
40
38
def forward (self , x ):
41
- return torch .relu (self .l1 (x .view (x .size (0 ), - 1 )))
39
+ return F .relu (self .l1 (x .view (x .size (0 ), - 1 )))
42
40
43
41
def training_step (self , batch , batch_idx ):
44
42
x , y = batch
45
43
y_hat = self (x )
46
- loss = F .binary_cross_entropy_with_logits (y_hat , y )
47
- result = pl .TrainResult (loss )
48
- result .log ('train_loss' , loss , on_epoch = True )
49
- return result
44
+ loss = F .cross_entropy (y_hat , y )
45
+ self .log ('train_loss' , loss )
46
+ return loss
50
47
51
48
def validation_step (self , batch , batch_idx ):
52
49
x , y = batch
53
50
y_hat = self (x )
54
- loss = F .binary_cross_entropy_with_logits (y_hat , y )
55
- result = pl .EvalResult (checkpoint_on = loss )
56
- result .log ('val_loss' , loss )
57
- return result
51
+ loss = F .cross_entropy (y_hat , y )
52
+ self .log ('val_loss' , loss )
58
53
59
54
def configure_optimizers (self ):
60
- return torch .optim .Adam (self .parameters (), lr = 0.02 )
55
+ return torch .optim .Adam (self .parameters (), lr = 1e-2 )
61
56
62
57
63
58
class TestPytorchLightning (unittest .TestCase ):
0 commit comments