Skip to content

Commit dc0d35a

Browse files
authored
Add Interruptible Query Execution in Jupyter via KeyboardInterrupt Support (#1141)
* fix: enhance error handling in async wait_for_future function * feat: implement async execution for execution plans in PySessionContext * fix: improve error message for execution failures in PySessionContext * fix: enhance error handling and improve execution plan retrieval in PyDataFrame * fix: ensure 'static lifetime for futures in wait_for_future function * fix: handle potential errors when caching DataFrame and retrieving execution plan * fix: flatten batches in PyDataFrame to ensure proper schema conversion * fix: correct error handling in batch processing for schema conversion * fix: flatten nested structure in PyDataFrame to ensure proper RecordBatch iteration * fix: improve error handling in PyDataFrame stream execution * fix: add utility to get Tokio Runtime with time enabled and update wait_for_future to use it * fix: store result of converting RecordBatches to PyArrow for debugging * fix: handle error from wait_for_future in PyDataFrame collect method * fix: propagate error from wait_for_future in collect_partitioned method * fix: enable IO in Tokio runtime with time support * main register_listing_table * Revert "main register_listing_table" This reverts commit 52a5efe. * fix: propagate error correctly from wait_for_future in PySessionContext methods * fix: simplify error handling in PySessionContext by unwrapping wait_for_future result * test: add interruption handling test for long-running queries in DataFusion * test: move test_collect_interrupted to test_dataframe.py * fix: add const for interval in wait_for_future utility * fix: use get_tokio_runtime instead of the custom get_runtime * Revert "fix: use get_tokio_runtime instead of the custom get_runtime" This reverts commit ca2d892. * fix: use get_tokio_runtime instead of the custom get_runtime * . * Revert "." This reverts commit b8ce3e4. * fix: improve query interruption handling in test_collect_interrupted * fix: ensure proper handling of query interruption in test_collect_interrupted * fix: improve error handling in database table retrieval * refactor: add helper for async move * Revert "refactor: add helper for async move" This reverts commit faabf6d. * move py_err_to_datafusion_err to errors.rs * add create_csv_read_options * fix * create_csv_read_options -> PyDataFusionResult * revert to before create_csv_read_options * refactor: simplify file compression type parsing in PySessionContext * fix: parse_compression_type once only * add create_ndjson_read_options * refactor comment for clarity in wait_for_future function * refactor wait_for_future to avoid spawn * remove unused py_err_to_datafusion_err function * add comment to clarify error handling in next method of PyRecordBatchStream * handle error from wait_for_future in PySubstraitSerializer * clarify comment on future pinning in wait_for_future function * refactor wait_for_future to use Duration for signal check interval * handle error from wait_for_future in count method of PyDataFrame * fix ruff errors * fix clippy errors * remove unused get_and_enter_tokio_runtime function and simplify wait_for_future * Refactor async handling in PySessionContext and PyDataFrame - Simplified async handling by removing unnecessary cloning of strings and context in various methods. - Streamlined the use of `wait_for_future` to directly handle futures without intermediate variables. - Improved error handling by directly propagating results from futures. - Enhanced readability by reducing boilerplate code in methods related to reading and writing data. - Updated the `wait_for_future` function to improve signal checking and future handling. * Organize imports in utils.rs for improved readability * map_err instead of panic * Fix error handling in async stream execution for PySessionContext and PyDataFrame
1 parent d6ef9bc commit dc0d35a

File tree

7 files changed

+211
-66
lines changed

7 files changed

+211
-66
lines changed

python/tests/test_dataframe.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,12 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import ctypes
1718
import datetime
1819
import os
1920
import re
21+
import threading
22+
import time
2023
from typing import Any
2124

2225
import pyarrow as pa
@@ -2060,3 +2063,121 @@ def test_fill_null_all_null_column(ctx):
20602063
# Check that all nulls were filled
20612064
result = filled_df.collect()[0]
20622065
assert result.column(1).to_pylist() == ["filled", "filled", "filled"]
2066+
2067+
2068+
def test_collect_interrupted():
2069+
"""Test that a long-running query can be interrupted with Ctrl-C.
2070+
2071+
This test simulates a Ctrl-C keyboard interrupt by raising a KeyboardInterrupt
2072+
exception in the main thread during a long-running query execution.
2073+
"""
2074+
# Create a context and a DataFrame with a query that will run for a while
2075+
ctx = SessionContext()
2076+
2077+
# Create a recursive computation that will run for some time
2078+
batches = []
2079+
for i in range(10):
2080+
batch = pa.RecordBatch.from_arrays(
2081+
[
2082+
pa.array(list(range(i * 1000, (i + 1) * 1000))),
2083+
pa.array([f"value_{j}" for j in range(i * 1000, (i + 1) * 1000)]),
2084+
],
2085+
names=["a", "b"],
2086+
)
2087+
batches.append(batch)
2088+
2089+
# Register tables
2090+
ctx.register_record_batches("t1", [batches])
2091+
ctx.register_record_batches("t2", [batches])
2092+
2093+
# Create a large join operation that will take time to process
2094+
df = ctx.sql("""
2095+
WITH t1_expanded AS (
2096+
SELECT
2097+
a,
2098+
b,
2099+
CAST(a AS DOUBLE) / 1.5 AS c,
2100+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS d
2101+
FROM t1
2102+
CROSS JOIN (SELECT 1 AS dummy FROM t1 LIMIT 5)
2103+
),
2104+
t2_expanded AS (
2105+
SELECT
2106+
a,
2107+
b,
2108+
CAST(a AS DOUBLE) * 2.5 AS e,
2109+
CAST(a AS DOUBLE) * CAST(a AS DOUBLE) * CAST(a AS DOUBLE) AS f
2110+
FROM t2
2111+
CROSS JOIN (SELECT 1 AS dummy FROM t2 LIMIT 5)
2112+
)
2113+
SELECT
2114+
t1.a, t1.b, t1.c, t1.d,
2115+
t2.a AS a2, t2.b AS b2, t2.e, t2.f
2116+
FROM t1_expanded t1
2117+
JOIN t2_expanded t2 ON t1.a % 100 = t2.a % 100
2118+
WHERE t1.a > 100 AND t2.a > 100
2119+
""")
2120+
2121+
# Flag to track if the query was interrupted
2122+
interrupted = False
2123+
interrupt_error = None
2124+
main_thread = threading.main_thread()
2125+
2126+
# Shared flag to indicate query execution has started
2127+
query_started = threading.Event()
2128+
max_wait_time = 5.0 # Maximum wait time in seconds
2129+
2130+
# This function will be run in a separate thread and will raise
2131+
# KeyboardInterrupt in the main thread
2132+
def trigger_interrupt():
2133+
"""Poll for query start, then raise KeyboardInterrupt in the main thread"""
2134+
# Poll for query to start with small sleep intervals
2135+
start_time = time.time()
2136+
while not query_started.is_set():
2137+
time.sleep(0.1) # Small sleep between checks
2138+
if time.time() - start_time > max_wait_time:
2139+
msg = f"Query did not start within {max_wait_time} seconds"
2140+
raise RuntimeError(msg)
2141+
2142+
# Check if thread ID is available
2143+
thread_id = main_thread.ident
2144+
if thread_id is None:
2145+
msg = "Cannot get main thread ID"
2146+
raise RuntimeError(msg)
2147+
2148+
# Use ctypes to raise exception in main thread
2149+
exception = ctypes.py_object(KeyboardInterrupt)
2150+
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
2151+
ctypes.c_long(thread_id), exception
2152+
)
2153+
if res != 1:
2154+
# If res is 0, the thread ID was invalid
2155+
# If res > 1, we modified multiple threads
2156+
ctypes.pythonapi.PyThreadState_SetAsyncExc(
2157+
ctypes.c_long(thread_id), ctypes.py_object(0)
2158+
)
2159+
msg = "Failed to raise KeyboardInterrupt in main thread"
2160+
raise RuntimeError(msg)
2161+
2162+
# Start a thread to trigger the interrupt
2163+
interrupt_thread = threading.Thread(target=trigger_interrupt)
2164+
# we mark as daemon so the test process can exit even if this thread doesn't finish
2165+
interrupt_thread.daemon = True
2166+
interrupt_thread.start()
2167+
2168+
# Execute the query and expect it to be interrupted
2169+
try:
2170+
# Signal that we're about to start the query
2171+
query_started.set()
2172+
df.collect()
2173+
except KeyboardInterrupt:
2174+
interrupted = True
2175+
except Exception as e:
2176+
interrupt_error = e
2177+
2178+
# Assert that the query was interrupted properly
2179+
if not interrupted:
2180+
pytest.fail(f"Query was not interrupted; got error: {interrupt_error}")
2181+
2182+
# Make sure the interrupt thread has finished
2183+
interrupt_thread.join(timeout=1.0)

