From 2abd3c34757a417f3d08a3940000f92c5bbe697f Mon Sep 17 00:00:00 2001 From: vinhdc10998 Date: Mon, 10 May 2021 11:22:36 +0700 Subject: [PATCH] Fix gru and add multi model --- model/gru_model.py | 4 ++-- model/multi_model.py | 55 ++++++++++++++++++++++++++++++++++++++++++++ train.py | 18 +++++++++++++-- 3 files changed, 73 insertions(+), 4 deletions(-) create mode 100644 model/multi_model.py diff --git a/model/gru_model.py b/model/gru_model.py index 1cd666a..d9ca46c 100644 --- a/model/gru_model.py +++ b/model/gru_model.py @@ -17,7 +17,7 @@ def __init__(self, model_config, device, type_model): self.type_model = type_model self.device = device - self._features = torch.tensor(np.load(f'model/features/region_{self.region}_model_features.npy')).to(self.device) + self.linear = nn.Linear(self.input_dim, self.feature_size, bias=True) self.gru = nn.ModuleList(self._create_gru_cell( self.feature_size, @@ -53,7 +53,7 @@ def forward(self, x): ''' batch_size = x.shape[0] _input = torch.swapaxes(x, 0, 1) - gru_inputs = torch.matmul(_input, self._features) + gru_inputs = self.linear(_input) outputs, _ = self._compute_gru(self.gru, gru_inputs, batch_size) logit_list = [] diff --git a/model/multi_model.py b/model/multi_model.py new file mode 100644 index 0000000..ab71eb2 --- /dev/null +++ b/model/multi_model.py @@ -0,0 +1,55 @@ +import torch +from torch import nn +from .gru_model import GRUModel +from torch.nn import functional as F + +TYPE_MODEL = ['Hybrid'] +class MultiModel(nn.Module): + def __init__(self,model_config, device, type_model=None): + super(MultiModel,self).__init__() + assert type_model in TYPE_MODEL + self.num_classes = model_config['num_classes'] + self.num_outputs = model_config['num_outputs'] + self.type_model = type_model + + self.lowerModel = GRUModel(model_config, device, type_model='Lower') + self.higherModel = GRUModel(model_config, device, type_model='Higher') + + self.lowerModel.load_state_dict(self.get_gru_layer(model_config['lower_path'], device)) + self.higherModel.load_state_dict(self.get_gru_layer(model_config['higher_path'], device)) + + for param in self.lowerModel.parameters(): + param.requires_grad = False + for param in self.higherModel.parameters(): + param.requires_grad = False + + self.linear = nn.ModuleList([nn.Linear(self.num_classes*2, self.num_classes) for _ in range(self.num_outputs)]) + + + @staticmethod + def get_gru_layer(path, device): + tmp = torch.load(path, map_location=torch.device(device)) + a = {} + for i in tmp: + if 'gru' in i: + k = i[9:] + a[k] = tmp[i] + return a + + + + def forward(self, input_): + logits_1 = self.higherModel(input_) + logits_2 = self.lowerModel(input_) + logits = torch.cat((torch.stack(logits_1), torch.stack(logits_2)), dim=-1) + logit_list = [] + for index, logit in enumerate(logits): + logit_tmp = self.linear[index](logit) + logit_list.append(logit_tmp) + + logit = torch.cat(logit_list, dim=0) + pred = F.softmax(torch.stack(logit_list), dim=-1) + return logit, pred + + + \ No newline at end of file diff --git a/train.py b/train.py index 3e91053..ecefa33 100644 --- a/train.py +++ b/train.py @@ -1,3 +1,4 @@ +from model.multi_model import MultiModel import os import json import torch @@ -16,14 +17,27 @@ def run(dataloader, model_config, args, region): type_model = args.model_type lr = args.learning_rate epochs = args.epochs - gamma = args.gamma if type_model == 'Higher' else -args.gamma + output_model_dir = args.output_model_dir train_loader = dataloader['train'] val_loader = dataloader['val'] test_loader = dataloader['test'] #Init Model - model = SingleModel(model_config, device, type_model=type_model).to(device) + if type_model in ['Lower', 'Higher']: + gamma = args.gamma if type_model == 'Higher' else -args.gamma + model = SingleModel(model_config, device, type_model=type_model).to(device) + + elif type_model in ['Hybrid']: + gamma = 0 + if args.best_model: + model_config['lower_path'] = os.path.join(args.model_dir, f'Best_Lower_region_{region}.pt') + model_config['higher_path'] = os.path.join(args.model_dir, f'Best_Higher_region_{region}.pt') + else: + model_config['lower_path'] = os.path.join(args.model_dir, f'Lower_region_{region}.pt') + model_config['higher_path'] = os.path.join(args.model_dir, f'Higher_region_{region}.pt') + model = MultiModel(model_config, device, type_model=type_model).to(device) + def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) print("Number of learnable parameters:",count_parameters(model))