Skip to content

Commit 38caf97

Browse files
authored
fix case of f(scalar, array) invocation (#63)
1 parent 4956e2b commit 38caf97

File tree

4 files changed

+241
-116
lines changed

4 files changed

+241
-116
lines changed

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ paste = "1"
1717
log = "0.4"
1818

1919
[dev-dependencies]
20+
datafusion = { version = "44", default-features = false, features = ["nested_expressions"] }
2021
codspeed-criterion-compat = "2.6"
2122
criterion = "0.5.1"
2223
clap = "4"

src/common.rs

+163-115
Original file line numberDiff line numberDiff line change
@@ -95,17 +95,46 @@ impl From<i64> for JsonPath<'_> {
9595
}
9696
}
9797

98-
impl<'s> JsonPath<'s> {
99-
pub fn extract_path(args: &'s [ColumnarValue]) -> Vec<Self> {
100-
args[1..]
98+
enum JsonPathArgs<'a> {
99+
Array(&'a ArrayRef),
100+
Scalars(Vec<JsonPath<'a>>),
101+
}
102+
103+
impl<'s> JsonPathArgs<'s> {
104+
fn extract_path(path_args: &'s [ColumnarValue]) -> DataFusionResult<Self> {
105+
// If there is a single argument as an array, we know how to handle it
106+
if let Some((ColumnarValue::Array(array), &[])) = path_args.split_first() {
107+
return Ok(Self::Array(array));
108+
}
109+
110+
path_args
101111
.iter()
102-
.map(|arg| match arg {
103-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s))) => Self::Key(s),
104-
ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => (*i).into(),
105-
ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => (*i).into(),
106-
_ => Self::None,
112+
.enumerate()
113+
.map(|(pos, arg)| match arg {
114+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)) | ScalarValue::LargeUtf8(Some(s))) => {
115+
Ok(JsonPath::Key(s))
116+
}
117+
ColumnarValue::Scalar(ScalarValue::UInt64(Some(i))) => Ok((*i).into()),
118+
ColumnarValue::Scalar(ScalarValue::Int64(Some(i))) => Ok((*i).into()),
119+
ColumnarValue::Scalar(
120+
ScalarValue::Null
121+
| ScalarValue::Utf8(None)
122+
| ScalarValue::LargeUtf8(None)
123+
| ScalarValue::UInt64(None)
124+
| ScalarValue::Int64(None),
125+
) => Ok(JsonPath::None),
126+
ColumnarValue::Array(_) => {
127+
// if there was a single arg, which is an array, handled above in the
128+
// split_first case. So this is multiple args of which one is an array
129+
exec_err!("More than 1 path element is not supported when querying JSON using an array.")
130+
}
131+
ColumnarValue::Scalar(arg) => exec_err!(
132+
"Unexpected argument type at position {}, expected string or int, got {arg:?}.",
133+
pos + 1
134+
),
107135
})
108-
.collect()
136+
.collect::<DataFusionResult<_>>()
137+
.map(JsonPathArgs::Scalars)
109138
}
110139
}
111140

