Skip to content

Commit

Permalink
Read Fasta
Browse files Browse the repository at this point in the history
  • Loading branch information
millanp95 committed Jun 26, 2024
1 parent 61a3bee commit c30a0ce
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 8 deletions.
64 changes: 64 additions & 0 deletions barcodebert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,67 @@ def concat_all_gather(tensor, **kwargs):
torch.distributed.all_gather(tensors_gather, tensor, **kwargs)
output = torch.cat(tensors_gather, dim=0)
return output


def check_sequence(header, seq):
"""
Adapted from VAMB: https://github.com/RasmussenLab/vamb
Check that there're no invalid characters or bad format
in the file.
Note: The GAPS ('-') that are introduced from alignment
are considered valid characters.
"""

if len(header) > 0 and (header[0] in (">", "#") or header[0].isspace()):
raise ValueError("Bad character in sequence header")
if "\t" in header:
raise ValueError("tab included in header")

#Preprocessing of sequence before tokenization
basemask = bytearray.maketrans(b"acgtuUswkmyrbdhvnSWKMBDHV", b"ACGTTTNNNNNNNNNNNNNNNNNNN")

masked = seq.translate(basemask, b" \t\n\r")
stripped = masked.translate(None, b"ACGTNYR-") #Valid characters
if len(stripped) > 0:
bad_character = chr(stripped[0])
msg = "Invalid DNA byte in sequence {}: '{}'"
raise ValueError(msg.format(header, bad_character))
return masked

def read_fasta(fname):

barcodes = []
ids = []
lines = []
seq_id = ""

for line in open(fname, "rb"):
if line.startswith(b"#"):
pass

elif line.startswith(b">"):
if seq_id != "":
seq = bytearray().join(lines)

# Check entry is valid
seq = check_sequence(seq_id, seq)

# Save the barcode
barcodes.append(seq.decode())

lines = []
ids.append(seq_id)
seq_id = line[1:-1].decode() # Modify this according to your labels.
seq_id = line[1:-1].decode()
else:
lines += [line.strip()]

seq = bytearray().join(lines)
seq = check_sequence(seq_id, seq)
# Save the barcode
barcodes.append(seq.decode())
ids.append(seq_id)

return barcodes, ids
23 changes: 19 additions & 4 deletions baselines/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@
from tqdm.auto import tqdm
from transformers import AutoTokenizer

from barcodebert.utils import read_fasta


class DNADataset(Dataset):
def __init__(self, file_path, embedder, randomize_offset=False, max_length=660):
self.randomize_offset = randomize_offset

df = pd.read_csv(file_path, sep="\t" if file_path.endswith(".tsv") else ",", keep_default_na=False)
self.barcodes = df["nucleotides"].to_list()
if file_path.endswith((".tsv",".csv")):

self.ids = df["species_index"].to_list() # ideally, this should be process id
df = pd.read_csv(file_path, sep="\t" if file_path.endswith(".tsv") else ",", keep_default_na=False)
self.barcodes = df["nucleotides"].to_list()
self.ids = df["species_index"].to_list() # ideally, this should be process id
elif file_path.endswith((".fas",".fa", ".fasta")):
self.barcodes, self.ids = read_fasta(file_path)

self.tokenizer = embedder.tokenizer
self.backbone_name = embedder.name
self.max_len = max_length
Expand Down Expand Up @@ -60,7 +66,7 @@ def __getitem__(self, idx):
return processed_barcode, label


def representations_from_df(filename, embedder, batch_size=128):
def representations_from_file(filename, embedder, batch_size=128):

# create embeddings folder
if not os.path.isdir("embeddings"):
Expand Down Expand Up @@ -171,3 +177,12 @@ def labels_from_df(filename, target_level, label_pipeline):
labels = df[target_level].to_list()
return np.array(list(map(label_pipeline, labels)))
# return df[target_level].to_numpy()


def labels_from_fasta(filename, target_level, label_pipeline):
taxonomy = ["class", "order", "family", "subfamily", "tribe", "genus", "species"]
_, ids = read_fasta(filename)
labels = np.array([tag.split(" ")[1] for tag in ids])
labels= np.array([tag.split("|")[taxonomy.index(target_level)] for tag in labels])

return np.array(list(map(label_pipeline, labels)))
8 changes: 4 additions & 4 deletions baselines/linear_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from barcodebert import utils
from barcodebert.datasets import KmerTokenizer
from barcodebert.io import load_pretrained_model
from baselines.datasets import labels_from_df, representations_from_df
from baselines.datasets import labels_from_df, representations_from_file
from baselines.io import load_baseline_model


Expand Down Expand Up @@ -122,17 +122,17 @@ def run(config):

# Generate emebddings for the training, test and validation sets
print("Generating embeddings for test set", flush=True)
X_test = representations_from_df(test_filename, embedder, batch_size=128)
X_test = representations_from_file(test_filename, embedder, batch_size=128)
y_test = labels_from_df(test_filename, target_level, label_pipeline)
print(X_test.shape, y_test.shape)

print("Generating embeddings for validation set", flush=True)
X_val = representations_from_df(validation_filename, embedder, batch_size=128)
X_val = representations_from_file(validation_filename, embedder, batch_size=128)
y_val = labels_from_df(validation_filename, target_level, label_pipeline)
print(X_test.shape, y_test.shape)

print("Generating embeddings for train set", flush=True)
X = representations_from_df(train_filename, embedder, batch_size=128)
X = representations_from_file(train_filename, embedder, batch_size=128)
y = labels_from_df(train_filename, target_level, label_pipeline)
print(X.shape, y.shape)

Expand Down

0 comments on commit c30a0ce

Please sign in to comment.