src/catalog.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ impl PyDatabase {
9797
}
9898

9999
fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
100-
if let Some(table) = wait_for_future(py, self.database.table(name))? {
100+
if let Some(table) = wait_for_future(py, self.database.table(name))?? {
101101
Ok(PyTable::new(table))
102102
} else {
103103
Err(PyDataFusionError::Common(format!(

src/context.rs

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use pyo3::prelude::*;
3434
use crate::catalog::{PyCatalog, PyTable};
3535
use crate::dataframe::PyDataFrame;
3636
use crate::dataset::Dataset;
37-
use crate::errors::{py_datafusion_err, PyDataFusionResult};
37+
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
3838
use crate::expr::sort_expr::PySortExpr;
3939
use crate::physical_plan::PyExecutionPlan;
4040
use crate::record_batch::PyRecordBatchStream;
@@ -375,7 +375,7 @@ impl PySessionContext {
375375
None => {
376376
let state = self.ctx.state();
377377
let schema = options.infer_schema(&state, &table_path);
378-
wait_for_future(py, schema)?
378+
wait_for_future(py, schema)??
379379
}
380380
};
381381
let config = ListingTableConfig::new(table_path)
@@ -400,7 +400,7 @@ impl PySessionContext {
400400
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
401401
pub fn sql(&mut self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
402402
let result = self.ctx.sql(query);
403-
let df = wait_for_future(py, result)?;
403+
let df = wait_for_future(py, result)??;
404404
Ok(PyDataFrame::new(df))
405405
}
406406

@@ -417,7 +417,7 @@ impl PySessionContext {
417417
SQLOptions::new()
418418
};
419419
let result = self.ctx.sql_with_options(query, options);
420-
let df = wait_for_future(py, result)?;
420+
let df = wait_for_future(py, result)??;
421421
Ok(PyDataFrame::new(df))
422422
}
423423

@@ -451,7 +451,7 @@ impl PySessionContext {
451451

452452
self.ctx.register_table(&*table_name, Arc::new(table))?;
453453

454-
let table = wait_for_future(py, self._table(&table_name))?;
454+
let table = wait_for_future(py, self._table(&table_name))??;
455455

456456
let df = PyDataFrame::new(table);
457457
Ok(df)
@@ -650,7 +650,7 @@ impl PySessionContext {
650650
.collect();
651651

652652
let result = self.ctx.register_parquet(name, path, options);
653-
wait_for_future(py, result)?;
653+
wait_for_future(py, result)??;
654654
Ok(())
655655
}
656656

@@ -693,11 +693,11 @@ impl PySessionContext {
693693
if path.is_instance_of::<PyList>() {
694694
let paths = path.extract::<Vec<String>>()?;
695695
let result = self.register_csv_from_multiple_paths(name, paths, options);
696-
wait_for_future(py, result)?;
696+
wait_for_future(py, result)??;
697697
} else {
698698
let path = path.extract::<String>()?;
699699
let result = self.ctx.register_csv(name, &path, options);
700-
wait_for_future(py, result)?;
700+
wait_for_future(py, result)??;
701701
}
702702

703703
Ok(())
@@ -734,7 +734,7 @@ impl PySessionContext {
734734
options.schema = schema.as_ref().map(|x| &x.0);
735735

736736
let result = self.ctx.register_json(name, path, options);
737-
wait_for_future(py, result)?;
737+
wait_for_future(py, result)??;
738738

739739
Ok(())
740740
}
@@ -764,7 +764,7 @@ impl PySessionContext {
764764
options.schema = schema.as_ref().map(|x| &x.0);
765765

766766
let result = self.ctx.register_avro(name, path, options);
767-
wait_for_future(py, result)?;
767+
wait_for_future(py, result)??;
768768

769769
Ok(())
770770
}
@@ -825,9 +825,19 @@ impl PySessionContext {
825825
}
826826

827827
pub fn table(&self, name: &str, py: Python) -> PyResult<PyDataFrame> {
828-
let x = wait_for_future(py, self.ctx.table(name))
828+
let res = wait_for_future(py, self.ctx.table(name))
829829
.map_err(|e| PyKeyError::new_err(e.to_string()))?;
830-
Ok(PyDataFrame::new(x))
830+
match res {
831+
Ok(df) => Ok(PyDataFrame::new(df)),
832+
Err(e) => {
833+
if let datafusion::error::DataFusionError::Plan(msg) = &e {
834+
if msg.contains("No table named") {
835+
return Err(PyKeyError::new_err(msg.to_string()));
836+
}
837+
}
838+
Err(py_datafusion_err(e))
839+
}
840+
}
831841
}
832842

833843
pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
@@ -865,10 +875,10 @@ impl PySessionContext {
865875
let df = if let Some(schema) = schema {
866876
options.schema = Some(&schema.0);
867877
let result = self.ctx.read_json(path, options);
868-
wait_for_future(py, result)?
878+
wait_for_future(py, result)??
869879
} else {
870880
let result = self.ctx.read_json(path, options);
871-
wait_for_future(py, result)?
881+
wait_for_future(py, result)??
872882
};
873883
Ok(PyDataFrame::new(df))
874884
}
@@ -915,12 +925,12 @@ impl PySessionContext {
915925
let paths = path.extract::<Vec<String>>()?;
916926
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
917927
let result = self.ctx.read_csv(paths, options);
918-
let df = PyDataFrame::new(wait_for_future(py, result)?);
928+
let df = PyDataFrame::new(wait_for_future(py, result)??);
919929
Ok(df)
920930
} else {
921931
let path = path.extract::<String>()?;
922932
let result = self.ctx.read_csv(path, options);
923-
let df = PyDataFrame::new(wait_for_future(py, result)?);
933+
let df = PyDataFrame::new(wait_for_future(py, result)??);
924934
Ok(df)
925935
}
926936
}
@@ -958,7 +968,7 @@ impl PySessionContext {
958968
.collect();
959969

960970
let result = self.ctx.read_parquet(path, options);
961-
let df = PyDataFrame::new(wait_for_future(py, result)?);
971+
let df = PyDataFrame::new(wait_for_future(py, result)??);
962972
Ok(df)
963973
}
964974

@@ -978,10 +988,10 @@ impl PySessionContext {
978988
let df = if let Some(schema) = schema {
979989
options.schema = Some(&schema.0);
980990
let read_future = self.ctx.read_avro(path, options);
981-
wait_for_future(py, read_future)?
991+
wait_for_future(py, read_future)??
982992
} else {
983993
let read_future = self.ctx.read_avro(path, options);
984-
wait_for_future(py, read_future)?
994+
wait_for_future(py, read_future)??
985995
};
986996
Ok(PyDataFrame::new(df))
987997
}
@@ -1021,8 +1031,8 @@ impl PySessionContext {
10211031
let plan = plan.plan.clone();
10221032
let fut: JoinHandle<datafusion::common::Result<SendableRecordBatchStream>> =
10231033
rt.spawn(async move { plan.execute(part, Arc::new(ctx)) });
1024-
let stream = wait_for_future(py, fut).map_err(py_datafusion_err)?;
1025-
Ok(PyRecordBatchStream::new(stream?))
1034+
let stream = wait_for_future(py, async { fut.await.map_err(to_datafusion_err) })???;
1035+
Ok(PyRecordBatchStream::new(stream))
10261036
}
10271037
}
10281038

0 commit comments

Comments
 (0)