Skip to content

Commit

Permalink
DEV: Update to BIOSCAN 5M
Browse files Browse the repository at this point in the history
  • Loading branch information
millanp95 committed Jun 12, 2024
1 parent fe4e000 commit 5361381
Show file tree
Hide file tree
Showing 14 changed files with 31 additions and 423 deletions.
8 changes: 3 additions & 5 deletions barcodebert/bzsl/feature_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .utils import (
extract_clean_barcode_list,
extract_clean_barcode_list_for_aligned,
extract_dna_features,
)
from .utils import (extract_clean_barcode_list,
extract_clean_barcode_list_for_aligned,
extract_dna_features)

__all__ = ["extract_clean_barcode_list", "extract_clean_barcode_list_for_aligned", "extract_dna_features"]
6 changes: 2 additions & 4 deletions barcodebert/bzsl/feature_extraction/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
import scipy.io as sio

from barcodebert.bzsl.feature_extraction import (
extract_clean_barcode_list,
extract_clean_barcode_list_for_aligned,
extract_dna_features,
)
extract_clean_barcode_list, extract_clean_barcode_list_for_aligned,
extract_dna_features)
from barcodebert.bzsl.models import load_model

random.seed(10)
Expand Down
6 changes: 2 additions & 4 deletions barcodebert/bzsl/finetuning/supervised_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@
from tqdm import tqdm

from barcodebert.bzsl.feature_extraction import (
extract_clean_barcode_list,
extract_clean_barcode_list_for_aligned,
extract_dna_features,
)
extract_clean_barcode_list, extract_clean_barcode_list_for_aligned,
extract_dna_features)
from barcodebert.bzsl.models import load_model

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
Expand Down
5 changes: 1 addition & 4 deletions barcodebert/bzsl/genus_species/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
import torch

from barcodebert.bzsl.genus_species.bayesian_classifier import (
BayesianClassifier,
apply_pca,
calculate_priors,
)
BayesianClassifier, apply_pca, calculate_priors)
from barcodebert.bzsl.genus_species.dataset import get_data_splits, load_data


Expand Down
9 changes: 2 additions & 7 deletions barcodebert/bzsl/models/dnabert/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,8 @@

from tokenizers.implementations import BaseTokenizer

from .file_utils import (
cached_path,
hf_bucket_url,
is_remote_url,
is_tf_available,
is_torch_available,
)
from .file_utils import (cached_path, hf_bucket_url, is_remote_url,
is_tf_available, is_torch_available)

if is_tf_available():
import tensorflow as tf
Expand Down
7 changes: 2 additions & 5 deletions barcodebert/bzsl/surrogate_species/bayesian_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
from scipy.spatial.distance import cdist
from scipy.special import gammaln

from barcodebert.bzsl.surrogate_species.utils import (
DataLoader,
apply_pca,
perf_calc_acc,
)
from barcodebert.bzsl.surrogate_species.utils import (DataLoader, apply_pca,
perf_calc_acc)


class Model:
Expand Down
5 changes: 2 additions & 3 deletions barcodebert/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from barcodebert import utils
from barcodebert.datasets import DNADataset
from barcodebert.evaluation import evaluate
from barcodebert.io import get_project_root, load_pretrained_model, safe_save_model
from barcodebert.io import (get_project_root, load_pretrained_model,
safe_save_model)

BASE_BATCH_SIZE = 64

Expand Down Expand Up @@ -163,8 +164,6 @@ def print_pass(*args, **kwargs):
"max_len",
"tokenizer",
"use_unk_token",
"pretrain_levenshtein",
"levenshtein_vectorized",
"n_layers",
"n_heads",
]
Expand Down
2 changes: 0 additions & 2 deletions barcodebert/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,6 @@ def run(config):
"max_len",
"tokenizer",
"use_unk_token",
"pretrain_levenshtein",
"levenshtein_vectorized",
"n_layers",
"n_heads",
]
Expand Down
4 changes: 2 additions & 2 deletions barcodebert/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def load_pretrained_model(checkpoint_path, device=None):
return model, ckpt


def load_old_pretrained_model(checkpoint_path, config, device=None):
def load_old_pretrained_model(checkpoint_path, k_mer, device=None):
"""
Load a pretrained model using the publised format from a checkpoint file.
Expand All @@ -104,7 +104,7 @@ def load_old_pretrained_model(checkpoint_path, config, device=None):
ckpt : dict
The contents of the checkpoint file.
"""
vocab_size = 4**config.k_mer + 3
vocab_size = 4**k_mer + 3
configuration = BertConfig(vocab_size=vocab_size, output_hidden_states=True)
# Initializing a model (with random weights) from the bert-base-uncased style configuration
model = BertForMaskedLM(configuration)
Expand Down
2 changes: 0 additions & 2 deletions barcodebert/knn_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ def run(config):
"max_len",
"tokenizer",
"use_unk_token",
"pretrain_levenshtein",
"levenshtein_vectorized",
"n_layers",
"n_heads",
]
Expand Down
141 changes: 0 additions & 141 deletions barcodebert/levenshtein.py

This file was deleted.

17 changes: 5 additions & 12 deletions barcodebert/linear_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def run(config):
start_time = time.time()
k = config.k_mer

extra = "soft_" if config.pretrain_levenshtein else ""
extra = ""
representation_folder = os.path.join(get_project_root(), "data", f"{extra}{config.n_layers}_{config.n_layers}")
if not os.path.exists(representation_folder):
os.makedirs(representation_folder, exist_ok=True)
Expand All @@ -36,16 +36,10 @@ def run(config):

# Vocabulary
max_len = config.max_len
if config.pretrain_levenshtein:
kmer_iter = (["".join(kmer)] for kmer in product("ACGTN", repeat=k))
vocab = build_vocab_from_iterator(kmer_iter, specials=["<MASK>"])
vocab.set_default_index(vocab["N" * k]) # <UNK> and <CLS> do not exist anymore
vocab_size = len(vocab)
else:
kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k))
vocab = build_vocab_from_iterator(kmer_iter, specials=["<MASK>", "<UNK>"])
vocab.set_default_index(vocab["<UNK>"]) # <UNK> is necessary in the hard case
vocab_size = len(vocab)
kmer_iter = (["".join(kmer)] for kmer in product("ACGT", repeat=k))
vocab = build_vocab_from_iterator(kmer_iter, specials=["<MASK>", "<UNK>"])
vocab.set_default_index(vocab["<UNK>"]) # <UNK> is necessary in the hard case
vocab_size = len(vocab)

tokenizer = KmerTokenizer(k, vocab, stride=k, padding=True, max_len=max_len)

Expand Down Expand Up @@ -183,7 +177,6 @@ def get_parser():
parser.add_argument("--max_len", action="store", type=int, default=660)
parser.add_argument("--n_layers", action="store", type=int, default=12)
parser.add_argument("--n_heads", action="store", type=int, default=12)
parser.add_argument("--pretrain_levenshtein", action="store_true")
return parser


Expand Down
Loading

0 comments on commit 5361381

Please sign in to comment.