Skip to content

Commit

Permalink
STY: Comply with style guides
Browse files Browse the repository at this point in the history
  • Loading branch information
millanp95 committed Jul 16, 2024
1 parent a9fa309 commit 73a3cd7
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 166 deletions.
49 changes: 22 additions & 27 deletions baselines/embedders.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

# Adapted from https://github.com/frederikkemarin/BEND/blob/main/bend/models/dnabert2_padding.py
# Which was adapted from https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/bert_padding.py
# Which was adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
# Which was adapted from:
# https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py


import os
Expand All @@ -27,8 +28,6 @@

import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder
from torch import nn
from torchtext.vocab import build_vocab_from_iterator
from tqdm.auto import tqdm
from transformers import (
Expand All @@ -38,7 +37,6 @@
BertConfig,
BertModel,
BertTokenizer,
BigBirdModel,
logging,
)

Expand Down Expand Up @@ -102,12 +100,10 @@ def __call__(self, sequence: str, *args, **kwargs):
The embedding of the sequence.
"""
return self.embed([sequence], *args, disable_tqdm=True, **kwargs)[0]
return embeddings


##
## DNABert https://doi.org/10.1093/bioinformatics/btab083
## Download from https://github.com/jerryji1993/DNABERT
# DNABERT https://doi.org/10.1093/bioinformatics/btab083
# Download from https://github.com/jerryji1993/DNABERT


class DNABertEmbedder(BaseEmbedder):
Expand All @@ -120,7 +116,8 @@ def load_model(self, model_path: str = "../../external-models/DNABERT/", kmer: i
----------
model_path : str
The path to the model directory. Defaults to "../../external-models/DNABERT/".
The DNABERT models need to be downloaded manually as indicated in the DNABERT repository at https://github.com/jerryji1993/DNABERT.
The DNABERT models need to be downloaded manually as indicated in the DNABERT repository at:
https://github.com/jerryji1993/DNABERT.
kmer : int
The kmer size of the model. Defaults to 6.
Expand All @@ -132,7 +129,8 @@ def load_model(self, model_path: str = "../../external-models/DNABERT/", kmer: i

if not os.path.exists(dnabert_path):
print(
f"Path {dnabert_path} does not exists, check if the wrong path was given. If not download from https://github.com/jerryji1993/DNABERT"
f"Path {dnabert_path} does not exists, check if the wrong path was given. \
If not download from https://github.com/jerryji1993/DNABERT"
)

config = BertConfig.from_pretrained(dnabert_path)
Expand Down Expand Up @@ -330,17 +328,16 @@ def embed(
"""

self.model.eval()
cls_tokens = []
embeddings = []

