Skip to content

Commit

Permalink
Add new make_dataset workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
gonzalobenegas committed Sep 1, 2023
1 parent e7a8d0d commit 1a979b2
Show file tree
Hide file tree
Showing 20 changed files with 579 additions and 325 deletions.
2 changes: 1 addition & 1 deletion analysis/arabidopsis/Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ from tqdm import tqdm
tqdm.pandas()

from gpn.make_dataset_mlm import make_windows, get_seq
from gpn.utils import load_fasta, save_fasta, load_table, load_repeatmasker, Genome
from gpn.data import load_fasta, save_fasta, load_table, load_repeatmasker, Genome


configfile: "config.yaml"
Expand Down
2 changes: 1 addition & 1 deletion analysis/arabidopsis/embedding_umap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"source": [
"import bioframe as bf\n",
"import gpn.model\n",
"from gpn.utils import load_table, Genome\n",
"from gpn.data import load_table, Genome\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import pandas as pd\n",
Expand Down
2 changes: 1 addition & 1 deletion analysis/arabidopsis/modisco_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import bioframe as bf
#from gpn.utils import Genome
#from gpn.data import Genome
import modiscolite
import numpy as np
import pandas as pd
Expand Down
2 changes: 1 addition & 1 deletion analysis/arabidopsis/motif_perplexity.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
"tqdm.pandas()\n",
"\n",
"from gpn.define_intervals import intersect_intervals\n",
"from gpn.utils import Genome, load_table"
"from gpn.data import Genome, load_table"
]
},
{
Expand Down
304 changes: 304 additions & 0 deletions gpn/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
import gzip
from Bio import SeqIO, bgzf
from Bio.Seq import Seq
import bioframe as bf
from datasets import load_dataset
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()


DEFINED_SYMBOLS = np.frombuffer("ACGTacgt".encode('ascii'), dtype="S1")
UNMASKED_SYMBOLS = np.frombuffer("ACGT".encode('ascii'), dtype="S1")


def load_fasta(path, subset_chroms=None):
with gzip.open(path, "rt") if path.endswith(".gz") else open(path) as handle:
genome = pd.Series({
rec.id: str(rec.seq) for rec in SeqIO.parse(handle, "fasta")
if subset_chroms is None or rec.id in subset_chroms
})
return genome


def save_fasta(path, genome):
with bgzf.BgzfWriter(path, "wb") if path.endswith(".gz") else open(path, "w") as handle:
SeqIO.write(genome.values(), handle, "fasta")


# Some standard formats
def load_table(path):
if path.endswith('.parquet'):
df = pd.read_parquet(path)
elif 'csv' in path:
df = pd.read_csv(path)
elif 'tsv' in path:
df = pd.read_csv(path, sep='\t')
elif 'vcf' in path:
df = pd.read_csv(
path, sep="\t", header=None, comment="#", usecols=[0,1,3,4], dtype={0: str},
).rename(columns={0: 'chrom', 1: 'pos', 3: 'ref', 4: 'alt'})
elif 'gtf' in path or 'gff' in path:
df = pd.read_csv(
path,
sep="\t",
header=None,
comment="#",
dtype={"chrom": str},
names=[
"chrom",
"source",
"feature",
"start",
"end",
"score",
"strand",
"frame",
"attribute",
],
)
df.start -= 1
df.chrom = df.chrom.astype(str)
return df


def load_repeatmasker(path):
df = pd.read_csv(path, sep="\t").rename(
columns=dict(genoName="chrom", genoStart="start", genoEnd="end")
)
df.chrom = df.chrom.astype(str)
return df


class Genome:
def __init__(self, path, subset_chroms=None):
self._genome = load_fasta(path, subset_chroms=subset_chroms)

def get_seq(self, chrom, start, end, strand="+"):
seq = self._genome[chrom][start:end]
if strand == "-":
seq = str(Seq(seq).reverse_complement())
return seq

def get_nuc(self, chrom, pos, strand="+"):
# pos is assumed to be 1-based as in VCF
seq = self._genome[chrom][pos-1]
if strand == "-":
seq = str(Seq(seq).reverse_complement())
return seq

def filter_chroms(self, chroms):
self._genome = self._genome[chroms]

def get_seq_fwd_rev(self, chrom, start, end):
seq_fwd = self.get_seq(chrom, start, end)
seq_rev = str(Seq(seq_fwd).reverse_complement())
return seq_fwd, seq_rev

def get_all_intervals(self):
return pd.DataFrame([
{"chrom": chrom, "start": 0, "end": len(seq)}
for chrom, seq in self._genome.items()
])

def get_intervals_matching_symbols(self, symbols):
def get_intervals_matching_symbols_chrom(chrom):
complete_interval = pd.DataFrame({"chrom": [chrom.name], "start": [0], "end": [len(chrom.seq)]})
intervals = pd.DataFrame(dict(
start=np.where(~np.isin(np.frombuffer(chrom.seq.encode("ascii"), dtype="S1"), symbols))[0]
))
if len(intervals) > 0:
intervals["chrom"] = chrom.name
intervals["end"] = intervals.start + 1
intervals = bf.merge(intervals).drop(columns="n_intervals")
return bf.subtract(complete_interval, intervals)
return complete_interval

return pd.concat(
self._genome.rename("seq").to_frame().progress_apply(
get_intervals_matching_symbols_chrom, axis=1,
).values,
ignore_index=True,
)

def get_defined_intervals(self):
return self.get_intervals_matching_symbols(DEFINED_SYMBOLS)

def get_unmasked_intervals(self):
return self.get_intervals_matching_symbols(UNMASKED_SYMBOLS)


def add_space_every_k(seq, k):
return " ".join([seq[x:x+k] for x in range(0, len(seq), k)])


def load_dataset_from_file_or_dir(
path, split="test", format="parquet", is_file=False, **kwargs,
):
# TODO: should add handling of vcf, could use load_table and create dataset
# from pandas df
if is_file:
return load_dataset(format, data_files=path, split="train", **kwargs)
else:
return load_dataset(path, split=split, **kwargs)


def token_input_id(token, tokenizer, n_prefix=0):
return tokenizer(token)["input_ids"][n_prefix]


# TODO: maybe call it I or Is, or ivals instead of intervals


def get_annotation_features(annotation, feature):
annotation_features = annotation[annotation.feature == feature]
return bf.merge(bf.sanitize_bedframe(annotation_features[["chrom", "start", "end"]]))


def intersect_intervals(a, b):
return bf.overlap(a, b, how="inner", return_overlap=True)[
["chrom", "overlap_start", "overlap_end"]
].rename(columns=dict(overlap_start="start", overlap_end="end"))


def union_intervals(a, b):
return bf.merge(pd.concat([a, b], ignore_index=True)).drop(columns="n_intervals")


def intervals_size(intervals):
return (intervals.end-intervals.start).sum()


def add_flank(intervals, flank):
return bf.merge(bf.expand(intervals, pad=flank)).drop(columns="n_intervals")


def add_jitter(intervals, magnitude, seed=42):
# After using this function, we recommend intersecting with
# Genome.get_all_intervals(), to avoid getting out of chromosome bounds
# or smaller subsets such as Genome.get_defined_intervals()
rng = np.random.default_rng(seed)
jitter = rng.integers(-magnitude, magnitude, size=len(intervals), endpoint=True)
new_intervals = intervals.copy()
new_intervals.start += jitter
new_intervals.end += jitter
return bf.merge(new_intervals)


def filter_length(intervals, min_interval_len):
return intervals[intervals.end-intervals.start>=min_interval_len]


def filter_defined(intervals, genome, include_flank=None):
defined = genome.get_defined_intervals()
if include_flank is not None:
defined = add_flank(defined, include_flank)
return intersect_intervals(intervals, defined)


def filter_unmasked(intervals, genome, include_flank=None):
unmasked = genome.get_unmasked_intervals()
if include_flank is not None:
unmasked = add_flank(unmasked, include_flank)
return intersect_intervals(intervals, unmasked)


def filter_annotation_features(
intervals, annotation, feature, include_flank=None, jitter=None,
):
annotation_features = get_annotation_features(annotation, feature)
if include_flank is not None:
annotation_features = add_flank(annotation_features, include_flank)
if jitter is not None:
annotation_features = add_jitter(annotation_features, jitter)
return intersect_intervals(intervals, annotation_features)


def get_promoters(annotation, upstream_size, downstream_size=0):
# not exactly getting promoters, just gettting regions upstream of TSS

def get_promoter(transcript):
if transcript.strand == "+":
start, end = transcript.start-upstream_size, transcript.start+downstream_size
else:
start, end = transcript.end-downstream_size, transcript.end+upstream_size
return pd.Series(dict(chrom=transcript.chrom, start=start, end=end))

transcripts = annotation[annotation.feature.isin(["mRNA", "transcript"])]
promoters = transcripts.apply(get_promoter, axis=1)
return bf.merge(promoters).drop(columns="n_intervals")


def get_random_intervals(intervals, size, n, seed=42):
rng = np.random.default_rng(seed)
interval_size = (intervals.end-intervals.start).values
# the number of random intervals that can be generated per interval
# e.g. if target size is 512, an interval of size 512 can produce 1 interval,
# and interval of size 513 can produce 2 intervals
interval_w = 1 + interval_size - size
interval_p = interval_w / interval_w.sum()
rand_interval_index = rng.choice(len(intervals), p=interval_p, size=n)

rand_intervals = []
for i in range(n):
interval = intervals.iloc[rand_interval_index[i]]
start = rng.integers(interval.start, interval.end - size, endpoint=True)
end = start + size
rand_intervals.append([interval.chrom, start, end])
rand_intervals = pd.DataFrame(rand_intervals, columns=["chrom", "start", "end"])
return bf.merge(rand_intervals).drop(columns="n_intervals")


def get_balanced_intervals(defined_intervals, annotation, window_size, promoter_upstream=1000):
# there's the issue of pseudogenes though... should be aware
exons = add_flank(get_annotation_features(annotation, "exon"), window_size//2)
print("exons: ", intervals_size(exons)/intervals_size(defined_intervals))
promoters = add_flank(get_promoters(annotation, promoter_upstream), window_size//2)
print("promoters: ", intervals_size(promoters)/intervals_size(defined_intervals))
intervals = union_intervals(exons, promoters)
intervals = intersect_intervals(add_jitter(intervals, 100), defined_intervals)
# in case they collide with undefined intervals
intervals = filter_length(intervals, window_size)
print("intervals: ", intervals_size(intervals)/intervals_size(defined_intervals))
# maybe add a 0.5 factor
n_random_intervals = intervals_size(intervals) // window_size
random_intervals = get_random_intervals(defined_intervals, window_size, n_random_intervals)
print("random_intervals: ", intervals_size(random_intervals)/intervals_size(defined_intervals))
intervals = union_intervals(intervals, random_intervals)
print("intervals: ", intervals_size(intervals)/intervals_size(defined_intervals))
print((intervals.end-intervals.start).min())
assert (intervals.end-intervals.start).min() >= window_size
return intervals


def make_windows(intervals, window_size, step_size, add_rc=False):
return pd.concat(
intervals.progress_apply(
lambda interval: get_interval_windows(interval, window_size, step_size, add_rc), axis=1,
).values,
ignore_index=True,
)


def get_interval_windows(interval, window_size, step_size, add_rc):
windows = pd.DataFrame(
dict(start=np.arange(interval.start, interval.end-window_size+1, step_size))
)
windows["end"] = windows.start + window_size
windows["chrom"] = interval.chrom
windows = windows[["chrom", "start", "end"]] # just re-ordering
windows["strand"] = "+"
if add_rc:
windows_neg = windows.copy() # TODO: this should be optional
windows_neg.strand = "-"
return pd.concat([windows, windows_neg], ignore_index=True)
return windows


def get_seq(intervals, genome):
intervals["seq"] = intervals.progress_apply(
lambda i: genome.get_seq(i.chrom, i.start, i.end, i.strand),
axis=1,
)
return intervals
Loading

0 comments on commit 1a979b2

Please sign in to comment.