@@ -116,154 +145,173 @@ pub fn invoke<C: FromIterator<Option<I>> + 'static, I>(
116145
to_scalar: impl Fn(Option<I>) -> ScalarValue,
117146
return_dict: bool,
118147
) -> DataFusionResult<ColumnarValue> {
119-
let Some(first_arg) = args.first() else {
120-
// I think this can't happen, but I assumed the same about args[1] and I was wrong, so better to be safe
148+
let Some((json_arg, path_args)) = args.split_first() else {
121149
return exec_err!("expected at least one argument");
122150
};
123-
match first_arg {
124-
ColumnarValue::Array(json_array) => {
125-
let array = match args.get(1) {
126-
Some(ColumnarValue::Array(a)) => {
127-
if args.len() > 2 {
128-
// TODO perhaps we could support this by zipping the arrays, but it's not trivial, #23
129-
exec_err!("More than 1 path element is not supported when querying JSON using an array.")
130-
} else {
131-
invoke_array(json_array, a, to_array, jiter_find, return_dict)
132-
}
133-
}
134-
Some(ColumnarValue::Scalar(_)) => scalar_apply(
135-
json_array,
136-
&JsonPath::extract_path(args),
137-
to_array,
138-
jiter_find,
139-
return_dict,
140-
),
141-
None => scalar_apply(json_array, &[], to_array, jiter_find, return_dict),
142-
};
143-
array.map(ColumnarValue::from)
151+
152+
let path = JsonPathArgs::extract_path(path_args)?;
153+
match (json_arg, path) {
154+
(ColumnarValue::Array(json_array), JsonPathArgs::Array(path_array)) => {
155+
invoke_array_array(json_array, path_array, to_array, jiter_find, return_dict).map(ColumnarValue::Array)
156+
}
157+
(ColumnarValue::Array(json_array), JsonPathArgs::Scalars(path)) => {
158+
invoke_array_scalars(json_array, &path, to_array, jiter_find, return_dict).map(ColumnarValue::Array)
159+
}
160+
(ColumnarValue::Scalar(s), JsonPathArgs::Array(path_array)) => {
161+
invoke_scalar_array(s, path_array, jiter_find, to_array)
162+
}
163+
(ColumnarValue::Scalar(s), JsonPathArgs::Scalars(path)) => {
164+
invoke_scalar_scalars(s, &path, jiter_find, to_scalar)
144165
}
145-
ColumnarValue::Scalar(s) => invoke_scalar(s, args, jiter_find, to_scalar),
146166
}
147167
}
148168

149-
fn invoke_array<C: FromIterator<Option<I>> + 'static, I>(
169+
fn invoke_array_array<C: FromIterator<Option<I>> + 'static, I>(
150170
json_array: &ArrayRef,
151-
needle_array: &ArrayRef,
171+
path_array: &ArrayRef,
152172
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
153173
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
154174
return_dict: bool,
155175
) -> DataFusionResult<ArrayRef> {
156176
downcast_dictionary_array!(
157-
needle_array => match needle_array.values().data_type() {
158-
DataType::Utf8 => zip_apply(json_array, needle_array.downcast_dict::<StringArray>().unwrap(), to_array, jiter_find, true, return_dict),
159-
DataType::LargeUtf8 => zip_apply(json_array, needle_array.downcast_dict::<LargeStringArray>().unwrap(), to_array, jiter_find, true, return_dict),
160-
DataType::Utf8View => zip_apply(json_array, needle_array.downcast_dict::<StringViewArray>().unwrap(), to_array, jiter_find, true, return_dict),
161-
DataType::Int64 => zip_apply(json_array, needle_array.downcast_dict::<Int64Array>().unwrap(), to_array, jiter_find, false, return_dict),
162-
DataType::UInt64 => zip_apply(json_array, needle_array.downcast_dict::<UInt64Array>().unwrap(), to_array, jiter_find, false, return_dict),
163-
other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other),
164-
},
165-
DataType::Utf8 => zip_apply(json_array, needle_array.as_string::<i32>(), to_array, jiter_find, true, return_dict),
166-
DataType::LargeUtf8 => zip_apply(json_array, needle_array.as_string::<i64>(), to_array, jiter_find, true, return_dict),
167-
DataType::Utf8View => zip_apply(json_array, needle_array.as_string_view(), to_array, jiter_find, true, return_dict),
168-
DataType::Int64 => zip_apply(json_array, needle_array.as_primitive::<Int64Type>(), to_array, jiter_find, false, return_dict),
169-
DataType::UInt64 => zip_apply(json_array, needle_array.as_primitive::<UInt64Type>(), to_array, jiter_find, false, return_dict),
170-
other => exec_err!("unexpected second argument type, expected string or int array, got {:?}", other)
177+
json_array => {
178+
let values = invoke_array_array(json_array.values(), path_array, to_array, jiter_find, return_dict)?;
179+
post_process_dict(json_array, values, return_dict)
180+
}
181+
DataType::Utf8 => zip_apply(json_array.as_string::<i32>().iter(), path_array, to_array, jiter_find),
182+
DataType::LargeUtf8 => zip_apply(json_array.as_string::<i64>().iter(), path_array, to_array, jiter_find),
183+
DataType::Utf8View => zip_apply(json_array.as_string_view().iter(), path_array, to_array, jiter_find),
184+
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup_array(path_array.data_type())) {
185+
zip_apply(string_array.iter(), path_array, to_array, jiter_find)
186+
} else {
187+
exec_err!("unexpected json array type {:?}", other)
188+
}
171189
)
172190
}
173191

