Skip to content

Commit

Permalink
Merge pull request #28 from vinhdc10998/Vinhdev
Browse files Browse the repository at this point in the history
Fix gru and add multi model
  • Loading branch information
vinhdc10998 authored May 10, 2021
2 parents 9352658 + 2abd3c3 commit 5468742
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 4 deletions.
4 changes: 2 additions & 2 deletions model/gru_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, model_config, device, type_model):
self.type_model = type_model
self.device = device

self._features = torch.tensor(np.load(f'model/features/region_{self.region}_model_features.npy')).to(self.device)
self.linear = nn.Linear(self.input_dim, self.feature_size, bias=True)

self.gru = nn.ModuleList(self._create_gru_cell(
self.feature_size,
Expand Down Expand Up @@ -53,7 +53,7 @@ def forward(self, x):
'''
batch_size = x.shape[0]
_input = torch.swapaxes(x, 0, 1)
gru_inputs = torch.matmul(_input, self._features)
gru_inputs = self.linear(_input)
outputs, _ = self._compute_gru(self.gru, gru_inputs, batch_size)

logit_list = []
Expand Down
55 changes: 55 additions & 0 deletions model/multi_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
from torch import nn
from .gru_model import GRUModel
from torch.nn import functional as F

TYPE_MODEL = ['Hybrid']
class MultiModel(nn.Module):
def __init__(self,model_config, device, type_model=None):
super(MultiModel,self).__init__()
assert type_model in TYPE_MODEL
self.num_classes = model_config['num_classes']
self.num_outputs = model_config['num_outputs']
self.type_model = type_model

self.lowerModel = GRUModel(model_config, device, type_model='Lower')
self.higherModel = GRUModel(model_config, device, type_model='Higher')

self.lowerModel.load_state_dict(self.get_gru_layer(model_config['lower_path'], device))
self.higherModel.load_state_dict(self.get_gru_layer(model_config['higher_path'], device))

for param in self.lowerModel.parameters():
param.requires_grad = False
for param in self.higherModel.parameters():
param.requires_grad = False

self.linear = nn.ModuleList([nn.Linear(self.num_classes*2, self.num_classes) for _ in range(self.num_outputs)])


@staticmethod
def get_gru_layer(path, device):
tmp = torch.load(path, map_location=torch.device(device))
a = {}
for i in tmp:
if 'gru' in i:
k = i[9:]
a[k] = tmp[i]
return a



def forward(self, input_):
logits_1 = self.higherModel(input_)
logits_2 = self.lowerModel(input_)
logits = torch.cat((torch.stack(logits_1), torch.stack(logits_2)), dim=-1)
logit_list = []
for index, logit in enumerate(logits):
logit_tmp = self.linear[index](logit)
logit_list.append(logit_tmp)

logit = torch.cat(logit_list, dim=0)
pred = F.softmax(torch.stack(logit_list), dim=-1)
return logit, pred



18 changes: 16 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from model.multi_model import MultiModel
import os
import json
import torch
Expand All @@ -16,14 +17,27 @@ def run(dataloader, model_config, args, region):
type_model = args.model_type
lr = args.learning_rate
epochs = args.epochs
gamma = args.gamma if type_model == 'Higher' else -args.gamma

output_model_dir = args.output_model_dir
train_loader = dataloader['train']
val_loader = dataloader['val']
test_loader = dataloader['test']

#Init Model
model = SingleModel(model_config, device, type_model=type_model).to(device)
if type_model in ['Lower', 'Higher']:
gamma = args.gamma if type_model == 'Higher' else -args.gamma
model = SingleModel(model_config, device, type_model=type_model).to(device)

elif type_model in ['Hybrid']:
gamma = 0
if args.best_model:
model_config['lower_path'] = os.path.join(args.model_dir, f'Best_Lower_region_{region}.pt')
model_config['higher_path'] = os.path.join(args.model_dir, f'Best_Higher_region_{region}.pt')
else:
model_config['lower_path'] = os.path.join(args.model_dir, f'Lower_region_{region}.pt')
model_config['higher_path'] = os.path.join(args.model_dir, f'Higher_region_{region}.pt')
model = MultiModel(model_config, device, type_model=type_model).to(device)

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of learnable parameters:",count_parameters(model))
Expand Down

0 comments on commit 5468742

Please sign in to comment.