Skip to content

Commit 92cd3f8

Browse files
Merge pull request #302 from neherlab/refactor/ascii-char
refactor: introduce `AsciiChar` wrapper over `u8`
2 parents 67ba54a + 6a7488a commit 92cd3f8

File tree

12 files changed

+427
-232
lines changed

12 files changed

+427
-232
lines changed

packages/treetime/benches/find_letter_ranges_benchmark.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use criterion::{black_box, criterion_group, criterion_main, Criterion};
22
use treetime::representation::seq::Seq;
3+
use treetime::representation::seq_char::AsciiChar;
34
use treetime::seq::find_char_ranges::find_letter_ranges_by;
45

56
const SEQ: &str = "\
@@ -27,8 +28,8 @@ const SEQ: &str = "\
2728
ACACTGTCTTCATGTTGTCGGCCCAAATGTTAACAAAGGTGAAGACATTCAACTTCTTAA\
2829
";
2930

30-
fn pred(c: u8) -> bool {
31-
c == b'N' || c == b'-'
31+
fn pred(c: AsciiChar) -> bool {
32+
c == AsciiChar(b'N') || c == AsciiChar(b'-')
3233
}
3334

3435
pub fn bench_1(c: &mut Criterion) {

packages/treetime/src/alphabet/alphabet.rs

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::io::json::{json_write_str, JsonPretty};
22
use crate::representation::bitset128::BitSet128;
3+
use crate::representation::seq_char::AsciiChar;
34
use crate::representation::state_set::StateSet;
45
use crate::utils::string::quote;
56
use crate::{make_error, stateset, vec_u8};
@@ -16,9 +17,9 @@ use std::fmt::Display;
1617
use std::iter::once;
1718
use strum_macros::Display;
1819

19-
pub const NON_CHAR: u8 = b'.';
20-
pub const VARIABLE_CHAR: u8 = b'~';
21-
pub const FILL_CHAR: u8 = b' ';
20+
pub const NON_CHAR: AsciiChar = AsciiChar(b'.');
21+
pub const VARIABLE_CHAR: AsciiChar = AsciiChar(b'~');
22+
pub const FILL_CHAR: AsciiChar = AsciiChar(b' ');
2223

2324
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, ArgEnum, SmartDefault, Display)]
2425
#[clap(rename = "kebab-case")]
@@ -28,33 +29,33 @@ pub enum AlphabetName {
2829
Aa,
2930
}
3031

31-
pub type ProfileMap = IndexMap<u8, Array1<f64>>;
32-
pub type StateSetMap = IndexMap<u8, StateSet>;
33-
pub type CharToSet = IndexMap<u8, StateSet>;
34-
pub type SetToChar = IndexMap<StateSet, u8>;
32+
pub type ProfileMap = IndexMap<AsciiChar, Array1<f64>>;
33+
pub type StateSetMap = IndexMap<AsciiChar, StateSet>;
34+
pub type CharToSet = IndexMap<AsciiChar, StateSet>;
35+
pub type SetToChar = IndexMap<StateSet, AsciiChar>;
3536

3637
#[derive(Clone, Debug, Serialize, Deserialize)]
3738
pub struct Alphabet {
3839
all: StateSet,
3940
canonical: StateSet,
40-
ambiguous: IndexMap<u8, Vec<u8>>,
41+
ambiguous: IndexMap<AsciiChar, Vec<AsciiChar>>,
4142
ambiguous_keys: StateSet,
4243
determined: StateSet,
4344
undetermined: StateSet,
44-
unknown: u8,
45-
gap: u8,
45+
unknown: AsciiChar,
46+
gap: AsciiChar,
4647
treat_gap_as_unknown: bool,
4748
profile_map: ProfileMap,
4849

4950
#[serde(skip)]
50-
char_to_set: IndexMap<u8, StateSet>,
51+
char_to_set: IndexMap<AsciiChar, StateSet>,
5152
#[serde(skip)]
52-
set_to_char: IndexMap<StateSet, u8>,
53+
set_to_char: IndexMap<StateSet, AsciiChar>,
5354

5455
#[serde(skip)]
5556
char_to_index: Vec<Option<usize>>,
5657
#[serde(skip)]
57-
index_to_char: Vec<u8>,
58+
index_to_char: Vec<AsciiChar>,
5859
}
5960

