Skip to content

Commit c3b00da

Browse files
committed
Fix genome mapping serialization
Actually test and fix issues with using the new serialization based approach for dealing with blat hits.
1 parent 841df5f commit c3b00da

File tree

6 files changed

+156
-48
lines changed

6 files changed

+156
-48
lines changed

rnacentral_pipeline/cli/genome_mapping.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -37,41 +37,42 @@ def hits():
3737
pass
3838

3939

40-
@hits.command('as-json')
40+
@hits.command('serialize')
4141
@click.argument('assembly_id')
4242
@click.argument('hits', default='-', type=click.File('r'))
43-
@click.argument('output', default='-', type=click.File('w'))
43+
@click.argument('output', default='-', type=click.File('wb'))
4444
def hits_json(assembly_id, hits, output):
4545
"""
46-
Convert the PSL file into a JSON-line file (one object per line). This is a
47-
lossy operation but keeps everything needed for selecting later.
46+
Serialize the PSL file into something that python can later process. This is
47+
a lossy operation but keeps everything needed for selecting later. This
48+
exists so we can do mulitple select steps and still merge the results.
4849
"""
49-
blat.as_json(assembly_id, hits, output)
50+
blat.as_pickle(assembly_id, hits, output)
5051

5152

52-
@cli.command('as-importable')
53-
@click.argument('hits', default='-', type=click.File('r'))
53+
@hits.command('as-importable')
54+
@click.argument('hits', default='-', type=click.File('rb'))
5455
@click.argument('output', default='-', type=click.File('w'))
55-
def as_importable(raw, output):
56+
def as_importable(hits, output):
5657
"""
5758
Convert a json-line file into a CSV that can be used for import by pgloader.
5859
This is lossy as it only keeps the things needed for the database.
5960
"""
60-
blat.as_importable(raw, output)
61+
blat.write_importable(hits, output)
6162

6263

6364
@hits.command('select')
64-
@click.option('--sort', default=False)
65-
@click.argument('hits', default='-', type=click.File('r'))
66-
@click.argument('output', default='-', type=click.File('w'))
65+
@click.option('--sort', is_flag=True, default=False)
66+
@click.argument('hits', default='-', type=click.File('rb'))
67+
@click.argument('output', default='-', type=click.File('wb'))
6768
def select_hits(hits, output, sort=False):
6869
"""
6970
Parse a JSON-line file and select the best hits in the file. The best hits
7071
are written to the output file. This assumes the file is sorted by
7172
urs_taxid unless --sort is given in which case the data is sorted in memory.
7273
That may be very expensive.
7374
"""
74-
blat.select_json(hits, output, sort=sort)
75+
blat.select_pickle(hits, output, sort=sort)
7576

7677

7778
@cli.command('url-for')

rnacentral_pipeline/databases/data/regions.py

+31-11
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def display_int(self):
7474

7575

7676
# @enum.unique
77-
class CoordianteStart(enum.Enum):
77+
class CoordinateStart(enum.Enum):
7878
zero = 0
7979
one = 1
8080

@@ -124,9 +124,19 @@ class CoordinateSystem(object):
124124
http://genome.ucsc.edu/blog/the-ucsc-genome-browser-coordinate-counting-systems/
125125
"""
126126

127-
basis = attr.ib(validator=is_a(CoordianteStart))
127+
basis = attr.ib(validator=is_a(CoordinateStart))
128128
close_status = attr.ib(validator=is_a(CloseStatus))
129129

130+
@classmethod
131+
def build(cls, value):
132+
if isinstance(value, six.text_type):
133+
return cls.from_name(value)
134+
if isinstance(value, dict):
135+
return cls(**value)
136+
if isinstance(value, cls):
137+
return value
138+
raise ValueError("Cannot build CoordinateSystem from %s" % str(value))
139+
130140
@classmethod
131141
def from_name(cls, name):
132142
"""
@@ -143,7 +153,7 @@ def from_name(cls, name):
143153
raise UnknownCoordinateSystem(name)
144154

145155
return cls(
146-
basis=CoordianteStart.from_name(basis_name),
156+
basis=CoordinateStart.from_name(basis_name),
147157
close_status=CloseStatus.from_name(close_name),
148158
)
149159

@@ -178,9 +188,9 @@ def size(self, location):
178188

