diff --git a/analysis/arabidopsis/Snakefile b/analysis/arabidopsis/Snakefile index d79acbc..f2e8c7f 100644 --- a/analysis/arabidopsis/Snakefile +++ b/analysis/arabidopsis/Snakefile @@ -25,7 +25,7 @@ from tqdm import tqdm tqdm.pandas() from gpn.make_dataset_mlm import make_windows, get_seq -from gpn.utils import load_fasta, save_fasta, load_table, load_repeatmasker, Genome +from gpn.data import load_fasta, save_fasta, load_table, load_repeatmasker, Genome configfile: "config.yaml" diff --git a/analysis/arabidopsis/embedding_umap.ipynb b/analysis/arabidopsis/embedding_umap.ipynb index 93f89ec..bddae3b 100644 --- a/analysis/arabidopsis/embedding_umap.ipynb +++ b/analysis/arabidopsis/embedding_umap.ipynb @@ -28,7 +28,7 @@ "source": [ "import bioframe as bf\n", "import gpn.model\n", - "from gpn.utils import load_table, Genome\n", + "from gpn.data import load_table, Genome\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", diff --git a/analysis/arabidopsis/modisco_run.py b/analysis/arabidopsis/modisco_run.py index 928d8df..4a856ec 100644 --- a/analysis/arabidopsis/modisco_run.py +++ b/analysis/arabidopsis/modisco_run.py @@ -1,5 +1,5 @@ import bioframe as bf -#from gpn.utils import Genome +#from gpn.data import Genome import modiscolite import numpy as np import pandas as pd diff --git a/analysis/arabidopsis/motif_perplexity.ipynb b/analysis/arabidopsis/motif_perplexity.ipynb index 4d6da0b..fe974b3 100644 --- a/analysis/arabidopsis/motif_perplexity.ipynb +++ b/analysis/arabidopsis/motif_perplexity.ipynb @@ -28,7 +28,7 @@ "tqdm.pandas()\n", "\n", "from gpn.define_intervals import intersect_intervals\n", - "from gpn.utils import Genome, load_table" + "from gpn.data import Genome, load_table" ] }, { diff --git a/gpn/data.py b/gpn/data.py new file mode 100644 index 0000000..db2275c --- /dev/null +++ b/gpn/data.py @@ -0,0 +1,304 @@ +import gzip +from Bio import SeqIO, bgzf +from Bio.Seq import Seq +import bioframe as bf +from datasets import load_dataset +import numpy as np +import pandas as pd +from tqdm import tqdm +tqdm.pandas() + + +DEFINED_SYMBOLS = np.frombuffer("ACGTacgt".encode('ascii'), dtype="S1") +UNMASKED_SYMBOLS = np.frombuffer("ACGT".encode('ascii'), dtype="S1") + + +def load_fasta(path, subset_chroms=None): + with gzip.open(path, "rt") if path.endswith(".gz") else open(path) as handle: + genome = pd.Series({ + rec.id: str(rec.seq) for rec in SeqIO.parse(handle, "fasta") + if subset_chroms is None or rec.id in subset_chroms + }) + return genome + + +def save_fasta(path, genome): + with bgzf.BgzfWriter(path, "wb") if path.endswith(".gz") else open(path, "w") as handle: + SeqIO.write(genome.values(), handle, "fasta") + + +# Some standard formats +def load_table(path): + if path.endswith('.parquet'): + df = pd.read_parquet(path) + elif 'csv' in path: + df = pd.read_csv(path) + elif 'tsv' in path: + df = pd.read_csv(path, sep='\t') + elif 'vcf' in path: + df = pd.read_csv( + path, sep="\t", header=None, comment="#", usecols=[0,1,3,4], dtype={0: str}, + ).rename(columns={0: 'chrom', 1: 'pos', 3: 'ref', 4: 'alt'}) + elif 'gtf' in path or 'gff' in path: + df = pd.read_csv( + path, + sep="\t", + header=None, + comment="#", + dtype={"chrom": str}, + names=[ + "chrom", + "source", + "feature", + "start", + "end", + "score", + "strand", + "frame", + "attribute", + ], + ) + df.start -= 1 + df.chrom = df.chrom.astype(str) + return df + + +def load_repeatmasker(path): + df = pd.read_csv(path, sep="\t").rename( + columns=dict(genoName="chrom", genoStart="start", genoEnd="end") + ) + df.chrom = df.chrom.astype(str) + return df + + +class Genome: + def __init__(self, path, subset_chroms=None): + self._genome = load_fasta(path, subset_chroms=subset_chroms) + + def get_seq(self, chrom, start, end, strand="+"): + seq = self._genome[chrom][start:end] + if strand == "-": + seq = str(Seq(seq).reverse_complement()) + return seq + + def get_nuc(self, chrom, pos, strand="+"): + # pos is assumed to be 1-based as in VCF + seq = self._genome[chrom][pos-1] + if strand == "-": + seq = str(Seq(seq).reverse_complement()) + return seq + + def filter_chroms(self, chroms): + self._genome = self._genome[chroms] + + def get_seq_fwd_rev(self, chrom, start, end): + seq_fwd = self.get_seq(chrom, start, end) + seq_rev = str(Seq(seq_fwd).reverse_complement()) + return seq_fwd, seq_rev + + def get_all_intervals(self): + return pd.DataFrame([ + {"chrom": chrom, "start": 0, "end": len(seq)} + for chrom, seq in self._genome.items() + ]) + + def get_intervals_matching_symbols(self, symbols): + def get_intervals_matching_symbols_chrom(chrom): + complete_interval = pd.DataFrame({"chrom": [chrom.name], "start": [0], "end": [len(chrom.seq)]}) + intervals = pd.DataFrame(dict( + start=np.where(~np.isin(np.frombuffer(chrom.seq.encode("ascii"), dtype="S1"), symbols))[0] + )) + if len(intervals) > 0: + intervals["chrom"] = chrom.name + intervals["end"] = intervals.start + 1 + intervals = bf.merge(intervals).drop(columns="n_intervals") + return bf.subtract(complete_interval, intervals) + return complete_interval + + return pd.concat( + self._genome.rename("seq").to_frame().progress_apply( + get_intervals_matching_symbols_chrom, axis=1, + ).values, + ignore_index=True, + ) + + def get_defined_intervals(self): + return self.get_intervals_matching_symbols(DEFINED_SYMBOLS) + + def get_unmasked_intervals(self): + return self.get_intervals_matching_symbols(UNMASKED_SYMBOLS) + + +def add_space_every_k(seq, k): + return " ".join([seq[x:x+k] for x in range(0, len(seq), k)]) + + +def load_dataset_from_file_or_dir( + path, split="test", format="parquet", is_file=False, **kwargs, +): + # TODO: should add handling of vcf, could use load_table and create dataset + # from pandas df + if is_file: + return load_dataset(format, data_files=path, split="train", **kwargs) + else: + return load_dataset(path, split=split, **kwargs) + + +def token_input_id(token, tokenizer, n_prefix=0): + return tokenizer(token)["input_ids"][n_prefix] + + +# TODO: maybe call it I or Is, or ivals instead of intervals + + +def get_annotation_features(annotation, feature): + annotation_features = annotation[annotation.feature == feature] + return bf.merge(bf.sanitize_bedframe(annotation_features[["chrom", "start", "end"]])) + + +def intersect_intervals(a, b): + return bf.overlap(a, b, how="inner", return_overlap=True)[ + ["chrom", "overlap_start", "overlap_end"] + ].rename(columns=dict(overlap_start="start", overlap_end="end")) + + +def union_intervals(a, b): + return bf.merge(pd.concat([a, b], ignore_index=True)).drop(columns="n_intervals") + + +def intervals_size(intervals): + return (intervals.end-intervals.start).sum() + + +def add_flank(intervals, flank): + return bf.merge(bf.expand(intervals, pad=flank)).drop(columns="n_intervals") + + +def add_jitter(intervals, magnitude, seed=42): + # After using this function, we recommend intersecting with + # Genome.get_all_intervals(), to avoid getting out of chromosome bounds + # or smaller subsets such as Genome.get_defined_intervals() + rng = np.random.default_rng(seed) + jitter = rng.integers(-magnitude, magnitude, size=len(intervals), endpoint=True) + new_intervals = intervals.copy() + new_intervals.start += jitter + new_intervals.end += jitter + return bf.merge(new_intervals) + + +def filter_length(intervals, min_interval_len): + return intervals[intervals.end-intervals.start>=min_interval_len] + + +def filter_defined(intervals, genome, include_flank=None): + defined = genome.get_defined_intervals() + if include_flank is not None: + defined = add_flank(defined, include_flank) + return intersect_intervals(intervals, defined) + + +def filter_unmasked(intervals, genome, include_flank=None): + unmasked = genome.get_unmasked_intervals() + if include_flank is not None: + unmasked = add_flank(unmasked, include_flank) + return intersect_intervals(intervals, unmasked) + + +def filter_annotation_features( + intervals, annotation, feature, include_flank=None, jitter=None, +): + annotation_features = get_annotation_features(annotation, feature) + if include_flank is not None: + annotation_features = add_flank(annotation_features, include_flank) + if jitter is not None: + annotation_features = add_jitter(annotation_features, jitter) + return intersect_intervals(intervals, annotation_features) + + +def get_promoters(annotation, upstream_size, downstream_size=0): + # not exactly getting promoters, just gettting regions upstream of TSS + + def get_promoter(transcript): + if transcript.strand == "+": + start, end = transcript.start-upstream_size, transcript.start+downstream_size + else: + start, end = transcript.end-downstream_size, transcript.end+upstream_size + return pd.Series(dict(chrom=transcript.chrom, start=start, end=end)) + + transcripts = annotation[annotation.feature.isin(["mRNA", "transcript"])] + promoters = transcripts.apply(get_promoter, axis=1) + return bf.merge(promoters).drop(columns="n_intervals") + + +def get_random_intervals(intervals, size, n, seed=42): + rng = np.random.default_rng(seed) + interval_size = (intervals.end-intervals.start).values + # the number of random intervals that can be generated per interval + # e.g. if target size is 512, an interval of size 512 can produce 1 interval, + # and interval of size 513 can produce 2 intervals + interval_w = 1 + interval_size - size + interval_p = interval_w / interval_w.sum() + rand_interval_index = rng.choice(len(intervals), p=interval_p, size=n) + + rand_intervals = [] + for i in range(n): + interval = intervals.iloc[rand_interval_index[i]] + start = rng.integers(interval.start, interval.end - size, endpoint=True) + end = start + size + rand_intervals.append([interval.chrom, start, end]) + rand_intervals = pd.DataFrame(rand_intervals, columns=["chrom", "start", "end"]) + return bf.merge(rand_intervals).drop(columns="n_intervals") + + +def get_balanced_intervals(defined_intervals, annotation, window_size, promoter_upstream=1000): + # there's the issue of pseudogenes though... should be aware + exons = add_flank(get_annotation_features(annotation, "exon"), window_size//2) + print("exons: ", intervals_size(exons)/intervals_size(defined_intervals)) + promoters = add_flank(get_promoters(annotation, promoter_upstream), window_size//2) + print("promoters: ", intervals_size(promoters)/intervals_size(defined_intervals)) + intervals = union_intervals(exons, promoters) + intervals = intersect_intervals(add_jitter(intervals, 100), defined_intervals) + # in case they collide with undefined intervals + intervals = filter_length(intervals, window_size) + print("intervals: ", intervals_size(intervals)/intervals_size(defined_intervals)) + # maybe add a 0.5 factor + n_random_intervals = intervals_size(intervals) // window_size + random_intervals = get_random_intervals(defined_intervals, window_size, n_random_intervals) + print("random_intervals: ", intervals_size(random_intervals)/intervals_size(defined_intervals)) + intervals = union_intervals(intervals, random_intervals) + print("intervals: ", intervals_size(intervals)/intervals_size(defined_intervals)) + print((intervals.end-intervals.start).min()) + assert (intervals.end-intervals.start).min() >= window_size + return intervals + + +def make_windows(intervals, window_size, step_size, add_rc=False): + return pd.concat( + intervals.progress_apply( + lambda interval: get_interval_windows(interval, window_size, step_size, add_rc), axis=1, + ).values, + ignore_index=True, + ) + + +def get_interval_windows(interval, window_size, step_size, add_rc): + windows = pd.DataFrame( + dict(start=np.arange(interval.start, interval.end-window_size+1, step_size)) + ) + windows["end"] = windows.start + window_size + windows["chrom"] = interval.chrom + windows = windows[["chrom", "start", "end"]] # just re-ordering + windows["strand"] = "+" + if add_rc: + windows_neg = windows.copy() # TODO: this should be optional + windows_neg.strand = "-" + return pd.concat([windows, windows_neg], ignore_index=True) + return windows + + +def get_seq(intervals, genome): + intervals["seq"] = intervals.progress_apply( + lambda i: genome.get_seq(i.chrom, i.start, i.end, i.strand), + axis=1, + ) + return intervals \ No newline at end of file diff --git a/gpn/define_intervals.py b/gpn/define_intervals.py index 59f1213..51a23ee 100644 --- a/gpn/define_intervals.py +++ b/gpn/define_intervals.py @@ -1,137 +1,10 @@ import argparse -from Bio import SeqIO import bioframe as bf -from gpn.utils import load_table, Genome -import gzip -import numpy as np -import pandas as pd -from tqdm import tqdm -tqdm.pandas() - - -# TODO: maybe call it I or Is, or ivals instead of intervals - - -def get_annotation_features(annotation, feature): - annotation_features = annotation[annotation.feature == feature] - return bf.merge(bf.sanitize_bedframe(annotation_features[["chrom", "start", "end"]])) - - -def intersect_intervals(a, b): - return bf.overlap(a, b, how="inner", return_overlap=True)[ - ["chrom", "overlap_start", "overlap_end"] - ].rename(columns=dict(overlap_start="start", overlap_end="end")) - - -def union_intervals(a, b): - return bf.merge(pd.concat([a, b], ignore_index=True)).drop(columns="n_intervals") - - -def intervals_size(intervals): - return (intervals.end-intervals.start).sum() - - -def add_flank(intervals, flank): - return bf.merge(bf.expand(intervals, pad=flank)).drop(columns="n_intervals") - - -def add_jitter(intervals, magnitude, seed=42): - # After using this function, we recommend intersecting with - # Genome.get_all_intervals(), to avoid getting out of chromosome bounds - # or smaller subsets such as Genome.get_defined_intervals() - rng = np.random.default_rng(seed) - jitter = rng.integers(-magnitude, magnitude, size=len(intervals), endpoint=True) - new_intervals = intervals.copy() - new_intervals.start += jitter - new_intervals.end += jitter - return bf.merge(new_intervals) - - -def filter_length(intervals, min_interval_len): - return intervals[intervals.end-intervals.start>=min_interval_len] - - -def filter_defined(intervals, genome, include_flank=None): - defined = genome.get_defined_intervals() - if include_flank is not None: - defined = add_flank(defined, include_flank) - return intersect_intervals(intervals, defined) - - -def filter_unmasked(intervals, genome, include_flank=None): - unmasked = genome.get_unmasked_intervals() - if include_flank is not None: - unmasked = add_flank(unmasked, include_flank) - return intersect_intervals(intervals, unmasked) - - -def filter_annotation_features( - intervals, annotation, feature, include_flank=None, jitter=None, -): - annotation_features = get_annotation_features(annotation, feature) - if include_flank is not None: - annotation_features = add_flank(annotation_features, include_flank) - if jitter is not None: - annotation_features = add_jitter(annotation_features, jitter) - return intersect_intervals(intervals, annotation_features) - - -def get_promoters(annotation, upstream_size, downstream_size=0): - # not exactly getting promoters, just gettting regions upstream of TSS - - def get_promoter(transcript): - if transcript.strand == "+": - start, end = transcript.start-upstream_size, transcript.start+downstream_size - else: - start, end = transcript.end-downstream_size, transcript.end+upstream_size - return pd.Series(dict(chrom=transcript.chrom, start=start, end=end)) - - transcripts = annotation[annotation.feature.isin(["mRNA", "transcript"])] - promoters = transcripts.apply(get_promoter, axis=1) - return bf.merge(promoters).drop(columns="n_intervals") - - -def get_random_intervals(intervals, size, n, seed=42): - rng = np.random.default_rng(seed) - interval_size = (intervals.end-intervals.start).values - # the number of random intervals that can be generated per interval - # e.g. if target size is 512, an interval of size 512 can produce 1 interval, - # and interval of size 513 can produce 2 intervals - interval_w = 1 + interval_size - size - interval_p = interval_w / interval_w.sum() - rand_interval_index = rng.choice(len(intervals), p=interval_p, size=n) - - rand_intervals = [] - for i in range(n): - interval = intervals.iloc[rand_interval_index[i]] - start = rng.integers(interval.start, interval.end - size, endpoint=True) - end = start + size - rand_intervals.append([interval.chrom, start, end]) - rand_intervals = pd.DataFrame(rand_intervals, columns=["chrom", "start", "end"]) - return bf.merge(rand_intervals).drop(columns="n_intervals") - - -def get_balanced_intervals(defined_intervals, annotation, window_size, promoter_upstream=1000): - # there's the issue of pseudogenes though... should be aware - exons = add_flank(get_annotation_features(annotation, "exon"), window_size//2) - print("exons: ", intervals_size(exons)/intervals_size(defined_intervals)) - promoters = add_flank(get_promoters(annotation, promoter_upstream), window_size//2) - print("promoters: ", intervals_size(promoters)/intervals_size(defined_intervals)) - intervals = union_intervals(exons, promoters) - intervals = intersect_intervals(add_jitter(intervals, 100), defined_intervals) - # in case they collide with undefined intervals - intervals = filter_length(intervals, window_size) - print("intervals: ", intervals_size(intervals)/intervals_size(defined_intervals)) - # maybe add a 0.5 factor - n_random_intervals = intervals_size(intervals) // window_size - random_intervals = get_random_intervals(defined_intervals, window_size, n_random_intervals) - print("random_intervals: ", intervals_size(random_intervals)/intervals_size(defined_intervals)) - intervals = union_intervals(intervals, random_intervals) - print("intervals: ", intervals_size(intervals)/intervals_size(defined_intervals)) - print((intervals.end-intervals.start).min()) - assert (intervals.end-intervals.start).min() >= window_size - return intervals +from .data import ( + load_table, Genome, filter_defined, filter_unmasked, filter_length, + filter_annotation_features, add_jitter, add_flank, +) def main(args): diff --git a/gpn/get_embeddings.py b/gpn/get_embeddings.py index c8e11af..e9642c5 100644 --- a/gpn/get_embeddings.py +++ b/gpn/get_embeddings.py @@ -11,7 +11,7 @@ from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments import gpn.model -from gpn.utils import Genome, load_dataset_from_file_or_dir +from gpn.data import Genome, load_dataset_from_file_or_dir class ModelCenterEmbedding(torch.nn.Module): diff --git a/gpn/get_logits.py b/gpn/get_logits.py index 19d1c3c..c2e4d27 100644 --- a/gpn/get_logits.py +++ b/gpn/get_logits.py @@ -11,7 +11,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments import gpn.model -from gpn.utils import Genome, load_dataset_from_file_or_dir, token_input_id +from gpn.data import Genome, load_dataset_from_file_or_dir, token_input_id class MLMforLogitsModel(torch.nn.Module): diff --git a/gpn/make_dataset_mlm.py b/gpn/make_dataset_mlm.py index 07179b5..41e329c 100644 --- a/gpn/make_dataset_mlm.py +++ b/gpn/make_dataset_mlm.py @@ -1,44 +1,8 @@ import argparse import numpy as np import pandas as pd -#from pandarallel import pandarallel -#pandarallel.initialize(progress_bar=True) -from tqdm import tqdm -tqdm.pandas() -from .utils import Genome, load_table - - -def make_windows(intervals, window_size, step_size, add_rc=False): - return pd.concat( - intervals.progress_apply( - lambda interval: get_interval_windows(interval, window_size, step_size, add_rc), axis=1, - ).values, - ignore_index=True, - ) - - -def get_interval_windows(interval, window_size, step_size, add_rc): - windows = pd.DataFrame( - dict(start=np.arange(interval.start, interval.end-window_size+1, step_size)) - ) - windows["end"] = windows.start + window_size - windows["chrom"] = interval.chrom - windows = windows[["chrom", "start", "end"]] # just re-ordering - windows["strand"] = "+" - if add_rc: - windows_neg = windows.copy() # TODO: this should be optional - windows_neg.strand = "-" - return pd.concat([windows, windows_neg], ignore_index=True) - return windows - - -def get_seq(intervals, genome): - intervals["seq"] = intervals.progress_apply( - lambda i: genome.get_seq(i.chrom, i.start, i.end, i.strand), - axis=1, - ) - return intervals +from .data import Genome, load_table, make_windows, get_interval_windows, get_seq if __name__ == "__main__": diff --git a/gpn/run_vep.py b/gpn/run_vep.py index cedae30..ebae353 100644 --- a/gpn/run_vep.py +++ b/gpn/run_vep.py @@ -11,7 +11,7 @@ from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments import gpn.model -from gpn.utils import Genome, load_dataset_from_file_or_dir, token_input_id +from gpn.data import Genome, load_dataset_from_file_or_dir, token_input_id class MLMforVEPModel(torch.nn.Module): diff --git a/gpn/utils.py b/gpn/utils.py deleted file mode 100644 index f7000e5..0000000 --- a/gpn/utils.py +++ /dev/null @@ -1,148 +0,0 @@ -import gzip -from Bio import SeqIO, bgzf -from Bio.Seq import Seq -import bioframe as bf -from datasets import load_dataset -import numpy as np -import pandas as pd -from tqdm import tqdm -tqdm.pandas() - - -DEFINED_SYMBOLS = np.frombuffer("ACGTacgt".encode('ascii'), dtype="S1") -UNMASKED_SYMBOLS = np.frombuffer("ACGT".encode('ascii'), dtype="S1") - - -def load_fasta(path, subset_chroms=None): - with gzip.open(path, "rt") if path.endswith(".gz") else open(path) as handle: - genome = pd.Series({ - rec.id: str(rec.seq) for rec in SeqIO.parse(handle, "fasta") - if subset_chroms is None or rec.id in subset_chroms - }) - return genome - - -def save_fasta(path, genome): - with bgzf.BgzfWriter(path, "wb") if path.endswith(".gz") else open(path, "w") as handle: - SeqIO.write(genome.values(), handle, "fasta") - - -# Some standard formats -def load_table(path): - if path.endswith('.parquet'): - df = pd.read_parquet(path) - elif 'csv' in path: - df = pd.read_csv(path) - elif 'tsv' in path: - df = pd.read_csv(path, sep='\t') - elif 'vcf' in path: - df = pd.read_csv( - path, sep="\t", header=None, comment="#", usecols=[0,1,3,4], dtype={0: str}, - ).rename(columns={0: 'chrom', 1: 'pos', 3: 'ref', 4: 'alt'}) - elif 'gtf' in path or 'gff' in path: - df = pd.read_csv( - path, - sep="\t", - header=None, - comment="#", - dtype={"chrom": str}, - names=[ - "chrom", - "source", - "feature", - "start", - "end", - "score", - "strand", - "frame", - "attribute", - ], - ) - df.start -= 1 - df.chrom = df.chrom.astype(str) - return df - - -def load_repeatmasker(path): - df = pd.read_csv(path, sep="\t").rename( - columns=dict(genoName="chrom", genoStart="start", genoEnd="end") - ) - df.chrom = df.chrom.astype(str) - return df - - -class Genome: - def __init__(self, path, subset_chroms=None): - self._genome = load_fasta(path, subset_chroms=subset_chroms) - - def get_seq(self, chrom, start, end, strand="+"): - seq = self._genome[chrom][start:end] - if strand == "-": - seq = str(Seq(seq).reverse_complement()) - return seq - - def get_nuc(self, chrom, pos, strand="+"): - # pos is assumed to be 1-based as in VCF - seq = self._genome[chrom][pos-1] - if strand == "-": - seq = str(Seq(seq).reverse_complement()) - return seq - - def filter_chroms(self, chroms): - self._genome = self._genome[chroms] - - def get_seq_fwd_rev(self, chrom, start, end): - seq_fwd = self.get_seq(chrom, start, end) - seq_rev = str(Seq(seq_fwd).reverse_complement()) - return seq_fwd, seq_rev - - def get_all_intervals(self): - return pd.DataFrame([ - {"chrom": chrom, "start": 0, "end": len(seq)} - for chrom, seq in self._genome.items() - ]) - - def get_intervals_matching_symbols(self, symbols): - def get_intervals_matching_symbols_chrom(chrom): - complete_interval = pd.DataFrame({"chrom": [chrom.name], "start": [0], "end": [len(chrom.seq)]}) - intervals = pd.DataFrame(dict( - start=np.where(~np.isin(np.frombuffer(chrom.seq.encode("ascii"), dtype="S1"), symbols))[0] - )) - if len(intervals) > 0: - intervals["chrom"] = chrom.name - intervals["end"] = intervals.start + 1 - intervals = bf.merge(intervals).drop(columns="n_intervals") - return bf.subtract(complete_interval, intervals) - return complete_interval - - return pd.concat( - self._genome.rename("seq").to_frame().progress_apply( - get_intervals_matching_symbols_chrom, axis=1, - ).values, - ignore_index=True, - ) - - def get_defined_intervals(self): - return self.get_intervals_matching_symbols(DEFINED_SYMBOLS) - - def get_unmasked_intervals(self): - return self.get_intervals_matching_symbols(UNMASKED_SYMBOLS) - - -def add_space_every_k(seq, k): - return " ".join([seq[x:x+k] for x in range(0, len(seq), k)]) - - -def load_dataset_from_file_or_dir( - path, split="test", format="parquet", is_file=False, **kwargs, -): - # TODO: should add handling of vcf, could use load_table and create dataset - # from pandas df - if is_file: - return load_dataset(format, data_files=path, split="train", **kwargs) - else: - return load_dataset(path, split=split, **kwargs) - - -def token_input_id(token, tokenizer, n_prefix=0): - return tokenizer(token)["input_ids"][n_prefix] diff --git a/setup.py b/setup.py index c872116..88cefaf 100644 --- a/setup.py +++ b/setup.py @@ -13,12 +13,13 @@ "einops", "pandarallel", "bioframe", + "zstandard", ] setup( name='gpn', - version='0.2', + version='0.3', description='gpn', url='http://github.com/songlab-cal/gpn', author='Gonzalo Benegas', diff --git a/workflow/make_dataset/README.md b/workflow/make_dataset/README.md new file mode 100644 index 0000000..8bd0e63 --- /dev/null +++ b/workflow/make_dataset/README.md @@ -0,0 +1,33 @@ +# Workflow to create a training dataset +[Example dataset](https://huggingface.co/datasets/gonzalobenegas/example_dataset) (with default config, should take 5 minutes) +1. Download data from ncbi given a list of accessions, or alternatively, use your own fasta files. +2. Define a set of training intervals, e.g. full chromosomes, only exons, etc. +3. Shard the dataset for efficient loading with Hugging Face libraries +4. Optional: upload to Hugging Face Hub + +## Requirements: +- [GPN](https://github.com/songlab-cal/gpn) +- [Snakemake](https://snakemake.github.io/) +- If you want to automatically download data from NCBI, install [NCBI Datasets](https://www.ncbi.nlm.nih.gov/datasets/docs/v2/download-and-install/) (e.g. `conda install -c conda-forge ncbi-datasets-cli`). + +## Configuration: +- See `config\config.yaml` and `config\assemblies.tsv` +- Check notes in `workflow/Snakefile` for running with your own set of fasta files. + +## Running: +- `snakemake --cores all` + +## Uploading to Hugging Face Hub: +For easy distribution and deployment, the dataset can be uploaded to HF Hub (optionally, as a private dataset). +It can be automatically streamed during training (no need to fully download the data locally). +Make sure to first install [HF Hub client library](https://huggingface.co/docs/huggingface_hub/index). +```python +from huggingface_hub import HfApi +api = HfApi() + +private = False +repo_id = "gonzalobenegas/example_dataset" # replace with your username, dataset name +folder_path = "results/dataset" +api.create_repo(repo_id=repo_id, repo_type="dataset", private=private) +api.upload_folder(repo_id=repo_id, folder_path=folder_path, repo_type="dataset") +``` diff --git a/workflow/make_dataset/config/assemblies.tsv b/workflow/make_dataset/config/assemblies.tsv new file mode 100644 index 0000000..5f0725c --- /dev/null +++ b/workflow/make_dataset/config/assemblies.tsv @@ -0,0 +1,4 @@ +Assembly Accession Assembly Name Organism Name Organism Infraspecific Names Breed Organism Infraspecific Names Strain Organism Infraspecific Names Cultivar Organism Infraspecific Names Ecotype Organism Infraspecific Names Isolate Organism Infraspecific Names Sex Annotation Name Assembly Stats Total Sequence Length Assembly Level Assembly Submission Date WGS project accession genus Priority +GCF_000001735.4 TAIR10.1 Arabidopsis thaliana Columbia Annotation submitted by TAIR and Araport 119146348 Chromosome 2018-03-15 Arabidopsis 0_High +GCF_000150535.2 Papaya1.0 Carica papaya SunUp NCBI Annotation Release 100 369781828 Scaffold 2008-05-06 ABIM01 Carica 1_Low +GCF_000801105.1 Rs1.0 Raphanus sativus WK10039 NCBI Annotation Release 100 426202243 Scaffold 2015-09-29 JRUI02 Raphanus 1_Low diff --git a/workflow/make_dataset/config/config.yaml b/workflow/make_dataset/config/config.yaml new file mode 100644 index 0000000..f47320a --- /dev/null +++ b/workflow/make_dataset/config/config.yaml @@ -0,0 +1,31 @@ +# assumes the first column contains the assembly name +assemblies_path: "config/assemblies.tsv" + +# Intervals from fasta file used for training: +# - "all": all positions +# - "defined": positions with defined nucleotides (not N) +# - "annotation_{feature}": only positions from annotation, e.g. CDS, exon +# - "balanced_v1": recipe used in original paper +target_intervals: "all" + +window_size: 512 +step_size: 256 +add_rc: True + +# chroms will be randomly assigned to splits +split_proportion: + train: 0.99 + validation: 0.005 + test: 0.005 + +# this chroms are forced to be in validation set +whitelist_validation_chroms: +- "NC_003075.7" # Arabidopsis thaliana chr4 +# this chroms are forced to be in test set +whitelist_test_chroms: +- "NC_003076.8" # Arabidopsis thaliana chr5 + +# We want to split data into shards of e.g. ~100MB each +# It's good to have at least num_cpus shards to increase parallel loading speed +# of iterable datasets from HF hub +samples_per_file: 500_000 diff --git a/workflow/make_dataset/workflow/Snakefile b/workflow/make_dataset/workflow/Snakefile new file mode 100644 index 0000000..e695e4f --- /dev/null +++ b/workflow/make_dataset/workflow/Snakefile @@ -0,0 +1,21 @@ +import pandas as pd + + +configfile: "config/config.yaml" +print(config) + +assemblies = pd.read_csv(config["assemblies_path"], sep="\t", index_col=0) +splits = ["train", "validation", "test"] + +# comment out if you have your own fasta files +# and make sure you have genomes (and annotations, if applicable) in the right place +# results/genome/{assembly}.fa.gz (and results/annotation/{assembly}.gff.gz) +include: "rules/download.smk" + +include: "rules/intervals.smk" +include: "rules/dataset.smk" + + +rule all: + input: + expand("results/dataset/data/{split}", split=splits), diff --git a/workflow/make_dataset/workflow/rules/dataset.smk b/workflow/make_dataset/workflow/rules/dataset.smk new file mode 100644 index 0000000..f65728c --- /dev/null +++ b/workflow/make_dataset/workflow/rules/dataset.smk @@ -0,0 +1,81 @@ +from gpn.data import Genome, make_windows, get_seq +import math +import numpy as np +import os +import pandas as pd +from tqdm import tqdm + + +split_proportions = [config["split_proportion"][split] for split in splits] +assert np.isclose(sum(split_proportions), 1) + + +rule make_dataset_assembly: + input: + lambda wildcards: f"results/intervals/{wildcards['assembly']}/{config['target_intervals']}.parquet", + "results/genome/{assembly}.fa.gz", + output: + temp(expand("results/dataset_assembly/{{assembly}}/{split}.parquet", split=splits)), + threads: 2 + run: + intervals = pd.read_parquet(input[0]) + genome = Genome(input[1]) + intervals = make_windows( + intervals, config["window_size"], config["step_size"], config["add_rc"], + ) + print(intervals) + intervals = intervals.sample(frac=1.0, random_state=42) + intervals["assembly"] = wildcards["assembly"] + intervals = intervals[["assembly", "chrom", "start", "end", "strand"]] + intervals = get_seq(intervals, genome) + print(intervals) + + chroms = intervals.chrom.unique() + chrom_split = np.random.choice( + splits, p=split_proportions, size=len(chroms), + ) + chrom_split[np.isin(chroms, config["whitelist_validation_chroms"])] = "validation" + chrom_split[np.isin(chroms, config["whitelist_test_chroms"])] = "test" + chrom_split = pd.Series(chrom_split, index=chroms) + + intervals_split = chrom_split[intervals.chrom] + + for path, split in zip(output, splits): + print(path, split) + # to parquet to be able to load faster later + intervals[(intervals_split==split).values].to_parquet( + path, index=False, + ) + + +# before uploading to HF Hub, remove data/split/.snakemake_timestamp files +rule merge_datasets: + input: + expand("results/dataset_assembly/{assembly}/{{split}}.parquet", assembly=assemblies.index), + output: + directory("results/dataset/data/{split}"), + threads: workflow.cores + run: + intervals = pd.concat( + tqdm((pd.read_parquet(path) for path in input), total=len(input)), + ignore_index=True, + ).sample(frac=1, random_state=42) + print(intervals) + + if config.get("subsample_to_target", False) and wildcards.split == "train": + n_target = (intervals.assembly==config["target_assembly"]).sum() + intervals = intervals.groupby("assembly").sample( + n=n_target, random_state=42 + ).sample(frac=1, random_state=42) + print(wildcards.split, intervals.assembly.value_counts()) + print(intervals) + + n_shards = math.ceil(len(intervals) / config["samples_per_file"]) + assert n_shards < 10000 + os.makedirs(output[0]) + for i in tqdm(range(n_shards)): + path = Path(output[0]) / f"shard_{i:05}.jsonl.zst" + intervals.iloc[i::n_shards].to_json( + path, orient="records", lines=True, + compression={'method': 'zstd', 'threads': -1} + ) diff --git a/workflow/make_dataset/workflow/rules/download.smk b/workflow/make_dataset/workflow/rules/download.smk new file mode 100644 index 0000000..06035b2 --- /dev/null +++ b/workflow/make_dataset/workflow/rules/download.smk @@ -0,0 +1,25 @@ +assemblies["Assembly Name"] = assemblies["Assembly Name"].str.replace(" ", "_") +assemblies["genome_path"] = ( + "tmp/" + assemblies.index + "/ncbi_dataset/data/" + assemblies.index + "/" + + assemblies.index + "_" + assemblies["Assembly Name"] + "_genomic.fna" +) +assemblies["annotation_path"] = ( + "tmp/" + assemblies.index + "/ncbi_dataset/data/" + assemblies.index + "/genomic.gff" +) + + +rule download_genome: + output: + "results/genome/{assembly}.fa.gz", + "results/annotation/{assembly}.gff.gz", + params: + tmp_dir=directory("tmp/{assembly}"), + genome_path=lambda wildcards: assemblies.loc[wildcards.assembly, "genome_path"], + annotation_path=lambda wildcards: assemblies.loc[wildcards.assembly, "annotation_path"], + shell: + """ + mkdir -p {params.tmp_dir} && cd {params.tmp_dir} && + datasets download genome accession {wildcards.assembly} --include genome,gff3 \ + && unzip ncbi_dataset.zip && cd - && gzip -c {params.genome_path} > {output[0]}\ + && gzip -c {params.annotation_path} > {output[1]} && rm -r {params.tmp_dir} + """ diff --git a/workflow/make_dataset/workflow/rules/intervals.smk b/workflow/make_dataset/workflow/rules/intervals.smk new file mode 100644 index 0000000..b155c8e --- /dev/null +++ b/workflow/make_dataset/workflow/rules/intervals.smk @@ -0,0 +1,65 @@ +from gpn.data import ( + Genome, load_table, get_balanced_intervals, filter_length, + filter_annotation_features, +) + + +rule make_all_intervals: + input: + "results/genome/{assembly}.fa.gz", + output: + "results/intervals/{assembly}/all.parquet", + threads: 2 + run: + I = Genome(input[0]).get_all_intervals() + I = filter_length(I, config["window_size"]) + I.to_parquet(output[0], index=False) + + +rule make_defined_intervals: + input: + "results/genome/{assembly}.fa.gz", + output: + "results/intervals/{assembly}/defined.parquet", + threads: 2 + run: + I = Genome(input[0]).get_defined_intervals() + I = filter_length(I, config["window_size"]) + I.to_parquet(output[0], index=False) + + +rule make_annotation_intervals: + input: + "results/intervals/{assembly}/defined.parquet", + "results/annotation/{assembly}.gff.gz", + output: + "results/intervals/{assembly}/annotation_{feature}.parquet", + run: + I = pd.read_parquet(input[0]) + annotation = load_table(input[1]) + include_flank = config.get( + "annotation_features_include_flank", config["window_size"] // 2 + ) + add_jiter = config.get("annotation_features_add_jitter", 100) + I = filter_annotation_features( + I, annotation, wildcards.feature, + include_flank=include_flank, jitter=add_jitter, + ) + I = filter_length(I, config["window_size"]) + I.to_parquet(output[0], index=False) + + +rule make_balanced_v1_intervals: + input: + "results/intervals/{assembly}/defined.parquet", + "results/annotation/{assembly}.gff.gz", + output: + "results/intervals/{assembly}/balanced_v1.parquet", + run: + defined_intervals = load_table(input[0]) + annotation = load_table(input[1]) + intervals = get_balanced_intervals( + defined_intervals, annotation, config["window_size"], + config.get("promoter_upstream", 1000), + ) + intervals.to_parquet(output[0], index=False) diff --git a/workflow/make_dataset_from_ncbi/Snakefile b/workflow/make_dataset_from_ncbi/Snakefile index 3ad1a3e..daa8549 100644 --- a/workflow/make_dataset_from_ncbi/Snakefile +++ b/workflow/make_dataset_from_ncbi/Snakefile @@ -1,6 +1,6 @@ from gpn.define_intervals import load_table, get_balanced_intervals from gpn.make_dataset_mlm import make_windows, get_seq -from gpn.utils import Genome +from gpn.data import Genome import gzip import math import numpy as np