import torch.utils.data as data
from PIL import Image
import PIL
import os
import os.path
import pickle
import random
import numpy as np
import pandas as pd
import torch
import torchvision

class TextDataset(data.Dataset):
    def __init__(self,img_dir, data_dir,transform1,transform2, split,imsize1,imsize2):
        self.img_dir=img_dir
        self.transform1 = transform1
        self.transform2=transform2
        self.imsize1 = imsize1
        self.imsize2 = imsize2
        self.data = []
        self.data_dir = data_dir
        split_dir = os.path.join(data_dir, split)

        self.filenames = self.load_filenames(split_dir)
        self.embeddings = self.load_embedding(split_dir)
        self.class_id = self.load_class_id(split_dir, len(self.filenames))
        # self.captions = self.load_all_captions()

    def get_img(self, img_path):
        img = Image.open(img_path).convert('RGB')
        width, height = img.size
        load_size1 = int(self.imsize1 * 76 / 64)
        load_size2 = int(self.imsize2 * 76 / 64)
        img1 = img.resize((load_size1, load_size1), PIL.Image.BILINEAR)
        img2 = img.resize((load_size2, load_size2), PIL.Image.BILINEAR)
        if self.transform1 is not None:
            img1 = self.transform1(img1)
        if self.transform2 is not None:
            img2 = self.transform2(img2)    
        return img1,img2

  

    def load_embedding(self, data_dir):
       
        embedding_filename = '/char-CNN-RNN-embeddings.pickle'
        with open(data_dir + embedding_filename, 'rb') as f:
            #embeddings = pickle.load(f)
            embeddings = pickle._Unpickler(f)
            embeddings.encoding = 'latin1'
            embeddings = embeddings.load()
            embeddings = np.array(embeddings)
            # embedding_shape = [embeddings.shape[-1]]
            print('embeddings: ', embeddings.shape)
        return embeddings

    def load_class_id(self, data_dir, total_num):
        if os.path.isfile(data_dir + '/class_info.pickle'):
            with open(data_dir + '/class_info.pickle', 'rb') as f:
                class_id = pickle.load(f)
        else:
            class_id = np.arange(total_num)
        return class_id

    def load_filenames(self, data_dir):
        filepath = os.path.join(data_dir, 'filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames = pickle.load(f)
        print('Load filenames from: %s (%d)' % (filepath, len(filenames)))
        return filenames

    def __getitem__(self, index):
        key = self.filenames[index]
        # cls_id = self.class_id[index]
       
        data_dir = self.data_dir
        img_dir = self.img_dir
        # captions = self.captions[key]
        embeddings = self.embeddings[index, :, :]
        img_name = '%s/%s.jpg' % ( img_dir,key)
        img1,img2 = self.get_img(img_name)
        embedding_ix = random.randint(0, embeddings.shape[0]-1)
        embedding = embeddings[embedding_ix, :]
        return img1,img2, embedding

    def __len__(self):
        return len(self.filenames)