Skip to content

Commit

Permalink
Vinh update code
Browse files Browse the repository at this point in the history
  • Loading branch information
vinhdc10998 committed May 6, 2021
2 parents a767e10 + d6a6b81 commit 77f8255
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 1,612 deletions.
1,485 changes: 0 additions & 1,485 deletions .ipynb_checkpoints/GenotypeImputation-checkpoint.ipynb

This file was deleted.

Binary file removed data/__pycache__/dataset.cpython-36.pyc
Binary file not shown.
7 changes: 3 additions & 4 deletions model/custom_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ class CustomCrossEntropyLoss(nn.Module):
def __init__(self, gamma=0):
super(CustomCrossEntropyLoss, self).__init__()
self.gamma = gamma
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')

def forward(self, pred, target, a1_freq_list):
#TODO:
# assert pred, target, a1_freq_list
loss = nn.CrossEntropyLoss(reduction='none')
return (((2*a1_freq_list)**self.gamma) * loss(pred, target)).mean()
assert pred.shape[0] == target.shape[0] and target.shape[0] == a1_freq_list.shape[0]
return (((2*a1_freq_list)**self.gamma) * self.cross_entropy_loss(pred, target)).mean()
64 changes: 13 additions & 51 deletions model/gru_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import torch
import copy
from torch import nn
class GRUModel(nn.Module):
def __init__(self, model_config, device, type_model):
Expand All @@ -19,15 +18,8 @@ def __init__(self, model_config, device, type_model):
self.device = device

self._features = torch.tensor(np.load(f'model/features/region_{self.region}_model_features.npy')).to(self.device)
output_linear_dim = self.feature_size * 2
self.linear = nn.Sequential(
nn.Linear(self.feature_size, output_linear_dim),
nn.BatchNorm1d(output_linear_dim),
nn.ReLU()
)
self.sigmoid = nn.Sigmoid()

