Skip to content

Commit

Permalink
Merge pull request #25 from vinhdc10998/Vinhdev
Browse files Browse the repository at this point in the history
Vinh update code model
  • Loading branch information
vinhdc10998 authored May 3, 2021
2 parents 7347d86 + f572caa commit a17e5b8
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 50 deletions.
7 changes: 3 additions & 4 deletions model/custom_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ class CustomCrossEntropyLoss(nn.Module):
def __init__(self, gamma=0):
super(CustomCrossEntropyLoss, self).__init__()
self.gamma = gamma
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')

def forward(self, pred, target, a1_freq_list):
#TODO:
# assert pred, target, a1_freq_list
loss = nn.CrossEntropyLoss(reduction='none')
return (((2*a1_freq_list)**self.gamma) * loss(pred, target)).mean()
assert pred.shape[0] == target.shape[0] and target.shape[0] == a1_freq_list.shape[0]
return (((2*a1_freq_list)**self.gamma) * self.cross_entropy_loss(pred, target)).sum()
56 changes: 14 additions & 42 deletions model/gru_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
import copy
import torch
from torch import nn
class GRUModel(nn.Module):
Expand All @@ -19,8 +18,9 @@ def __init__(self, model_config, device, type_model):
self.device = device

self._features = torch.tensor(np.load(f'model/features/region_{self.region}_model_features.npy')).to(self.device)
self.tanh = nn.Tanh()

self.gru = nn.ModuleDict(self._create_gru_cell(
self.gru = nn.ModuleList(self._create_gru_cell(
self.feature_size,
self.hidden_units,
self.num_layers
Expand All @@ -34,14 +34,9 @@ def __init__(self, model_config, device, type_model):

@staticmethod
def _create_gru_cell(input_size, hidden_units, num_layers):
gru = [nn.GRU(input_size, hidden_units)] + [nn.GRU(hidden_units, hidden_units) for _ in range(num_layers-1)]
gru_fw = nn.ModuleList(copy.deepcopy(gru))
gru_bw = nn.ModuleList(copy.deepcopy(gru))
return {
'fw': gru_fw,
'bw': gru_bw
}

gru = [nn.GRU(input_size, hidden_units, bidirectional=True)] + [nn.GRU(hidden_units*2, hidden_units, bidirectional=True) for _ in range(num_layers-1)]
return gru

@staticmethod
def _create_linear_list(hidden_units, num_classes, output_points_fw, output_points_bw):
list_linear = []
Expand All @@ -57,49 +52,26 @@ def forward(self, x):
return logits_list(g) in paper
'''
batch_size = x.shape[0]
with torch.no_grad():
_input = torch.unbind(x, dim=1)
fw_end = self.output_points_fw[-1]
bw_end = self.output_points_bw[0]

gru_inputs = []
for index in range(self.num_inputs):
gru_input = torch.matmul(_input[index], self._features[index])
gru_inputs.append(gru_input)

outputs_fw = torch.zeros(self.num_inputs, batch_size, self.hidden_units)
outputs_bw = torch.zeros(self.num_inputs, batch_size, self.hidden_units)

if fw_end is not None:
inputs_fw = torch.stack(gru_inputs[: fw_end + 1])
outputs, _ = self._compute_gru(self.gru['fw'], inputs_fw, batch_size)
for t in range(fw_end + 1):
outputs_fw[t] = outputs[t]
if bw_end is not None:
inputs_bw = torch.stack([
gru_inputs[i]
for i in range(self.num_inputs - 1, bw_end - 1, -1)
])
outputs, _ = self._compute_gru(self.gru['bw'], inputs_bw, batch_size)
for i, t in enumerate(
range(self.num_inputs - 1, bw_end - 1, -1)):
outputs_bw[t] = outputs[i]
_input = torch.swapaxes(x, 0, 1)
gru_inputs = torch.matmul(_input, self._features)
outputs, _ = self._compute_gru(self.gru, gru_inputs, batch_size)

logit_list = []
for index, (t_fw, t_bw) in enumerate(zip(self.output_points_fw, self.output_points_bw)):
gru_output = []
if t_fw is not None:
gru_output.append(outputs_fw[t_fw])
gru_output.append(outputs[t_fw, :, :self.hidden_units])
if t_bw is not None:
gru_output.append(outputs_bw[t_bw])
gru_output.append(outputs[t_bw, :, self.hidden_units:])
gru_output = torch.cat(gru_output, dim=1).to(self.device)
logit = self.list_linear[index](gru_output)
logit = self.tanh(logit)
logit_list.append(logit)
return torch.stack(logit_list)
return logit_list

def init_hidden(self, batch):
weight = next(self.gru.parameters()).data
hidden = weight.new(1, batch, self.hidden_units).zero_()
weight = next(self.parameters()).data
hidden = weight.new(2, batch, self.hidden_units).zero_()
return hidden

def _compute_gru(self, GRUs, _input, batch_size):
Expand Down
7 changes: 5 additions & 2 deletions model/single_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@ def __init__(self,model_config, device, type_model=None):

def forward(self, input_):
logit_list = self.gruModel(input_)
logit = torch.reshape(logit_list, shape=[-1, self.num_classes])
prediction = F.softmax(logit_list, dim=-1)
logit = torch.cat(logit_list, dim=0)
prediction = torch.reshape(
F.softmax(logit, dim=-1),
shape = [self.num_outputs, -1, self.num_classes]
)
return logit, prediction


2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def main():
model_config['region'] = region
train_val_set = RegionDataset(root_dir, region, chromosome)
test_set = RegionDataset(test_dir, region, chromosome)
train_size = int(0.7 * len(train_val_set))
train_size = int(0.8 * len(train_val_set))
val_size = len(train_val_set) - train_size
train_set, val_set = torch.utils.data.random_split(train_val_set, [train_size, val_size])

Expand Down
2 changes: 1 addition & 1 deletion utils/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_argument():
dest='chromosome', help='Chromosome')
parser.add_argument('--lr', type=float, default=1e-4, required=False,
dest='learning_rate', help='Learning rate')
parser.add_argument('--gamma', type=str, default=1.5, required=False,
parser.add_argument('--gamma', type=float, default=1, required=False,
dest='gamma', help='gamma in loss function')
parser.add_argument('--output-model-dir', type=str, default='model/weights', required=False,
dest='output_model_dir', help='Output weights model dir')
Expand Down

0 comments on commit a17e5b8

Please sign in to comment.