174-
fn zip_apply<'a, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
192+
fn invoke_array_scalars<C: FromIterator<Option<I>>, I>(
175193
json_array: &ArrayRef,
176-
path_array: impl ArrayAccessor<Item = P>,
194+
path: &[JsonPath],
177195
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
178196
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
179-
object_lookup: bool,
180197
return_dict: bool,
181198
) -> DataFusionResult<ArrayRef> {
199+
fn inner<'j, C: FromIterator<Option<I>>, I>(
200+
json_iter: impl IntoIterator<Item = Option<&'j str>>,
201+
path: &[JsonPath],
202+
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
203+
) -> C {
204+
json_iter
205+
.into_iter()
206+
.map(|opt_json| jiter_find(opt_json, path).ok())
207+
.collect::<C>()
208+
}
209+
182210
let c = downcast_dictionary_array!(
183211
json_array => {
184-
let values = zip_apply(json_array.values(), path_array, to_array, jiter_find, object_lookup, false)?;
212+
let values = invoke_array_scalars(json_array.values(), path, to_array, jiter_find, false)?;
185213
return post_process_dict(json_array, values, return_dict);
186214
}
187-
DataType::Utf8 => zip_apply_iter(json_array.as_string::<i32>().iter(), path_array, jiter_find),
188-
DataType::LargeUtf8 => zip_apply_iter(json_array.as_string::<i64>().iter(), path_array, jiter_find),
189-
DataType::Utf8View => zip_apply_iter(json_array.as_string_view().iter(), path_array, jiter_find),
190-
other => if let Some(string_array) = nested_json_array(json_array, object_lookup) {
191-
zip_apply_iter(string_array.iter(), path_array, jiter_find)
215+
DataType::Utf8 => inner(json_array.as_string::<i32>(), path, jiter_find),
216+
DataType::LargeUtf8 => inner(json_array.as_string::<i64>(), path, jiter_find),
217+
DataType::Utf8View => inner(json_array.as_string_view(), path, jiter_find),
218+
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
219+
inner(string_array, path, jiter_find)
192220
} else {
193221
return exec_err!("unexpected json array type {:?}", other);
194222
}
195223
);
196-
197224
to_array(c)
198225
}
199226

200-
#[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references
201-
fn zip_apply_iter<'a, 'j, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
202-
json_iter: impl Iterator<Item = Option<&'j str>>,
203-
path_array: impl ArrayAccessor<Item = P>,
227+
fn invoke_scalar_array<C: FromIterator<Option<I>> + 'static, I>(
228+
scalar: &ScalarValue,
229+
path_array: &ArrayRef,
204230
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
205-
) -> C {
206-
json_iter
207-
.enumerate()
208-
.map(|(i, opt_json)| {
209-
if path_array.is_null(i) {
210-
None
211-
} else {
212-
let path = path_array.value(i).into();
213-
jiter_find(opt_json, &[path]).ok()
214-
}
215-
})
216-
.collect::<C>()
231+
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
232+
) -> DataFusionResult<ColumnarValue> {
233+
let s = extract_json_scalar(scalar)?;
234+
// TODO: possible optimization here if path_array is a dictionary; can apply against the
235+
// dictionary values directly for less work
236+
zip_apply(
237+
std::iter::repeat(s).take(path_array.len()),
238+
path_array,
239+
to_array,
240+
jiter_find,
241+
)
242+
.map(ColumnarValue::Array)
217243
}
218244