self.gru = nn.ModuleDict(self._create_gru_cell(
self.gru = nn.ModuleList(self._create_gru_cell(
self.feature_size,
self.hidden_units,
self.num_layers
Expand All @@ -41,14 +33,10 @@ def __init__(self, model_config, device, type_model):

@staticmethod
def _create_gru_cell(input_size, hidden_units, num_layers):
gru = [nn.GRU(input_size, hidden_units)] + [nn.GRU(hidden_units, hidden_units) for _ in range(num_layers-1)]
gru_fw = nn.ModuleList(copy.deepcopy(gru))
gru_bw = nn.ModuleList(copy.deepcopy(gru))
return {
'fw': gru_fw,
'bw': gru_bw
}

gru = [nn.GRU(input_size, hidden_units, bidirectional=True)] # First layer
gru += [nn.GRU(hidden_units*2, hidden_units, bidirectional=True) for _ in range(num_layers-1)] # 2 -> num_layers
return gru

@staticmethod
def _create_linear_list(hidden_units, num_classes, output_points_fw, output_points_bw):
list_linear = []
Expand All @@ -64,51 +52,25 @@ def forward(self, x):
return logits_list(g) in paper
'''
batch_size = x.shape[0]
with torch.no_grad():
_input = torch.unbind(x, dim=1)
fw_end = self.output_points_fw[-1]
bw_start = self.output_points_bw[0] #bw end

gru_inputs = []
for index in range(self.num_inputs):
gru_input = torch.matmul(_input[index], self._features[index])
# gru_input = self.linear(gru_input)
gru_inputs.append(gru_input)

outputs_fw = torch.zeros(self.num_inputs, batch_size, self.hidden_units)
outputs_bw = torch.zeros(self.num_inputs, batch_size, self.hidden_units)

if fw_end is not None:
inputs_fw = torch.stack(gru_inputs[: fw_end + 1])
outputs, _ = self._compute_gru(self.gru['fw'], inputs_fw, batch_size)
for t in range(fw_end + 1):
outputs_fw[t] = outputs[t]
if bw_start is not None:
inputs_bw = torch.stack([
gru_inputs[i]
for i in range(self.num_inputs - 1, bw_start - 1, -1)
])
outputs, _ = self._compute_gru(self.gru['bw'], inputs_bw, batch_size)
for i, t in enumerate(
range(self.num_inputs - 1, bw_start - 1, -1)):
outputs_bw[t] = outputs[i]
_input = torch.swapaxes(x, 0, 1)
gru_inputs = torch.matmul(_input, self._features)
outputs, _ = self._compute_gru(self.gru, gru_inputs, batch_size)

logit_list = []
for index, (t_fw, t_bw) in enumerate(zip(self.output_points_fw, self.output_points_bw)):
gru_output = []
if t_fw is not None:
gru_output.append(outputs_fw[t_fw])
gru_output.append(outputs[t_fw, :, :self.hidden_units])
if t_bw is not None:
gru_output.append(outputs_bw[t_bw])
gru_output.append(outputs[t_bw, :, self.hidden_units:])
gru_output = torch.cat(gru_output, dim=1).to(self.device)
logit = self.list_linear[index](gru_output)
# logit = self.sigmoid(logit)
logit_list.append(logit)
return logit_list

def init_hidden(self, batch):
weight = next(self.gru.parameters()).data
hidden = weight.new(1, batch, self.hidden_units).zero_()
weight = next(self.parameters()).data
hidden = weight.new(2, batch, self.hidden_units).zero_()
return hidden

def _compute_gru(self, GRUs, _input, batch_size):
Expand All @@ -121,4 +83,4 @@ def _compute_gru(self, GRUs, _input, batch_size):
_input = output
hidden = state
logits, state = _input, hidden
return logits, state
return logits, state
5 changes: 1 addition & 4 deletions model/single_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ def __init__(self,model_config, device, type_model=None):
def forward(self, input_):
logit_list = self.gruModel(input_)
logit = torch.cat(logit_list, dim=0)
prediction = torch.reshape(
F.softmax(logit, dim=-1),
shape=[self.num_outputs, -1, self.num_classes]
)
prediction = F.softmax(torch.stack(logit_list), dim=-1)
return logit, prediction


39 changes: 7 additions & 32 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,42 +1,17 @@
absl-py==0.12.0
asciitree==0.3.3
astor==0.8.1
cached-property==1.5.2
certifi==2020.12.5
cycler==0.10.0
Cython==0.29.22
dask==2021.3.0
fasteners==0.16
gast==0.2.2
google-pasta==0.2.0
grpcio==1.36.1
h5py==3.1.0
importlib-metadata==3.7.3
jedi==0.17.0
joblib==1.0.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.4
matplotlib==3.3.4
numcodecs==0.7.3
numpy==1.19.5
matplotlib==3.4.1
numpy==1.20.2
pandas==1.2.4
Pillow==8.2.0
protobuf==3.15.6
pyparsing==2.4.7
PyYAML==5.4.1
pyzmq==20.0.0
scikit-allel==1.3.3
python-dateutil==2.8.1
pytz==2021.1
scikit-learn==0.24.1
scipy==1.5.4
scipy==1.6.3
six==1.15.0
sklearn==0.0
termcolor==1.1.0
threadpoolctl==2.1.0
toolz==0.11.1
torch==1.8.1
traitlets==4.3.3
typing-extensions==3.7.4.3
Werkzeug==1.0.1
wrapt==1.12.1
zarr==2.7.0
zipp==3.4.1
51 changes: 31 additions & 20 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import json
import torch
from sklearn.metrics import r2_score
from model.custom_cross_entropy import CustomCrossEntropyLoss
from model.single_model import SingleModel
from model.early_stopping import EarlyStopping
Expand All @@ -12,51 +11,57 @@
from utils.imputation import train, evaluation, get_device, save_model
torch.manual_seed(42)

def run(dataloader, model_config, args, region, epochs=200):
def run(dataloader, model_config, args, region):
device = get_device(args.gpu)
type_model = args.model_type
lr = args.learning_rate
epochs = args.epochs
gamma = args.gamma if type_model == 'Higher' else -args.gamma
output_model_dir = args.output_model_dir
train_loader = dataloader['train']
val_loader = dataloader['val']
test_loader = dataloader['test']

#Init Model
model = SingleModel(model_config, device, type_model=type_model).to(device)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of learnable parameters:",count_parameters(model))

loss_fn = CustomCrossEntropyLoss(gamma)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
early_stopping = EarlyStopping(patience=10)
# early_stopping = EarlyStopping(patience=10)

#Start train
_r2_score_list, loss_values = [], [] #train
r2_val_list, val_loss_list = [], [] #validation
best_val_r2 = -99999999
best_val_loss = 99999999
for t in range(epochs):
train_loss, r2_train = train(train_loader, model, device, loss_fn, optimizer, scheduler)
val_loss, r2_val, _ = evaluation(test_loader, model, device, loss_fn)
val_loss, r2_val, _ = evaluation(val_loader, model, device, loss_fn)
test_loss, r2_test, _ = evaluation(test_loader, model, device, loss_fn)
loss_values.append(train_loss)
_r2_score_list.append(r2_train)
r2_val_list.append(r2_val)
val_loss_list.append(val_loss)
print(f"[REGION {region} - EPOCHS {t+1}]: train_loss: {train_loss:>7f}, train_r2: {r2_train:>7f}, val_loss: {val_loss:>7f}, val_r2: {r2_val:>7f}")

print(f"[REGION {region} - EPOCHS {t+1}]\
lr: {optimizer.param_groups[0]['lr']}\
train_loss: {train_loss:>7f}, train_r2: {r2_train:>7f},\
val_loss: {val_loss:>7f}, val_r2: {r2_val:>7f},\
test_loss: {test_loss:>7f}, test_r2: {r2_test:>7f}")
# Save best model
if r2_val > best_val_r2:
best_val_r2 = r2_val
if val_loss < best_val_loss:
best_val_loss = val_loss
best_epochs = t+1
save_model(model, region, type_model, output_model_dir, best=True)

#Early stopping
if args.early_stopping:
early_stopping(val_loss)
if early_stopping.early_stop:
break
print(f"Best model at epochs {best_epochs} with R2 score: {best_val_r2}")
# if args.early_stopping:
# early_stopping(val_loss)
# if early_stopping.early_stop:
# break

print(f"Best model at epochs {best_epochs} with loss: {best_val_loss}")
draw_chart(loss_values, _r2_score_list, val_loss_list, r2_val_list, region, type_model)
save_model(model, region, type_model, output_model_dir)

Expand All @@ -80,20 +85,26 @@ def main():
with open(os.path.join(model_config_dir, f'region_{region}_config.json'), "r") as json_config:
model_config = json.load(json_config)
model_config['region'] = region
train_set = RegionDataset(root_dir, region, chromosome)
train_val_set = RegionDataset(root_dir, region, chromosome)
test_set = RegionDataset(test_dir, region, chromosome)
print("[Train - test]:", len(train_set), len(test_set), 'samples')
train_size = int(0.8 * len(train_val_set))
val_size = len(train_val_set) - train_size
train_set, val_set = torch.utils.data.random_split(train_val_set, [train_size, val_size])

print("[Train - Val- Test]:", len(train_set), len(val_set), len(test_set), 'samples')
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)
dataloader = {
'train': train_loader,
'test': test_loader}
'test': test_loader,
'val': val_loader
}
run(
dataloader,
model_config,
args,
region,
epochs
region
)

if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion utils/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_argument():
dest='chromosome', help='Chromosome')
parser.add_argument('--lr', type=float, default=1e-4, required=False,
dest='learning_rate', help='Learning rate')
parser.add_argument('--gamma', type=str, default=0.1, required=False,
parser.add_argument('--gamma', type=float, default=1, required=False,
dest='gamma', help='gamma in loss function')
parser.add_argument('--output-model-dir', type=str, default='model/weights', required=False,
dest='output_model_dir', help='Output weights model dir')
Expand Down
21 changes: 8 additions & 13 deletions utils/imputation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import os
import torch
from sklearn.metrics import r2_score
from data.load_data import *
from data.load_data import mkdir

def evaluation(dataloader, model, device, loss_fn):
'''
Evaluate model with R square score
'''
size = len(dataloader)
model.eval()
size = len(dataloader)
test_loss = 0
with torch.no_grad():
predictions = []
Expand All @@ -19,7 +19,6 @@ def evaluation(dataloader, model, device, loss_fn):
# Compute prediction error
logits, prediction = model(X)
y_pred = torch.argmax(prediction, dim=-1).T

test_loss += loss_fn(logits, torch.flatten(y.T), torch.flatten(a1_freq.T)).item()

predictions.append(y_pred)
Expand All @@ -28,10 +27,8 @@ def evaluation(dataloader, model, device, loss_fn):
predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)
test_loss /= size
_r2_score = r2_score(
labels.cpu().detach().numpy(),
predictions.cpu().detach().numpy()
)
n_samples = len(labels)
_r2_score = sum([r2_score(labels[i].cpu().detach().numpy(), predictions[i].cpu().detach().numpy()) for i in range(n_samples)])/n_samples
return test_loss, _r2_score, (predictions, labels)

def train(dataloader, model, device, loss_fn, optimizer, scheduler):
Expand All @@ -58,16 +55,14 @@ def train(dataloader, model, device, loss_fn, optimizer, scheduler):
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()

train_loss = loss.item()

scheduler.step()
predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)
_r2_score = r2_score(
labels.cpu().detach().numpy(),
predictions.cpu().detach().numpy()
)
n_samples = len(labels)
_r2_score = sum([r2_score(labels[i].cpu().detach().numpy(), predictions[i].cpu().detach().numpy()) for i in range(n_samples)])/n_samples

return train_loss, _r2_score

def save_model(model, region, type_model, path, best=False):
Expand Down
5 changes: 3 additions & 2 deletions utils/plot_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def draw_MAF_R2(pred, label, a1_freq_list, type_model, region, bins=30, output_p
for index in range(bins):
y = torch.stack(label_bins[index]).detach().numpy().T
y_pred = torch.stack(pred_bins[index]).detach().numpy().T
_r2_score = r2_score(y, y_pred)
n_samples = len(y)
_r2_score = sum([r2_score(y[i], y_pred[i]) for i in range(n_samples)])/n_samples
r2_score_list.append(_r2_score)

x_axis = np.unique(pd.cut(a1_freq_list, bins, labels=np.linspace(start=0, stop=0.5, num=bins)))
print(np.unique(bins_list))
plt.plot(x_axis, r2_score_list)
plt.grid(linestyle='--')
plt.xlabel("Minor Allele Frequency")
Expand Down

0 comments on commit 77f8255

Please sign in to comment.