179189
def as_zero_based(self, location):
180190
start = location.start
181-
if self.basis is CoordianteStart.zero:
191+
if self.basis is CoordinateStart.zero:
182192
pass
183-
elif self.basis is CoordianteStart.one:
193+
elif self.basis is CoordinateStart.one:
184194
start = start - 1
185195
else:
186196
raise ValueError("Unknown type of start: %s" % self.basis)
@@ -189,9 +199,9 @@ def as_zero_based(self, location):
189199

190200
def as_one_based(self, location):
191201
start = location.start
192-
if self.basis is CoordianteStart.zero:
202+
if self.basis is CoordinateStart.zero:
193203
start = start + 1
194-
elif self.basis is CoordianteStart.one:
204+
elif self.basis is CoordinateStart.one:
195205
pass
196206
else:
197207
raise ValueError("Unknown type of start: %s" % self.basis)
@@ -218,16 +228,26 @@ def greater_than_start(self, attribute, value):
218228
(value, self.start))
219229

220230

231+
def as_sorted_exons(raw):
232+
exons = []
233+
for entry in raw:
234+
if isinstance(entry, dict):
235+
exons.append(Exon(**entry))
236+
else:
237+
exons.append(entry)
238+
return tuple(sorted(exons, key=op.attrgetter('start')))
239+
240+
221241
@attr.s(frozen=True, hash=True, slots=True)
222242
class SequenceRegion(object):
223243
assembly_id = attr.ib(validator=is_a(six.text_type), converter=six.text_type)
224244
chromosome = attr.ib(validator=is_a(six.text_type), converter=six.text_type)
225245
strand = attr.ib(validator=is_a(Strand), converter=Strand.build)
226-
exons = attr.ib(
227-
validator=is_a(tuple),
228-
converter=lambda es: tuple(sorted(es, key=op.attrgetter('start'))),
246+
exons = attr.ib(validator=is_a(tuple), converter=as_sorted_exons)
247+
coordinate_system = attr.ib(
248+
validator=is_a(CoordinateSystem),
249+
converter=CoordinateSystem.build,
229250
)
230-
coordinate_system = attr.ib(validator=is_a(CoordinateSystem))
231251

232252
@property
233253
def start(self):

rnacentral_pipeline/rnacentral/genome_mapping/blat.py

+12-22
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""
1515

1616
import csv
17+
import json
1718
import operator as op
1819
import itertools as it
1920
import logging
@@ -23,6 +24,7 @@
2324
import attr
2425
from attr.validators import instance_of as is_a
2526

27+
from rnacentral_pipeline import utils
2628
from rnacentral_pipeline.databases.data.regions import Exon
2729
from rnacentral_pipeline.databases.data.regions import Strand
2830
from rnacentral_pipeline.databases.data.regions import SequenceRegion
@@ -129,12 +131,6 @@ def parse_psl(assembly_id, handle):
129131
yield BlatHit.build(assembly_id, result)
130132

131133

132-
def parse_json(handle):
133-
for line in handle:
134-
data = json.loads(line)
135-
yield BlatHit(**data)
136-
137-
138134
def select_hits(hits):
139135
hits = six.moves.filter(select_possible, hits)
140136
hits = it.groupby(hits, op.attrgetter('upi'))
@@ -144,27 +140,21 @@ def select_hits(hits):
144140
return hits
145141

146142

147-
def write_json(hits, output):
148-
for hit in hits:
149-
output.write(json.dumps(hit))
150-
output.write('\n')
151-
152-
153-
def write_importable(hits, output):
143+
def write_importable(handle, output):
144+
hits = utils.unpickle_stream(handle)
154145
writeable = six.moves.map(op.methodcaller('writeable'), hits)
155-
writeable = it.chain.from_iterable(hits)
146+
writeable = it.chain.from_iterable(writeable)
156147
csv.writer(output).writerows(writeable)
157148

158149

159-
def as_json(assembly_id, hits, output):
150+
def as_pickle(assembly_id, hits, output):
160151
parsed = parse_psl(assembly_id, hits)
161-
parsed = six.moves.map(attr.asdict, selected)
162-
write_json(parsed, output)
152+
utils.pickle_stream(parsed, output)
163153

164154

165-
def select_json(hits, output, sort=False):
166-
parsed = parse_json(hits)
155+
def select_pickle(handle, output, sort=False):
156+
hits = utils.unpickle_stream(handle)
167157
if sort:
168-
parsed = sorted(parsed, key=op.itemgetter('upi'))
169-
selected = select_hits(parsed)
170-
write_json(selected, output)
158+
hits = sorted(hits, key=op.attrgetter('upi'))
159+
selected = select_hits(hits)
160+
utils.pickle_stream(selected, output)

rnacentral_pipeline/utils.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Copyright [2009-2018] EMBL-European Bioinformatics Institute
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import six.moves.cPickle as pickle
17+
18+
def pickle_stream(stream, handle, *args, **kwargs):
19+
for entry in stream:
20+
pickle.dump(entry, handle, *args, **kwargs)
21+
22+
23+
def unpickle_stream(handle, *args, **kwargs):
24+
try:
25+
while True:
26+
yield pickle.load(handle)
27+
except EOFError:
28+
raise StopIteration()

tests/databases/data/regions_test.py

+41-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
limitations under the License.
1414
"""
1515

16+
import six.moves.cPickle as pickle
17+
1618
import pytest
1719

1820
from rnacentral_pipeline.databases.data.regions import *
@@ -46,12 +48,12 @@ def test_fails_with_bad_strands(raw):
4648
@pytest.mark.parametrize('name,expected', [
4749
('0-start, half-open',
4850
CoordinateSystem(
49-
basis=CoordianteStart.zero,
51+
basis=CoordinateStart.zero,
5052
close_status=CloseStatus.open)
5153
),
5254
('1-start, fully-closed',
5355
CoordinateSystem(
54-
basis=CoordianteStart.one,
56+
basis=CoordinateStart.one,
5557
close_status=CloseStatus.closed)
5658
),
5759
])
@@ -234,3 +236,40 @@ def test_can_generate_correct_writeable(accession, strand, coordinate_system, ex
234236
)
235237
upi = accession.startswith('U')
236238
assert list(region.writeable(accession, is_upi=upi)) == expected
239+
240+
241+
@pytest.mark.parametrize('data', [
242+
(Strand.reverse),
243+
(Strand.unknown),
244+
(Strand.forward),
245+
(CoordinateStart.zero),
246+
(CoordinateStart.one),
247+
(CloseStatus.closed),
248+
(CloseStatus.open),
249+
])
250+
def test_can_serialize_enums_to_and_from_pickle(data):
251+
assert pickle.loads(pickle.dumps(data)) is data
252+
253+
254+
@pytest.mark.parametrize('region', [
255+
(SequenceRegion(
256+
assembly_id='GRCh38',
257+
chromosome='1',
258+
strand=Strand.unknown,
259+
exons=[Exon(start=100000, stop=2000000), Exon(start=-1, stop=10000)],
260+
coordinate_system=CoordinateSystem.from_name('0-start, half-open')
261+
)),
262+
(SequenceRegion(
263+
assembly_id='GRCh38',
264+
chromosome='MT',
265+
strand=Strand.forward,
266+
exons=[Exon(start=10, stop=200), Exon(start=-1, stop=10000)],
267+
coordinate_system=CoordinateSystem.from_name('1-start, fully-closed')
268+
)),
269+
])
270+
def test_can_serialize_regions_to_and_from_pickle(region):
271+
data = attr.asdict(region)
272+
loaded = pickle.loads(pickle.dumps(data))
273+
assert loaded == data
274+
assert SequenceRegion(**loaded) == region
275+
assert pickle.loads(pickle.dumps(region)) == region

tests/utils_test.py

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
Copyright [2009-2018] EMBL-European Bioinformatics Institute
5+
Licensed under the Apache License, Version 2.0 (the "License");
6+
you may not use this file except in compliance with the License.
7+
You may obtain a copy of the License at
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
"""
15+
16+
import tempfile
17+
18+
import six
19+
20+
from rnacentral_pipeline import utils
21+
22+
def test_can_serialize_stream():
23+
data = ['a', 1, 2, 3, 4]
24+
with tempfile.NamedTemporaryFile() as tmp:
25+
utils.pickle_stream(data, tmp)
26+
tmp.seek(0)
27+
result = utils.unpickle_stream(tmp)
28+
assert not isinstance(result, list)
29+
for index, obj in enumerate(result):
30+
assert data[index] == obj

0 commit comments

Comments
 (0)