Skip to content

Commit d1ebdbf

Browse files
authored
Add tests and CI for optional pyarrow module (#1711)
* Implement other side of conversion * Add test workflow * Add (failing) tests * Get unit tests passing * Use python -m pip * Debug LD_LIBRARY_PATH * Set LIBRARY_PATH * Update help with better info
1 parent e4a056f commit d1ebdbf

File tree

2 files changed

+136
-4
lines changed

2 files changed

+136
-4
lines changed

.github/workflows/rust.yml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,55 @@ jobs:
230230
# do not produce debug symbols to keep memory usage down
231231
RUSTFLAGS: "-C debuginfo=0"
232232

233+
test-datafusion-pyarrow:
234+
needs: [linux-build-lib]
235+
runs-on: ubuntu-latest
236+
strategy:
237+
matrix:
238+
arch: [amd64]
239+
rust: [stable]
240+
container:
241+
image: ${{ matrix.arch }}/rust
242+
env:
243+
# Disable full debug symbol generation to speed up CI build and keep memory down
244+
# "1" means line tables only, which is useful for panic tracebacks.
245+
RUSTFLAGS: "-C debuginfo=1"
246+
steps:
247+
- uses: actions/checkout@v2
248+
with:
249+
submodules: true
250+
- name: Cache Cargo
251+
uses: actions/cache@v2
252+
with:
253+
path: /github/home/.cargo
254+
# this key equals the ones on `linux-build-lib` for re-use
255+
key: cargo-cache-
256+
- name: Cache Rust dependencies
257+
uses: actions/cache@v2
258+
with:
259+
path: /github/home/target
260+
# this key equals the ones on `linux-build-lib` for re-use
261+
key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }}
262+
- uses: actions/setup-python@v2
263+
with:
264+
python-version: "3.8"
265+
- name: Install PyArrow
266+
run: |
267+
echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV
268+
python -m pip install pyarrow
269+
- name: Setup Rust toolchain
270+
run: |
271+
rustup toolchain install ${{ matrix.rust }}
272+
rustup default ${{ matrix.rust }}
273+
rustup component add rustfmt
274+
- name: Run tests
275+
run: |
276+
cd datafusion
277+
cargo test --features=pyarrow
278+
env:
279+
CARGO_HOME: "/github/home/.cargo"
280+
CARGO_TARGET_DIR: "/github/home/target"
281+
233282
lint:
234283
name: Lint
235284
runs-on: ubuntu-latest

datafusion/src/pyarrow.rs

Lines changed: 87 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use pyo3::exceptions::{PyException, PyNotImplementedError};
18+
use pyo3::exceptions::PyException;
1919
use pyo3::prelude::*;
2020
use pyo3::types::PyList;
21-
use pyo3::PyNativeType;
2221

2322
use crate::arrow::array::ArrayData;
2423
use crate::arrow::pyarrow::PyArrowConvert;
@@ -49,8 +48,13 @@ impl PyArrowConvert for ScalarValue {
4948
Ok(scalar)
5049
}
5150

52-
fn to_pyarrow(&self, _py: Python) -> PyResult<PyObject> {
53-
Err(PyNotImplementedError::new_err("Not implemented"))
51+
fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
52+
let array = self.to_array();
53+
// convert to pyarrow array using C data interface
54+
let pyarray = array.data_ref().clone().into_py(py);
55+
let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
56+
57+
Ok(pyscalar)
5458
}
5559
}
5660

@@ -65,3 +69,82 @@ impl<'a> IntoPy<PyObject> for ScalarValue {
6569
self.to_pyarrow(py).unwrap()
6670
}
6771
}
72+
73+
#[cfg(test)]
74+
mod tests {
75+
use super::*;
76+
use pyo3::prepare_freethreaded_python;
77+
use pyo3::py_run;
78+
use pyo3::types::PyDict;
79+
use pyo3::Python;
80+
81+
fn init_python() {
82+
prepare_freethreaded_python();
83+
Python::with_gil(|py| {
84+
if let Err(err) = py.run("import pyarrow", None, None) {
85+
let locals = PyDict::new(py);
86+
py.run(
87+
"import sys; executable = sys.executable; python_path = sys.path",
88+
None,
89+
Some(locals),
90+
)
91+
.expect("Couldn't get python info");
92+
let executable: String =
93+
locals.get_item("executable").unwrap().extract().unwrap();
94+
let python_path: Vec<&str> =
95+
locals.get_item("python_path").unwrap().extract().unwrap();
96+
97+
Err(err).expect(
98+
format!(
99+
"pyarrow not found\nExecutable: {}\nPython path: {:?}\n\
100+
HINT: try `pip install pyarrow`\n\
101+
NOTE: On Mac OS, you must compile against a Framework Python \
102+
(default in python.org installers and brew, but not pyenv)\n\
103+
NOTE: On Mac OS, PYO3 might point to incorrect Python library \
104+
path when using virtual environments. Try \
105+
`export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n",
106+
executable, python_path
107+
)
108+
.as_ref(),
109+
)
110+
}
111+
})
112+
}
113+
114+
#[test]
115+
fn test_roundtrip() {
116+
init_python();
117+
118+
let example_scalars = vec![
119+
ScalarValue::Boolean(Some(true)),
120+
ScalarValue::Int32(Some(23)),
121+
ScalarValue::Float64(Some(12.34)),
122+
ScalarValue::Utf8(Some("Hello!".to_string())),
123+
ScalarValue::Date32(Some(1234)),
124+
];
125+
126+
Python::with_gil(|py| {
127+
for scalar in example_scalars.iter() {
128+
let result =
129+
ScalarValue::from_pyarrow(scalar.to_pyarrow(py).unwrap().as_ref(py))
130+
.unwrap();
131+
assert_eq!(scalar, &result);
132+
}
133+
});
134+
}
135+
136+
#[test]
137+
fn test_py_scalar() {
138+
init_python();
139+
140+
Python::with_gil(|py| {
141+
let scalar_float = ScalarValue::Float64(Some(12.34));
142+
let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap();
143+
py_run!(py, py_float, "assert py_float == 12.34");
144+
145+
let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
146+
let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap();
147+
py_run!(py, py_string, "assert py_string == 'Hello!'");
148+
});
149+
}
150+
}

0 commit comments

Comments
 (0)