-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
32 lines (24 loc) · 993 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""
We'll feed inputs into our network in batches.
So here are some tools for iterating over data in batches.
"""
from typing import Iterator, NamedTuple
import numpy as np
from vsdl.tensor import Tensor
Batch = NamedTuple("Batch", [('inputs', Tensor), ('targets', Tensor)])
class DataIterator:
def __call__(self, inputs: Tensor, Target: Tensor) -> Iterator:
raise NotADirectoryError
class BatchIterator(DataIterator):
def __init__(self, batch_size: int = 32, shuffle: bool = True) -> None:
self.batch_size = batch_size
self.shuffle = shuffle
def __call__(self, inputs: Tensor, targets: Tensor) -> Iterator:
starts = np.arange(0, len(inputs), self.batch_size)
if self.shuffle:
np.random.shuffle(starts)
for start in starts:
end = start+ self.batch_size
batch_inputs = inputs[start:end]
batch_targets = targets[start:end]
yield Batch(batch_inputs, batch_targets)