with torch.no_grad():
for n, s in enumerate(tqdm(sequences, disable=disable_tqdm)):
for _n, s in enumerate(tqdm(sequences, disable=disable_tqdm)):
# print('sequence', n)
s_chunks = [
s[chunk : chunk + self.max_seq_len] for chunk in range(0, len(s), self.max_seq_len)
] # split into chunks
embedded_seq = []
for n_chunk, chunk in enumerate(s_chunks): # embed each chunk
for _n_chunk, chunk in enumerate(s_chunks): # embed each chunk
tokens_ids = self.tokenizer(chunk, return_tensors="pt")["input_ids"].int().to(device)
if len(tokens_ids[0]) > self.max_tokens: # too long to fit into the model
split = torch.split(tokens_ids, self.max_tokens, dim=-1)
Expand Down Expand Up @@ -368,7 +365,8 @@ def embed(
if upsample_embeddings:
outs = self._repeat_embedding_vectors(self.tokenizer.convert_ids_to_tokens(tokens_ids[0]), outs)
embedded_seq.append(outs[:, 1:] if remove_special_tokens else outs)
# print('chunk', n_chunk, 'chunk length', len(chunk), 'tokens length', len(tokens_ids[0]), 'chunk embedded shape', outs.shape)
# print('chunk', n_chunk, 'chunk length', len(chunk), 'tokens length', len(tokens_ids[0]), ...
# 'chunk embedded shape', outs.shape)
embeddings.append(np.concatenate(embedded_seq, axis=1))

return embeddings
Expand Down Expand Up @@ -412,8 +410,10 @@ def load_model(self, model_path="pretrained_models/hyenadna/hyenadna-tiny-1k-seq
----------
model_path : str, optional
Path to the model checkpoint. Defaults to 'pretrained_models/hyenadna/hyenadna-tiny-1k-seqlen'.
If the path does not exist, the model will be downloaded from HuggingFace. Rather than just downloading the model,
HyenaDNA's `from_pretrained` method relies on cloning the HuggingFace-hosted repository, and using git lfs to download the model.
If the path does not exist, the model will be downloaded from HuggingFace. Rather than just
downloading the model,
HyenaDNA's `from_pretrained` method relies on cloning the HuggingFace-hosted repository,
and using git lfs to download the model.
This requires git lfs to be installed on your system, and will fail if it is not.
Expand All @@ -432,9 +432,9 @@ def load_model(self, model_path="pretrained_models/hyenadna/hyenadna-tiny-1k-seq
# all these settings are copied directly from huggingface.py

# data settings:
use_padding = True
rc_aug = False # reverse complement augmentation
add_eos = False # add end of sentence token
# use_padding = True
# rc_aug = False # reverse complement augmentation
# add_eos = False # add end of sentence token

# we need these for the decoder head, if using
use_head = False
Expand Down Expand Up @@ -501,20 +501,15 @@ def embed(
List of embeddings.
"""

# # prep model and forward
# model.to(device)
# with torch.inference_mode():

embeddings = []
with torch.inference_mode():
for s in tqdm(sequences, disable=disable_tqdm):
chunks = [
s[chunk : chunk + self.max_length] for chunk in range(0, len(s), self.max_length)
] # split into chunks
embedded_chunks = []
for n_chunk, chunk in enumerate(chunks):
#### Single embedding example ####

for _n_chunk, chunk in enumerate(chunks):
# Single embedding example
# create a sample 450k long, prepare
# sequence = 'ACTG' * int(self.max_length/4)
tok_seq = self.tokenizer(chunk) # adds CLS and SEP tokens
Expand Down Expand Up @@ -820,7 +815,7 @@ def load_model(self, checkpoint_path=None, from_paper=False, k_mer=8, n_heads=4,
if not from_paper:
model, ckpt = load_pretrained_model(checkpoint_path, device=device)
else:
model = load_old_pretrained_model(checkpoint_path, config, device=device)
model = load_old_pretrained_model(checkpoint_path, k_mer, device=device)

else:
arch = f"{k_mer}_{n_heads}_{n_layers}"
Expand Down
11 changes: 0 additions & 11 deletions baselines/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@
Input/output utilities.
"""

import os
from inspect import getsourcefile

import torch
from transformers import BertConfig, BertForMaskedLM, BertForTokenClassification

from baselines.embedders import (
BarcodeBERTEmbedder,
Expand All @@ -16,13 +12,6 @@
NucleotideTransformerEmbedder,
)

# PACKAGE_DIR = os.path.dirname(os.path.abspath(getsourcefile(lambda: 0)))


# def get_project_root() -> str:
# return os.path.dirname(PACKAGE_DIR)


device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

Expand Down
20 changes: 10 additions & 10 deletions baselines/models/dnabert2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch
import torch.nn as nn
from einops import rearrange
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from transformers.activations import ACT2FN
from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
from transformers.models.bert.modeling_bert import BertPreTrainedModel
Expand All @@ -30,7 +29,7 @@

try:
from .flash_attn_triton import flash_attn_qkvpacked_func
except ImportError as e:
except ImportError:
flash_attn_qkvpacked_func = None

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,11 +82,9 @@ def forward(
assert isinstance(self.token_type_ids, torch.LongTensor)
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
token_type_ids = buffered_token_type_ids_expanded # type: ignore
token_type_ids = buffered_token_type_ids_expanded
else:
token_type_ids = torch.zeros(
input_shape, dtype=torch.long, device=self.word_embeddings.device # type: ignore
) # type: ignore # yapf: disable
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.word_embeddings.device)

if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
Expand Down Expand Up @@ -120,7 +117,9 @@ def __init__(self, config):
# Warn if defaulting to pytorch because of import issues
if flash_attn_qkvpacked_func is None:
warnings.warn(
"Unable to import Triton; defaulting MosaicBERT attention implementation to pytorch (this will reduce throughput when using this model)."
"Unable to import Triton; defaulting MosaicBERT attention \
implementation to pytorch (this will reduce throughput when using this model).",
stacklevel=2,
)

def forward(
Expand Down Expand Up @@ -410,7 +409,7 @@ def forward(
# Add alibi matrix to extended_attention_mask
if self._current_alibi_size < seqlen:
# Rebuild the alibi tensor when needed
warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}")
warnings.warn(f"Increasing alibi size from {self._current_alibi_size} to {seqlen}", stacklevel=2)
self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device)
elif self.alibi.device != hidden_states.device:
# Device catch-up
Expand Down Expand Up @@ -592,7 +591,7 @@ def forward(
else:
# TD [2022-03-01]: the indexing here is very tricky.
attention_mask_bool = attention_mask.bool()
subset_idx = subset_mask[attention_mask_bool] # type: ignore
subset_idx = subset_mask[attention_mask_bool]
sequence_output = encoder_outputs[-1][masked_tokens_mask[attention_mask_bool][subset_idx]]
if self.pooler is not None:
pool_input = encoder_outputs[-1][first_col_mask[attention_mask_bool][subset_idx]]
Expand Down Expand Up @@ -658,7 +657,8 @@ def __init__(self, config):
if config.is_decoder:
warnings.warn(
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
"bi-directional self-attention."
"bi-directional self-attention.",
stacklevel=2,
)

self.bert = BertModel(config, add_pooling_layer=False)
Expand Down
Loading

0 comments on commit 73a3cd7

Please sign in to comment.