-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathdata.py
30 lines (24 loc) · 775 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from env import Env_tsp
from config import Config, load_pkl, pkl_parser
class Generator(Dataset):
def __init__(self, cfg, env):
self.data = env.get_batch_nodes(cfg.n_samples)
def __getitem__(self, idx):
return self.data[idx]
def __len__(self):
return self.data.size(0)
if __name__ == '__main__':
cfg = load_pkl(pkl_parser().path)
env = Env_tsp(cfg)
dataset = Generator(cfg, env)
data = next(iter(dataset))
print(data.size())
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dataloader = DataLoader(dataset, batch_size = cfg.batch, shuffle = True)
for i, data in enumerate(dataloader):
print(data.size())
if i == 0:
break