Skip to content

Commit 4c11001

Browse files
authored
Improve internal DX around byte classification [1] (#16864)
This PR improves the internal DX when working with `u8` classification into a smaller enum. This is done by implementing a `ClassifyBytes` proc derive macro. The benefit of this is that the DX is much better and everything you will see here is done at compile time. Before: ```rs #[derive(Debug, Clone, Copy, PartialEq)] enum Class { ValidStart, ValidInside, OpenBracket, OpenParen, Slash, Other, } const CLASS_TABLE: [Class; 256] = { let mut table = [Class::Other; 256]; macro_rules! set { ($class:expr, $($byte:expr),+ $(,)?) => { $(table[$byte as usize] = $class;)+ }; } macro_rules! set_range { ($class:expr, $start:literal ..= $end:literal) => { let mut i = $start; while i <= $end { table[i as usize] = $class; i += 1; } }; } set_range!(Class::ValidStart, b'a'..=b'z'); set_range!(Class::ValidStart, b'A'..=b'Z'); set_range!(Class::ValidStart, b'0'..=b'9'); set!(Class::OpenBracket, b'['); set!(Class::OpenParen, b'('); set!(Class::Slash, b'/'); set!(Class::ValidInside, b'-', b'_', b'.'); table }; ``` After: ```rs #[derive(Debug, Clone, Copy, PartialEq, ClassifyBytes)] enum Class { #[bytes_range(b'a'..=b'z', b'A'..=b'Z', b'0'..=b'9')] ValidStart, #[bytes(b'-', b'_', b'.')] ValidInside, #[bytes(b'[')] OpenBracket, #[bytes(b'(')] OpenParen, #[bytes(b'/')] Slash, #[fallback] Other, } ``` Before we were generating a `CLASS_TABLE` that we could access directly, but now it will be part of the `Class`. This means that the usage has to change: ```diff - CLASS_TABLE[cursor.curr as usize] + Class::TABLE[cursor.curr as usize] ``` This is slightly worse UX, and this is where another change comes in. We implemented the `From<u8> for #enum_name` trait inside of the `ClassifyBytes` derive macro. This allows us to use `.into()` on any `u8` as long as we are comparing it to a `Class` instance. In our scenario: ```diff - Class::TABLE[cursor.curr as usize] + cursor.curr.into() ``` Usage wise, this looks something like this: ```diff while cursor.pos < len { - match Class::TABLE[cursor.curr as usize] { + match cursor.curr.into() { - Class::Escape => match Class::Table[cursor.next as usize] { + Class::Escape => match cursor.next.into() { // An escaped whitespace character is not allowed Class::Whitespace => return MachineState::Idle, // An escaped character, skip ahead to the next character _ => cursor.advance(), }, // End of the string Class::Quote if cursor.curr == end_char => return self.done(start_pos, cursor), // Any kind of whitespace is not allowed Class::Whitespace => return MachineState::Idle, // Everything else is valid _ => {} }; cursor.advance() } MachineState::Idle } } ``` If you manually look at the `Class::TABLE` in your editor for example, you can see that it is properly generated at compile time. Given this input: ```rs #[derive(Clone, Copy, ClassifyBytes)] enum Class { #[bytes_range(b'a'..=b'z')] AlphaLower, #[bytes_range(b'A'..=b'Z')] AlphaUpper, #[bytes(b'@')] At, #[bytes(b':')] Colon, #[bytes(b'-')] Dash, #[bytes(b'.')] Dot, #[bytes(b'\0')] End, #[bytes(b'!')] Exclamation, #[bytes_range(b'0'..=b'9')] Number, #[bytes(b'[')] OpenBracket, #[bytes(b']')] CloseBracket, #[bytes(b'(')] OpenParen, #[bytes(b'%')] Percent, #[bytes(b'"', b'\'', b'`')] Quote, #[bytes(b'/')] Slash, #[bytes(b'_')] Underscore, #[bytes(b' ', b'\t', b'\n', b'\r', b'\x0C')] Whitespace, #[fallback] Other, } ``` This is the result: <img width="1244" alt="image" src="https://github.com/user-attachments/assets/6ffd6ad3-0b2f-4381-a24c-593e4c72080e" />
1 parent 0b36dd5 commit 4c11001

14 files changed

+474
-514
lines changed

Cargo.lock

+10
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
[package]
2+
name = "classification-macros"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
[lib]
7+
proc-macro = true
8+
9+
[dependencies]
10+
syn = "2"
11+
quote = "1"
12+
proc-macro2 = "1"
+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
use proc_macro::TokenStream;
2+
use quote::quote;
3+
use syn::{
4+
parse_macro_input, punctuated::Punctuated, token::Comma, Attribute, Data, DataEnum,
5+
DeriveInput, Expr, ExprLit, ExprRange, Ident, Lit, RangeLimits, Result, Variant,
6+
};
7+
8+
/// A custom derive that supports:
9+
///
10+
/// - `#[bytes(…)]` for single byte literals
11+
/// - `#[bytes_range(…)]` for inclusive byte ranges (b'a'..=b'z')
12+
/// - `#[fallback]` for a variant that covers everything else
13+
///
14+
/// Example usage:
15+
///
16+
/// ```rust
17+
/// use classification_macros::ClassifyBytes;
18+
///
19+
/// #[derive(Clone, Copy, ClassifyBytes)]
20+
/// enum Class {
21+
/// #[bytes(b'a', b'b', b'c')]
22+
/// Letters,
23+
///
24+
/// #[bytes_range(b'0'..=b'9')]
25+
/// Digits,
26+
///
27+
/// #[fallback]
28+
/// Other,
29+
/// }
30+
/// ```
31+
/// Then call `b'a'.into()` to get `Example::SomeLetters`.
32+
#[proc_macro_derive(ClassifyBytes, attributes(bytes, bytes_range, fallback))]
33+
pub fn classify_bytes_derive(input: TokenStream) -> TokenStream {
34+
let ast = parse_macro_input!(input as DeriveInput);
35+
36+
// This derive only works on an enum
37+
let Data::Enum(DataEnum { variants, .. }) = &ast.data else {
38+
return syn::Error::new_spanned(
39+
&ast.ident,
40+
"ClassifyBytes can only be derived on an enum.",
41+
)
42+
.to_compile_error()
43+
.into();
44+
};
45+
46+
let enum_name = &ast.ident;
47+
48+
let mut byte_map: [Option<Ident>; 256] = [const { None }; 256];
49+
let mut fallback_variant: Option<Ident> = None;
50+
51+
// Start parsing the variants
52+
for variant in variants {
53+
let variant_ident = &variant.ident;
54+
55+
// If this variant has #[fallback], record it
56+
if has_fallback_attr(variant) {
57+
if fallback_variant.is_some() {
58+
let err = syn::Error::new_spanned(
59+
variant_ident,
60+
"Multiple variants have #[fallback]. Only one allowed.",
61+
);
62+
return err.to_compile_error().into();
63+
}
64+
fallback_variant = Some(variant_ident.clone());
65+
}
66+
67+
// Get #[bytes(…)]
68+
let single_bytes = get_bytes_attrs(&variant.attrs);
69+
70+
// Get #[bytes_range(…)]
71+
let range_bytes = get_bytes_range_attrs(&variant.attrs);
72+
73+
// Combine them
74+
let all_bytes = single_bytes
75+
.into_iter()
76+
.chain(range_bytes)
77+
.collect::<Vec<_>>();
78+
79+
// Mark them in the table
80+
for b in all_bytes {
81+
byte_map[b as usize] = Some(variant_ident.clone());
82+
}
83+
}
84+
85+
// If no fallback variant is found, default to "Other"
86+
let fallback_ident = fallback_variant.expect("A variant marked with #[fallback] is missing");
87+
88+
// For each of the 256 byte values, fill the table
89+
let fill = byte_map
90+
.clone()
91+
.into_iter()
92+
.map(|variant_opt| match variant_opt {
93+
Some(ident) => quote!(#enum_name::#ident),
94+
None => quote!(#enum_name::#fallback_ident),
95+
});
96+
97+
// Generate the final expanded code
98+
let expanded = quote! {
99+
impl #enum_name {
100+
pub const TABLE: [#enum_name; 256] = [
101+
#(#fill),*
102+
];
103+
}
104+
105+
impl From<u8> for #enum_name {
106+
fn from(byte: u8) -> Self {
107+
#enum_name::TABLE[byte as usize]
108+
}
109+
}
110+
};
111+
112+
TokenStream::from(expanded)
113+
}
114+
115+
/// Checks if a variant has `#[fallback]`
116+
fn has_fallback_attr(variant: &Variant) -> bool {
117+
variant
118+
.attrs
119+
.iter()
120+
.any(|attr| attr.path().is_ident("fallback"))
121+
}
122+
123+
/// Get all single byte literals from `#[bytes(…)]`
124+
fn get_bytes_attrs(attrs: &[Attribute]) -> Vec<u8> {
125+
let mut assigned = Vec::new();
126+
for attr in attrs {
127+
if attr.path().is_ident("bytes") {
128+
match parse_bytes_attr(attr) {
129+
Ok(list) => assigned.extend(list),
130+
Err(e) => panic!("Error parsing #[bytes(...)]: {}", e),
131+
}
132+
}
133+
}
134+
assigned
135+
}
136+
137+
/// Parse `#[bytes(...)]` as a comma-separated list of **byte literals**, e.g. `b'a'`, `b'\n'`.
138+
fn parse_bytes_attr(attr: &Attribute) -> Result<Vec<u8>> {
139+
// We'll parse it as a list of syn::Lit separated by commas: e.g. (b'a', b'b')
140+
let items: Punctuated<Lit, Comma> = attr.parse_args_with(Punctuated::parse_terminated)?;
141+
let mut out = Vec::new();
142+
for lit in items {
143+
match lit {
144+
Lit::Byte(lb) => out.push(lb.value()),
145+
_ => {
146+
return Err(syn::Error::new_spanned(
147+
lit,
148+
"Expected a byte literal like b'a'",
149+
))
150+
}
151+
}
152+
}
153+
Ok(out)
154+
}
155+
156+
/// Get all byte ranges from `#[bytes_range(...)]`
157+
fn get_bytes_range_attrs(attrs: &[Attribute]) -> Vec<u8> {
158+
let mut assigned = Vec::new();
159+
for attr in attrs {
160+
if attr.path().is_ident("bytes_range") {
161+
match parse_bytes_range_attr(attr) {
162+
Ok(list) => assigned.extend(list),
163+
Err(e) => panic!("Error parsing #[bytes_range(...)]: {}", e),
164+
}
165+
}
166+
}
167+
assigned
168+
}
169+
170+
/// Parse `#[bytes_range(...)]` as a comma-separated list of range expressions, e.g.:
171+
/// `b'a'..=b'z', b'0'..=b'9'`
172+
fn parse_bytes_range_attr(attr: &Attribute) -> Result<Vec<u8>> {
173+
// We'll parse each element as a syn::Expr, then see if it's an Expr::Range
174+
let exprs: Punctuated<Expr, Comma> = attr.parse_args_with(Punctuated::parse_terminated)?;
175+
let mut out = Vec::new();
176+
177+
for expr in exprs {
178+
if let Expr::Range(ExprRange {
179+
start: Some(start),
180+
end: Some(end),
181+
limits,
182+
..
183+
}) = expr
184+
{
185+
let from = extract_byte_literal(&start)?;
186+
let to = extract_byte_literal(&end)?;
187+
188+
match limits {
189+
RangeLimits::Closed(_) => {
190+
// b'a'..=b'z'
191+
if from <= to {
192+
out.extend(from..=to);
193+
}
194+
}
195+
RangeLimits::HalfOpen(_) => {
196+
// b'a'..b'z' => from..(to-1)
197+
if from < to {
198+
out.extend(from..to);
199+
}
200+
}
201+
}
202+
} else {
203+
return Err(syn::Error::new_spanned(
204+
expr,
205+
"Expected a byte range like b'a'..=b'z'",
206+
));
207+
}
208+
}
209+
210+
Ok(out)
211+
}
212+
213+
/// Extract a u8 from an expression that can be:
214+
///
215+
/// - `Expr::Lit(Lit::Byte(...))`, e.g. b'a'
216+
/// - `Expr::Lit(Lit::Int(...))`, e.g. 0x80 or 255
217+
fn extract_byte_literal(expr: &Expr) -> Result<u8> {
218+
if let Expr::Lit(ExprLit { lit, .. }) = expr {
219+
match lit {
220+
// Existing case: b'a'
221+
Lit::Byte(lb) => Ok(lb.value()),
222+
223+
// New case: 0x80, 255, etc.
224+
Lit::Int(li) => {
225+
let value = li.base10_parse::<u64>()?;
226+
if value <= 255 {
227+
Ok(value as u8)
228+
} else {
229+
Err(syn::Error::new_spanned(
230+
li,
231+
format!("Integer literal {} out of range for a byte (0..255)", value),
232+
))
233+
}
234+
}
235+
236+
_ => Err(syn::Error::new_spanned(
237+
lit,
238+
"Expected b'...' or an integer literal in range 0..=255",
239+
)),
240+
}
241+
} else {
242+
Err(syn::Error::new_spanned(
243+
expr,
244+
"Expected a literal expression like b'a' or 0x80",
245+
))
246+
}
247+
}

crates/oxide/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ ignore = "0.4.23"
1717
dunce = "1.0.5"
1818
bexpand = "1.2.0"
1919
fast-glob = "0.4.3"
20+
classification-macros = { path = "../classification-macros" }
2021

2122
[dev-dependencies]
2223
tempfile = "3.13.0"

0 commit comments

Comments
 (0)