-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
26 lines (24 loc) · 852 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""
Here's a function that can train a neural net
"""
from vsdl.tensor import Tensor
from vsdl.nn import NeuralNet
from vsdl.loss import Loss, MSE
from vsdl.optim import Optimizer, SGD
from vsdl.data import DataIterator, BatchIterator
def train(net: NeuralNet,
inputs: Tensor,
targets: Tensor,
num_epochs: int = 5000,
iterator: DataIterator = BatchIterator(),
loss: Loss = MSE(),
optimizer: Optimizer = SGD()) -> None:
for epoch in range(num_epochs):
epoch_loss = 0.0
for batch in iterator(inputs, targets):
predicted = net.forward(batch.inputs)
epoch_loss += loss.loss(predicted, batch.targets)
grad = loss.grad(predicted, batch.targets)
net.backward(grad)
optimizer.step(net)
print(epoch, epoch_loss)