6061
impl Default for Alphabet {
@@ -111,15 +112,21 @@ impl Alphabet {
111112
treat_gap_as_unknown,
112113
} = cfg;
113114

115+
let gap = AsciiChar::from(*gap);
116+
let unknown = AsciiChar::from(*unknown);
117+
114118
let canonical = StateSet::from_iter(canonical);
115119
if canonical.is_empty() {
116120
return make_error!("When creating alphabet: canonical set of characters is empty. This is not allowed.");
117121
}
118122

119-
let ambiguous: IndexMap<u8, Vec<u8>> = ambiguous.to_owned();
123+
let ambiguous: IndexMap<AsciiChar, Vec<AsciiChar>> = ambiguous
124+
.iter()
125+
.map(|(k, v)| (AsciiChar(*k), v.iter().copied().map(AsciiChar).collect()))
126+
.collect();
120127
let ambiguous_keys = ambiguous.keys().collect();
121128

122-
let undetermined = stateset! {*unknown, *gap};
129+
let undetermined = stateset! {unknown, gap};
123130
let determined = StateSet::from_union([canonical, ambiguous_keys]);
124131
let all = StateSet::from_union([canonical, ambiguous_keys, undetermined]);
125132

@@ -128,7 +135,7 @@ impl Alphabet {
128135
let mut char_to_index = vec![None; 128];
129136
let mut index_to_char = Vec::with_capacity(canonical.len());
130137
for (i, c) in canonical.iter().enumerate() {
131-
char_to_index[c as usize] = Some(i);
138+
char_to_index[usize::from(c)] = Some(i);
132139
index_to_char.push(c);
133140
}
134141

@@ -137,8 +144,8 @@ impl Alphabet {
137144
ambiguous.iter().for_each(|(key, chars)| {
138145
char_to_set.insert(*key, StateSet::from_iter(chars));
139146
});
140-
char_to_set.insert(*gap, StateSet::from_char(*gap));
141-
char_to_set.insert(*unknown, StateSet::from_char(*unknown));
147+
char_to_set.insert(gap, StateSet::from_char(gap));
148+
char_to_set.insert(unknown, StateSet::from_char(unknown));
142149
char_to_set
143150
};
144151

@@ -153,8 +160,8 @@ impl Alphabet {
153160
ambiguous_keys,
154161
determined,
155162
undetermined,
156-
unknown: *unknown,
157-
gap: *gap,
163+
unknown,
164+
gap,
158165
treat_gap_as_unknown: *treat_gap_as_unknown,
159166
profile_map,
160167
char_to_set,
@@ -163,7 +170,7 @@ impl Alphabet {
163170
}
164171

165172
#[inline]
166-
pub fn get_profile(&self, c: u8) -> &Array1<f64> {
173+
pub fn get_profile(&self, c: AsciiChar) -> &Array1<f64> {
167174
self
168175
.profile_map
169176
.get(&c)
@@ -180,7 +187,7 @@ impl Alphabet {
180187
pub fn construct_profile<I, T>(&self, chars: I) -> Result<Array1<f64>, Report>
181188
where
182189
I: IntoIterator<Item = T>,
183-
T: Borrow<u8> + Display,
190+
T: Borrow<AsciiChar> + Display,
184191
{
185192
let mut profile = Array1::<f64>::zeros(self.n_canonical());
186193
for c in chars {
@@ -193,7 +200,7 @@ impl Alphabet {
193200
Ok(profile)
194201
}
195202

196-
pub fn get_code(&self, profile: &Array1<f64>) -> u8 {
203+
pub fn get_code(&self, profile: &Array1<f64>) -> AsciiChar {
197204
// TODO(perf): this mapping needs to be precomputed
198205
self
199206
.profile_map
@@ -204,39 +211,39 @@ impl Alphabet {
204211
}
205212

206213
#[allow(single_use_lifetimes)] // TODO: remove when anonymous lifetimes in `impl Trait` are stabilized
207-
pub fn seq2prof<'a>(&self, chars: impl IntoIterator<Item = &'a u8>) -> Result<Array2<f64>, Report> {
214+
pub fn seq2prof<'a>(&self, chars: impl IntoIterator<Item = &'a AsciiChar>) -> Result<Array2<f64>, Report> {
208215
let prof = stack(
209216
Axis(0),
210217
&chars.into_iter().map(|&c| self.get_profile(c).view()).collect_vec(),
211218
)?;
212219
Ok(prof)
213220
}
214221

215-
pub fn set_to_char(&self, c: StateSet) -> u8 {
222+
pub fn set_to_char(&self, c: StateSet) -> AsciiChar {
216223
self.set_to_char[&c]
217224
}
218225

219-
pub fn char_to_set(&self, c: u8) -> StateSet {
220-
self.char_to_set[&c]
226+
pub fn char_to_set(&self, c: impl Into<AsciiChar>) -> StateSet {
227+
self.char_to_set[&c.into()]
221228
}
222229

223230
/// All existing characters (including 'unknown' and 'gap')
224-
pub fn chars(&self) -> impl Iterator<Item = u8> + '_ {
231+
pub fn chars(&self) -> impl Iterator<Item = AsciiChar> + '_ {
225232
self.all.iter()
226233
}
227234

228235
/// Get u8 by index (indexed in the same order as given by `.chars()`)
229-
pub fn char(&self, index: usize) -> u8 {
236+
pub fn char(&self, index: usize) -> AsciiChar {
230237
self.index_to_char[index]
231238
}
232239

233240
/// Get index of a character (indexed in the same order as given by `.chars()`)
234-
pub fn index(&self, c: u8) -> usize {
235-
self.char_to_index[c as usize].unwrap()
241+
pub fn index(&self, c: impl Into<usize>) -> usize {
242+
self.char_to_index[c.into()].unwrap()
236243
}
237244

238245
/// Check if character is in alphabet (including 'unknown' and 'gap')
239-
pub fn contains(&self, c: u8) -> bool {
246+
pub fn contains(&self, c: AsciiChar) -> bool {
240247
self.all.contains(c)
241248
}
242249

@@ -245,12 +252,12 @@ impl Alphabet {
245252
}
246253

247254
/// Canonical (unambiguous) characters (e.g. 'A', 'C', 'G', 'T' in nuc alphabet)
248-
pub fn canonical(&self) -> impl Iterator<Item = u8> + '_ {
255+
pub fn canonical(&self) -> impl Iterator<Item = AsciiChar> + '_ {
249256
self.canonical.iter()
250257
}
251258

252259
/// Check is character is canonical
253-
pub fn is_canonical(&self, c: u8) -> bool {
260+
pub fn is_canonical(&self, c: AsciiChar) -> bool {
254261
self.canonical.contains(c)
255262
}
256263

@@ -259,12 +266,12 @@ impl Alphabet {
259266
}
260267

261268
/// Ambiguous characters (e.g. 'R', 'S' etc. in nuc alphabet)
262-
pub fn ambiguous(&self) -> impl Iterator<Item = u8> + '_ {
269+
pub fn ambiguous(&self) -> impl Iterator<Item = AsciiChar> + '_ {
263270
self.ambiguous_keys.iter()
264271
}
265272

266273
/// Check if character is ambiguous (e.g. 'R', 'S' etc. in nuc alphabet)
267-
pub fn is_ambiguous(&self, c: u8) -> bool {
274+
pub fn is_ambiguous(&self, c: AsciiChar) -> bool {
268275
self.ambiguous_keys.contains(c)
269276
}
270277

@@ -273,11 +280,11 @@ impl Alphabet {
273280
}
274281

275282
/// Determined characters: canonical or ambiguous
276-
pub fn determined(&self) -> impl Iterator<Item = u8> + '_ {
283+
pub fn determined(&self) -> impl Iterator<Item = AsciiChar> + '_ {
277284
self.determined.iter()
278285
}
279286

280-
pub fn is_determined(&self, c: u8) -> bool {
287+
pub fn is_determined(&self, c: AsciiChar) -> bool {
281288
self.determined.contains(c)
282289
}
283290

@@ -286,11 +293,11 @@ impl Alphabet {
286293
}
287294

288295
/// Undetermined characters: gap or unknown
289-
pub fn undetermined(&self) -> impl Iterator<Item = u8> + '_ {
296+
pub fn undetermined(&self) -> impl Iterator<Item = AsciiChar> + '_ {
290297
self.undetermined.iter()
291298
}
292299

293-
pub fn is_undetermined(&self, c: u8) -> bool {
300+
pub fn is_undetermined(&self, c: AsciiChar) -> bool {
294301
self.undetermined.contains(c)
295302
}
296303

@@ -299,23 +306,23 @@ impl Alphabet {
299306
}
300307

301308
/// Get 'unknown' character
302-
pub fn unknown(&self) -> u8 {
309+
pub fn unknown(&self) -> AsciiChar {
303310
self.unknown
304311
}
305312

306313
/// Check if character is an 'unknown' character
307-
pub fn is_unknown(&self, c: u8) -> bool {
308-
c == self.unknown()
314+
pub fn is_unknown(&self, c: impl Into<AsciiChar>) -> bool {
315+
c.into() == self.unknown()
309316
}
310317

311318
/// Get 'gap' character
312-
pub fn gap(&self) -> u8 {
319+
pub fn gap(&self) -> AsciiChar {
313320
self.gap
314321
}
315322

316323
/// Check if character is a gap
317-
pub fn is_gap(&self, c: u8) -> bool {
318-
c == self.gap()
324+
pub fn is_gap(&self, c: impl Into<AsciiChar>) -> bool {
325+
c.into() == self.gap()
319326
}
320327
}
321328

@@ -338,6 +345,9 @@ impl AlphabetConfig {
338345
treat_gap_as_unknown,
339346
} = self;
340347

348+
let gap = AsciiChar::from(*gap);
349+
let unknown = AsciiChar::from(*unknown);
350+
341351
self
342352
.validate()
343353
.wrap_err("When validating alphabet config")
@@ -353,11 +363,11 @@ impl AlphabetConfig {
353363
let mut profile_map: ProfileMap = canonical
354364
.iter()
355365
.zip(eye.rows())
356-
.map(|(s, x)| (*s, x.to_owned()))
366+
.map(|(s, x)| (AsciiChar(*s), x.to_owned()))
357367
.collect();
358368

359369
// Add unknown to profile map
360-
profile_map.insert(*unknown, Array1::<f64>::ones(canonical.len()));
370+
profile_map.insert(unknown, Array1::<f64>::ones(canonical.len()));
361371

362372
// Add ambiguous to profile map
363373
ambiguous.iter().for_each(|(&key, values)| {
@@ -366,12 +376,12 @@ impl AlphabetConfig {
366376
.enumerate()
367377
.map(|(i, c)| if values.contains(c) { 1.0 } else { 0.0 })
368378
.collect::<Array1<f64>>();
369-
profile_map.insert(key, profile);
379+
profile_map.insert(AsciiChar(key), profile);
370380
});
371381

372382
if *treat_gap_as_unknown {
373383
// Add gap to profile map
374-
profile_map.insert(*gap, profile_map[unknown].clone());
384+
profile_map.insert(gap, profile_map[&unknown].clone());
375385
}
376386

377387
Ok(profile_map)
@@ -397,7 +407,7 @@ impl AlphabetConfig {
397407
.collect_vec();
398408

399409
for reserved in [NON_CHAR, VARIABLE_CHAR, FILL_CHAR] {
400-
if all.iter().any(|&c| c == reserved) {
410+
if all.iter().any(|&c| c == u8::from(reserved)) {
401411
return make_error!("Alphabet contains reserved character: {reserved}");
402412
}
403413
}

0 commit comments

Comments
 (0)