Skip to content

Commit

Permalink
Merge pull request #26 from vinhdc10998/Vinhdev
Browse files Browse the repository at this point in the history
Vinhdev
  • Loading branch information
vinhdc10998 authored May 3, 2021
2 parents a17e5b8 + 8c05e79 commit d6a6b81
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
13 changes: 5 additions & 8 deletions utils/imputation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ def evaluation(dataloader, model, device, loss_fn):
predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)
test_loss /= size
_r2_score = r2_score(
labels.cpu().detach().numpy(),
predictions.cpu().detach().numpy()
)
n_samples = len(labels)
_r2_score = sum([r2_score(labels[i].cpu().detach().numpy(), predictions[i].cpu().detach().numpy()) for i in range(n_samples)])/n_samples
return test_loss, _r2_score, (predictions, labels)

def train(dataloader, model, device, loss_fn, optimizer, scheduler):
Expand Down Expand Up @@ -63,10 +61,9 @@ def train(dataloader, model, device, loss_fn, optimizer, scheduler):

predictions = torch.cat(predictions, dim=0)
labels = torch.cat(labels, dim=0)
_r2_score = r2_score(
labels.cpu().detach().numpy(),
predictions.cpu().detach().numpy()
)
n_samples = len(labels)
_r2_score = sum([r2_score(labels[i].cpu().detach().numpy(), predictions[i].cpu().detach().numpy()) for i in range(n_samples)])/n_samples

return train_loss, _r2_score

def save_model(model, region, type_model, path, best=False):
Expand Down
5 changes: 3 additions & 2 deletions utils/plot_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,11 @@ def draw_MAF_R2(pred, label, a1_freq_list, type_model, region, bins=30, output_p
for index in range(bins):
y = torch.stack(label_bins[index]).detach().numpy().T
y_pred = torch.stack(pred_bins[index]).detach().numpy().T
_r2_score = r2_score(y, y_pred)
n_samples = len(y)
_r2_score = sum([r2_score(y[i], y_pred[i]) for i in range(n_samples)])/n_samples
r2_score_list.append(_r2_score)

x_axis = np.unique(pd.cut(a1_freq_list, bins, labels=np.linspace(start=0, stop=0.5, num=bins)))
print(np.unique(bins_list))
plt.plot(x_axis, r2_score_list)
plt.grid(linestyle='--')
plt.xlabel("Minor Allele Frequency")
Expand Down

0 comments on commit d6a6b81

Please sign in to comment.