forked from RNAcentral/rnacentral-import-pipeline
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelpers.py
65 lines (54 loc) · 2.12 KB
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import asyncio
import io
import typing as ty
import aiohttp
from Bio import Entrez, SeqIO
from more_itertools import chunked
from throttler import throttle
Entrez.email = "[email protected]"
@throttle(rate_limit=2, period=1.0)
async def fetch_records(session, accessions: ty.List[str]):
try:
accession_str = ",".join(accessions)
async with session.get(
("https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?"
+ f"db=nuccore&id={accession_str}&rettype=gb&retmode=text")
) as response:
records_text = await response.text()
handle = io.StringIO(records_text)
taxids = {}
for record in SeqIO.parse(handle, format="genbank"):
taxids[record.id] = extract_taxid(record)
return taxids
except Exception as e:
print(f"Error fetching records: {e}")
return {accession: None for accession in accessions}
def extract_taxid(record):
for feature in record.features:
if feature.type == "source" and "db_xref" in feature.qualifiers:
db_xrefs = feature.qualifiers["db_xref"]
for db_xref in db_xrefs:
if "taxon" in db_xref:
return int(db_xref.split(":")[1])
return None
async def get_taxids_for_accessions(accessions, batch_size=100, requests_per_second=3):
async with aiohttp.ClientSession() as session:
tasks = []
for batch in chunked(accessions, batch_size):
tasks.append(fetch_records(session, batch))
taxids = {}
for result in asyncio.as_completed(tasks):
mapping = await result
for accession, taxid in mapping.items():
if taxid is None:
continue
taxids[accession] = taxid
return taxids
def taxid_mapping(
rows: ty.List[ty.Dict[str, str]], batch_size=100
) -> ty.Dict[str, int]:
accessions = []
for row in rows:
acc = row["Instances"].split(",")
accessions.extend(a.split("/")[0] for a in acc)
return asyncio.run(get_taxids_for_accessions(accessions, batch_size=batch_size))