219-
fn invoke_scalar<I>(
245+
fn invoke_scalar_scalars<I>(
220246
scalar: &ScalarValue,
221-
args: &[ColumnarValue],
247+
path: &[JsonPath],
222248
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
223249
to_scalar: impl Fn(Option<I>) -> ScalarValue,
224250
) -> DataFusionResult<ColumnarValue> {
225-
match scalar {
226-
ScalarValue::Dictionary(_, b) => invoke_scalar(b.as_ref(), args, jiter_find, to_scalar),
227-
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => {
228-
let path = JsonPath::extract_path(args);
229-
let v = jiter_find(s.as_ref().map(String::as_str), &path).ok();
230-
Ok(ColumnarValue::Scalar(to_scalar(v)))
231-
}
232-
ScalarValue::Union(type_id_value, union_fields, _) => {
233-
let opt_json = json_from_union_scalar(type_id_value.as_ref(), union_fields);
234-
let v = jiter_find(opt_json, &JsonPath::extract_path(args)).ok();
235-
Ok(ColumnarValue::Scalar(to_scalar(v)))
236-
}
237-
_ => {
238-
exec_err!("unexpected first argument type, expected string or JSON union")
239-
}
240-
}
251+
let s = extract_json_scalar(scalar)?;
252+
let v = jiter_find(s, path).ok();
253+
Ok(ColumnarValue::Scalar(to_scalar(v)))
241254
}
242255

243-
fn scalar_apply<C: FromIterator<Option<I>>, I>(
244-
json_array: &ArrayRef,
245-
path: &[JsonPath],
256+
fn zip_apply<'a, C: FromIterator<Option<I>> + 'static, I>(
257+
json_array: impl IntoIterator<Item = Option<&'a str>>,
258+
path_array: &ArrayRef,
246259
to_array: impl Fn(C) -> DataFusionResult<ArrayRef>,
247260
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
248-
return_dict: bool,
249261
) -> DataFusionResult<ArrayRef> {
262+
#[allow(clippy::needless_pass_by_value)] // ArrayAccessor is implemented on references
263+
fn inner<'a, 'j, P: Into<JsonPath<'a>>, C: FromIterator<Option<I>> + 'static, I>(
264+
json_iter: impl IntoIterator<Item = Option<&'j str>>,
265+
path_array: impl ArrayAccessor<Item = P>,
266+
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
267+
) -> C {
268+
json_iter
269+
.into_iter()
270+
.enumerate()
271+
.map(|(i, opt_json)| {
272+
if path_array.is_null(i) {
273+
None
274+
} else {
275+
let path = path_array.value(i).into();
276+
jiter_find(opt_json, &[path]).ok()
277+
}
278+
})
279+
.collect::<C>()
280+
}
281+
250282
let c = downcast_dictionary_array!(
251-
json_array => {
252-
let values = scalar_apply(json_array.values(), path, to_array, jiter_find, false)?;
253-
return post_process_dict(json_array, values, return_dict);
254-
}
255-
DataType::Utf8 => scalar_apply_iter(json_array.as_string::<i32>().iter(), path, jiter_find),
256-
DataType::LargeUtf8 => scalar_apply_iter(json_array.as_string::<i64>().iter(), path, jiter_find),
257-
DataType::Utf8View => scalar_apply_iter(json_array.as_string_view().iter(), path, jiter_find),
258-
other => if let Some(string_array) = nested_json_array(json_array, is_object_lookup(path)) {
259-
scalar_apply_iter(string_array.iter(), path, jiter_find)
260-
} else {
261-
return exec_err!("unexpected json array type {:?}", other);
262-
}
283+
path_array => match path_array.values().data_type() {
284+
DataType::Utf8 => inner(json_array, path_array.downcast_dict::<StringArray>().unwrap(), jiter_find),
285+
DataType::LargeUtf8 => inner(json_array, path_array.downcast_dict::<LargeStringArray>().unwrap(), jiter_find),
286+
DataType::Utf8View => inner(json_array, path_array.downcast_dict::<StringViewArray>().unwrap(), jiter_find),
287+
DataType::Int64 => inner(json_array, path_array.downcast_dict::<Int64Array>().unwrap(), jiter_find),
288+
DataType::UInt64 => inner(json_array, path_array.downcast_dict::<UInt64Array>().unwrap(), jiter_find),
289+
other => return exec_err!("unexpected second argument type, expected string or int array, got {:?}", other),
290+
},
291+
DataType::Utf8 => inner(json_array, path_array.as_string::<i32>(), jiter_find),
292+
DataType::LargeUtf8 => inner(json_array, path_array.as_string::<i64>(), jiter_find),
293+
DataType::Utf8View => inner(json_array, path_array.as_string_view(), jiter_find),
294+
DataType::Int64 => inner(json_array, path_array.as_primitive::<Int64Type>(), jiter_find),
295+
DataType::UInt64 => inner(json_array, path_array.as_primitive::<UInt64Type>(), jiter_find),
296+
other => return exec_err!("unexpected second argument type, expected string or int array, got {:?}", other)
263297
);
298+
264299
to_array(c)
265300
}
266301

302+
fn extract_json_scalar(scalar: &ScalarValue) -> DataFusionResult<Option<&str>> {
303+
match scalar {
304+
ScalarValue::Dictionary(_, b) => extract_json_scalar(b.as_ref()),
305+
ScalarValue::Utf8(s) | ScalarValue::Utf8View(s) | ScalarValue::LargeUtf8(s) => Ok(s.as_deref()),
306+
ScalarValue::Union(type_id_value, union_fields, _) => {
307+
Ok(json_from_union_scalar(type_id_value.as_ref(), union_fields))
308+
}
309+
_ => {
310+
exec_err!("unexpected first argument type, expected string or JSON union")
311+
}
312+
}
313+
}
314+
267315
/// Take a dictionary array of JSON data and an array of result values and combine them.
268316
fn post_process_dict<T: ArrowDictionaryKeyType>(
269317
dict_array: &DictionaryArray<T>,
@@ -295,12 +343,12 @@ fn is_object_lookup(path: &[JsonPath]) -> bool {
295343
}
296344
}
297345

298-
fn scalar_apply_iter<'j, C: FromIterator<Option<I>>, I>(
299-
json_iter: impl Iterator<Item = Option<&'j str>>,
300-
path: &[JsonPath],
301-
jiter_find: impl Fn(Option<&str>, &[JsonPath]) -> Result<I, GetError>,
302-
) -> C {
303-
json_iter.map(|opt_json| jiter_find(opt_json, path).ok()).collect::<C>()
346+
fn is_object_lookup_array(data_type: &DataType) -> bool {
347+
match data_type {
348+
DataType::Dictionary(_, value_type) => is_object_lookup_array(value_type),
349+
DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => true,
350+
_ => false,
351+
}
304352
}
305353

306354
pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> {

src/rewrite.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use datafusion::logical_expr::sqlparser::ast::BinaryOperator;
1212
pub(crate) struct JsonFunctionRewriter;
1313

1414
impl FunctionRewrite for JsonFunctionRewriter {
15-
fn name(&self) -> &str {
15+
fn name(&self) -> &'static str {
1616
"JsonFunctionRewriter"
1717
}
1818

0 commit comments

Comments
 (0)