Skip to content

Commit

Permalink
refactor(rust): Use distributor channel in new-streaming CSV reader a…
Browse files Browse the repository at this point in the history
…nd prepare scanning routine for true parallel reading (#21189)
  • Loading branch information
orlp authored and Liam Brannigan committed Feb 12, 2025
1 parent 7b14ada commit 2be2036
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 96 deletions.
100 changes: 100 additions & 0 deletions crates/polars-io/src/csv/read/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use polars_core::prelude::*;
use polars_core::{config, POOL};
use polars_error::feature_gated;
use polars_utils::index::Bounded;
use polars_utils::select::select_unpredictable;
use rayon::prelude::*;

use super::buffer::Buffer;
Expand Down Expand Up @@ -607,6 +608,13 @@ pub struct CountLines {
quoting: bool,
}

#[derive(Copy, Clone, Debug)]
pub struct LineStats {
newline_count: usize,
last_newline_offset: usize,
end_inside_string: bool,
}

impl CountLines {
pub fn new(quote_char: Option<u8>, eol_char: u8) -> Self {
let quoting = quote_char.is_some();
Expand All @@ -626,6 +634,98 @@ impl CountLines {
}
}

/// Analyzes a chunk of CSV data.
///
/// Returns (newline_count, last_newline_offset, end_inside_string) twice,
/// the first is assuming the start of the chunk is *not* inside a string,
/// the second assuming the start is inside a string.
pub fn analyze_chunk(&self, bytes: &[u8]) -> [LineStats; 2] {
let mut scan_offset = 0;
let mut states = [
LineStats {
newline_count: 0,
last_newline_offset: 0,
end_inside_string: false,
},
LineStats {
newline_count: 0,
last_newline_offset: 0,
end_inside_string: false,
},
];

// false if even number of quotes seen so far, true otherwise.
#[allow(unused_assignments)]
let mut global_quote_parity = false;

#[cfg(feature = "simd")]
{
// 0 if even number of quotes seen so far, u64::MAX otherwise.
let mut global_quote_parity_mask = 0;
while scan_offset + 64 <= bytes.len() {
let block: [u8; 64] = unsafe {
bytes
.get_unchecked(scan_offset..scan_offset + 64)
.try_into()
.unwrap_unchecked()
};
let simd_bytes = SimdVec::from(block);
let eol_mask = simd_bytes.simd_eq(self.simd_eol_char).to_bitmask();
if self.quoting {
let quote_mask = simd_bytes.simd_eq(self.simd_quote_char).to_bitmask();
let quote_parity =
prefix_xorsum_inclusive(quote_mask) ^ global_quote_parity_mask;
global_quote_parity_mask = ((quote_parity as i64) >> 63) as u64;

let start_outside_string_eol_mask = eol_mask & !quote_parity;
states[0].newline_count += start_outside_string_eol_mask.count_ones() as usize;
states[0].last_newline_offset = select_unpredictable(
start_outside_string_eol_mask != 0,
(scan_offset + 63)
.wrapping_sub(start_outside_string_eol_mask.leading_zeros() as usize),
states[0].last_newline_offset,
);

let start_inside_string_eol_mask = eol_mask & quote_parity;
states[1].newline_count += start_inside_string_eol_mask.count_ones() as usize;
states[1].last_newline_offset = select_unpredictable(
start_inside_string_eol_mask != 0,
(scan_offset + 63)
.wrapping_sub(start_inside_string_eol_mask.leading_zeros() as usize),
states[1].last_newline_offset,
);
} else {
states[0].newline_count += eol_mask.count_ones() as usize;
states[0].last_newline_offset = select_unpredictable(
eol_mask != 0,
(scan_offset + 63).wrapping_sub(eol_mask.leading_zeros() as usize),
states[0].last_newline_offset,
);
}

scan_offset += 64;
}

global_quote_parity = global_quote_parity_mask > 0;
}

while scan_offset < bytes.len() {
let c = unsafe { *bytes.get_unchecked(scan_offset) };
global_quote_parity ^= (c == self.quote_char) & self.quoting;

let state = &mut states[global_quote_parity as usize];
state.newline_count += (c == self.eol_char) as usize;
state.last_newline_offset =
select_unpredictable(c == self.eol_char, scan_offset, state.last_newline_offset);

scan_offset += 1;
}

states[0].end_inside_string = global_quote_parity;
states[1].end_inside_string = !global_quote_parity;
states
}

pub fn find_next(&self, bytes: &[u8], chunk_size: &mut usize) -> (usize, usize) {
loop {
let b = unsafe { bytes.get_unchecked(..(*chunk_size).min(bytes.len())) };
Expand Down
28 changes: 0 additions & 28 deletions crates/polars-stream/src/async_primitives/wait_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,6 @@ impl WaitGroup {
}
}

// Wait group with an associated index.
pub struct IndexedWaitGroup {
index: usize,
wait_group: WaitGroup,
}

impl IndexedWaitGroup {
pub fn new(index: usize) -> Self {
Self {
index,
wait_group: Default::default(),
}
}

pub fn index(&self) -> usize {
self.index
}

pub fn token(&self) -> WaitToken {
self.wait_group.token()
}

pub async fn wait(self) -> Self {
self.wait_group.wait().await;
self
}
}

struct WaitGroupFuture<'a> {
inner: &'a Arc<WaitGroupInner>,
}
Expand Down
101 changes: 33 additions & 68 deletions crates/polars-stream/src/nodes/io_sources/csv.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use std::sync::atomic::Ordering;
use std::sync::Arc;

use futures::stream::FuturesUnordered;
use futures::StreamExt;
use polars_core::config;
use polars_core::prelude::{AnyValue, DataType, Field};
use polars_core::scalar::Scalar;
Expand Down Expand Up @@ -32,24 +30,25 @@ use super::multi_scan::{MultiScanable, RowRestrication};
use super::{SourceNode, SourceOutput};
use crate::async_executor::{self, spawn};
use crate::async_primitives::connector::{connector, Receiver};
use crate::async_primitives::wait_group::{IndexedWaitGroup, WaitToken};
use crate::async_primitives::distributor_channel::distributor_channel;
use crate::async_primitives::wait_group::WaitGroup;
use crate::morsel::SourceToken;
use crate::nodes::compute_node_prelude::*;
use crate::nodes::io_sources::MorselOutput;
use crate::nodes::{MorselSeq, TaskPriority};
use crate::DEFAULT_DISTRIBUTOR_BUFFER_SIZE;

struct LineBatch {
bytes: MemSlice,
n_lines: usize,
slice: (usize, usize),
row_offset: usize,
morsel_seq: MorselSeq,
wait_token: WaitToken,
path_name: Option<PlSmallStr>,
}

type AsyncTaskData = (
Vec<crate::async_primitives::connector::Receiver<LineBatch>>,
Vec<crate::async_primitives::distributor_channel::Receiver<LineBatch>>,
Arc<ChunkReader>,
async_executor::AbortOnDropHandle<PolarsResult<()>>,
);
Expand Down Expand Up @@ -114,6 +113,7 @@ impl SourceNode for CsvSourceNode {
|(mut line_batch_rx, mut recv_from)| {
let chunk_reader = chunk_reader.clone();
let source_token = source_token.clone();
let wait_group = WaitGroup::default();

spawn(TaskPriority::Low, async move {
while let Ok(mut morsel_output) = recv_from.recv().await {
Expand All @@ -123,7 +123,6 @@ impl SourceNode for CsvSourceNode {
slice: (offset, len),
row_offset,
morsel_seq,
wait_token,
mut path_name,
}) = line_batch_rx.recv().await
{
Expand All @@ -150,11 +149,12 @@ impl SourceNode for CsvSourceNode {
}

let mut morsel = Morsel::new(df, morsel_seq, source_token.clone());
morsel.set_consume_token(wait_token);
morsel.set_consume_token(wait_group.token());

if morsel_output.port.send(morsel).await.is_err() {
break;
}
wait_group.wait().await;

if source_token.stop_requested() {
morsel_output.outcome.stop();
Expand Down Expand Up @@ -211,8 +211,8 @@ impl CsvSourceNode {
) -> AsyncTaskData {
let verbose = self.verbose;

let (mut line_batch_senders, line_batch_receivers): (Vec<_>, Vec<_>) =
(0..num_pipelines).map(|_| connector()).unzip();
let (mut line_batch_sender, line_batch_receivers) =
distributor_channel(num_pipelines, DEFAULT_DISTRIBUTOR_BUFFER_SIZE);

let scan_sources = self.scan_sources.clone();
let run_async = scan_sources.is_cloud_url() || config::force_async();
Expand Down Expand Up @@ -274,29 +274,19 @@ impl CsvSourceNode {
return Err(err);
}

let mut wait_groups = (0..num_pipelines)
.map(|index| IndexedWaitGroup::new(index).wait())
.collect::<FuturesUnordered<_>>();
let morsel_seq_ref = &mut MorselSeq::default();
let current_row_offset_ref = &mut 0usize;

let n_parts_hint = num_pipelines * 16;

let line_counter = CountLines::new(quote_char, eol_char);

let comment_prefix = comment_prefix.as_ref();
let morsel_seq_ref = &mut MorselSeq::default();
let current_row_offset_ref = &mut 0usize;
let memslice_sources = scan_sources.iter().map(|x| {
let bytes = x.to_memslice_async_assume_latest(run_async)?;
PolarsResult::Ok((
bytes,
include_file_paths.then(|| x.to_include_path_name().into()),
))
});

'main: for (i, v) in scan_sources
.iter()
.map(|x| {
let bytes = x.to_memslice_async_assume_latest(run_async)?;
PolarsResult::Ok((
bytes,
include_file_paths.then(|| x.to_include_path_name().into()),
))
})
.enumerate()
{
'main: for (i, v) in memslice_sources.enumerate() {
if verbose {
eprintln!(
"[CsvSource]: Start line splitting for file {} / {}",
Expand Down Expand Up @@ -326,7 +316,7 @@ impl CsvSourceNode {
skip_lines,
skip_rows_before_header,
skip_rows_after_header,
comment_prefix,
comment_prefix.as_ref(),
has_header,
)?;

Expand All @@ -337,7 +327,7 @@ impl CsvSourceNode {
let chunk_size = if global_slice.is_some() {
max_chunk_size
} else {
std::cmp::min(bytes.len() / n_parts_hint, max_chunk_size)
std::cmp::min(bytes.len() / (16 * num_pipelines), max_chunk_size)
};

// Use a small min chunk size to catch failures in tests.
Expand Down Expand Up @@ -385,49 +375,24 @@ impl CsvSourceNode {
(0, 0)
};

let mut mem_slice_this_chunk =
let mem_slice_this_chunk =
mem_slice.slice(slice_start..slice_start + position);

let morsel_seq = *morsel_seq_ref;
*morsel_seq_ref = morsel_seq.successor();

let Some(mut indexed_wait_group) = wait_groups.next().await else {
break;
};

let mut path_name = path_name.clone();

loop {
use crate::async_primitives::connector::SendError;

let channel_index = indexed_wait_group.index();
let wait_token = indexed_wait_group.token();

match line_batch_senders[channel_index].try_send(LineBatch {
bytes: mem_slice_this_chunk,
n_lines: count,
slice,
row_offset: current_row_offset,
morsel_seq,
wait_token,
path_name,
}) {
Ok(_) => {
wait_groups.push(indexed_wait_group.wait());
break;
},
Err(SendError::Closed(v)) => {
mem_slice_this_chunk = v.bytes;
path_name = v.path_name;
},
Err(SendError::Full(_)) => unreachable!(),
}
let path_name = path_name.clone();

let Some(v) = wait_groups.next().await else {
break 'main; // All channels closed
};

indexed_wait_group = v;
let batch = LineBatch {
bytes: mem_slice_this_chunk,
n_lines: count,
slice,
row_offset: current_row_offset,
morsel_seq,
path_name,
};
if line_batch_sender.send(batch).await.is_err() {
break;
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
feature(stdarch_aarch64_prefetch)
)]
#![cfg_attr(feature = "nightly", feature(core_intrinsics))] // For algebraic ops.
#![cfg_attr(feature = "nightly", feature(select_unpredictable))] // For branchless programming.
#![cfg_attr(feature = "nightly", allow(internal_features))]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
pub mod abs_diff;
Expand All @@ -26,6 +27,7 @@ pub mod mem;
pub mod min_max;
pub mod pl_str;
pub mod priority;
pub mod select;
pub mod slice;
pub mod sort;
pub mod sync;
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-utils/src/select.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#[cfg(feature = "nightly")]
pub fn select_unpredictable<T>(cond: bool, true_val: T, false_val: T) -> T {
cond.select_unpredictable(true_val, false_val)
}

#[cfg(not(feature = "nightly"))]
pub fn select_unpredictable<T>(cond: bool, true_val: T, false_val: T) -> T {
if cond {
true_val
} else {
false_val
}
}

0 comments on commit 2be2036

Please sign in to comment.