Skip to content

Commit 3f87b34

Browse files
committed
Server side parameters via with_param
Fixes #142
1 parent de075b9 commit 3f87b34

File tree

5 files changed

+243
-20
lines changed

5 files changed

+243
-20
lines changed

src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
#[macro_use]
66
extern crate static_assertions;
77

8-
use self::{error::Result, http_client::HttpClient};
8+
use self::{error::Result, http_client::HttpClient, sql::ser};
9+
use ::serde::Serialize;
910
use std::{collections::HashMap, fmt::Display, sync::Arc};
1011

1112
pub use self::{compression::Compression, row::Row};
@@ -160,6 +161,12 @@ impl Client {
160161
self
161162
}
162163

164+
pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
165+
let mut param = String::from("");
166+
ser::write_param(&mut param, &value)?;
167+
Ok(self.with_option(format!("param_{name}"), param))
168+
}
169+
163170
/// Used to specify a header that will be passed to all queries.
164171
///
165172
/// # Example

src/query.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use hyper::{header::CONTENT_LENGTH, Method, Request};
2-
use serde::Deserialize;
2+
use serde::{Deserialize, Serialize};
33
use std::fmt::Display;
44
use url::Url;
55

@@ -10,7 +10,7 @@ use crate::{
1010
request_body::RequestBody,
1111
response::Response,
1212
row::Row,
13-
sql::{Bind, SqlBuilder},
13+
sql::{ser, Bind, SqlBuilder},
1414
Client,
1515
};
1616

@@ -195,6 +195,12 @@ impl Query {
195195
self.client.add_option(name, value);
196196
self
197197
}
198+
199+
pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
200+
let mut param = String::from("");
201+
ser::write_param(&mut param, &value)?;
202+
Ok(self.with_option(format!("param_{name}"), param))
203+
}
198204
}
199205

200206
/// A cursor that emits rows.

src/sql/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier};
99

1010
mod bind;
1111
pub(crate) mod escape;
12-
mod ser;
12+
pub(crate) mod ser;
1313

1414
#[derive(Debug, Clone)]
1515
pub(crate) enum SqlBuilder {

src/sql/ser.rs

Lines changed: 183 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,32 @@ use thiserror::Error;
88

99
use super::escape;
1010

11-
// === SqlSerializerError ===
11+
// === SerializerError ===
1212

1313
#[derive(Debug, Error)]
14-
enum SqlSerializerError {
14+
enum SerializerError {
1515
#[error("{0} is unsupported")]
1616
Unsupported(&'static str),
1717
#[error("{0}")]
1818
Custom(String),
1919
}
2020

21-
impl ser::Error for SqlSerializerError {
21+
impl ser::Error for SerializerError {
2222
fn custom<T: fmt::Display>(msg: T) -> Self {
2323
Self::Custom(msg.to_string())
2424
}
2525
}
2626

27-
impl From<fmt::Error> for SqlSerializerError {
27+
impl From<fmt::Error> for SerializerError {
2828
fn from(err: fmt::Error) -> Self {
2929
Self::Custom(err.to_string())
3030
}
3131
}
3232

3333
// === SqlSerializer ===
3434

35-
type Result<T = (), E = SqlSerializerError> = std::result::Result<T, E>;
36-
type Impossible = ser::Impossible<(), SqlSerializerError>;
35+
type Result<T = (), E = SerializerError> = std::result::Result<T, E>;
36+
type Impossible = ser::Impossible<(), SerializerError>;
3737

3838
struct SqlSerializer<'a, W> {
3939
writer: &'a mut W,
@@ -43,7 +43,7 @@ macro_rules! unsupported {
4343
($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => {
4444
#[inline]
4545
fn $ser_method(self, _v: $ty) -> $ret {
46-
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
46+
Err(SerializerError::Unsupported(stringify!($ser_method)))
4747
}
4848
unsupported!($($other)*);
4949
};
@@ -53,7 +53,7 @@ macro_rules! unsupported {
5353
($ser_method:ident, $($other:tt)*) => {
5454
#[inline]
5555
fn $ser_method(self) -> Result {
56-
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
56+
Err(SerializerError::Unsupported(stringify!($ser_method)))
5757
}
5858
unsupported!($($other)*);
5959
};
@@ -73,7 +73,7 @@ macro_rules! forward_to_display {
7373
}
7474

7575
impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
76-
type Error = SqlSerializerError;
76+
type Error = SerializerError;
7777
type Ok = ();
7878
type SerializeMap = Impossible;
7979
type SerializeSeq = SqlListSerializer<'a, W>;
@@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
177177
_variant: &'static str,
178178
_value: &T,
179179
) -> Result {
180-
Err(SqlSerializerError::Unsupported("serialize_newtype_variant"))
180+
Err(SerializerError::Unsupported("serialize_newtype_variant"))
181181
}
182182

183183
#[inline]
184184
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
185-
Err(SqlSerializerError::Unsupported("serialize_tuple_struct"))
185+
Err(SerializerError::Unsupported("serialize_tuple_struct"))
186186
}
187187

188188
#[inline]
@@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
193193
_variant: &'static str,
194194
_len: usize,
195195
) -> Result<Impossible> {
196-
Err(SqlSerializerError::Unsupported("serialize_tuple_variant"))
196+
Err(SerializerError::Unsupported("serialize_tuple_variant"))
197197
}
198198

199199
#[inline]
200200
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
201-
Err(SqlSerializerError::Unsupported("serialize_struct"))
201+
Err(SerializerError::Unsupported("serialize_struct"))
202202
}
203203

204204
#[inline]
@@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
209209
_variant: &'static str,
210210
_len: usize,
211211
) -> Result<Self::SerializeStructVariant> {
212-
Err(SqlSerializerError::Unsupported("serialize_struct_variant"))
212+
Err(SerializerError::Unsupported("serialize_struct_variant"))
213213
}
214214

215215
#[inline]
@@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> {
227227
}
228228

229229
impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
230-
type Error = SqlSerializerError;
230+
type Error = SerializerError;
231231
type Ok = ();
232232

233233
#[inline]
@@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
254254
}
255255

256256
impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
257-
type Error = SqlSerializerError;
257+
type Error = SerializerError;
258258
type Ok = ();
259259

260260
#[inline]
@@ -271,6 +271,167 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
271271
}
272272
}
273273

274+
// === ParamSerializer ===
275+
276+
struct ParamSerializer<'a, W> {
277+
writer: &'a mut W,
278+
}
279+
280+
impl<'a, W: Write> Serializer for ParamSerializer<'a, W> {
281+
type Error = SerializerError;
282+
type Ok = ();
283+
type SerializeMap = Impossible;
284+
type SerializeSeq = SqlListSerializer<'a, W>;
285+
type SerializeStruct = Impossible;
286+
type SerializeStructVariant = Impossible;
287+
type SerializeTuple = SqlListSerializer<'a, W>;
288+
type SerializeTupleStruct = Impossible;
289+
type SerializeTupleVariant = Impossible;
290+
291+
unsupported!(
292+
serialize_map(Option<usize>) -> Result<Impossible>,
293+
serialize_bytes(&[u8]),
294+
serialize_unit,
295+
serialize_unit_struct(&'static str),
296+
);
297+
298+
forward_to_display!(
299+
serialize_i8(i8),
300+
serialize_i16(i16),
301+
serialize_i32(i32),
302+
serialize_i64(i64),
303+
serialize_i128(i128),
304+
serialize_u8(u8),
305+
serialize_u16(u16),
306+
serialize_u32(u32),
307+
serialize_u64(u64),
308+
serialize_u128(u128),
309+
serialize_f32(f32),
310+
serialize_f64(f64),
311+
serialize_bool(bool),
312+
);
313+
314+
#[inline]
315+
fn serialize_char(self, value: char) -> Result {
316+
let mut tmp = [0u8; 4];
317+
self.serialize_str(value.encode_utf8(&mut tmp))
318+
}
319+
320+
#[inline]
321+
fn serialize_str(self, value: &str) -> Result {
322+
// ClickHouse expects strings in params to be unquoted until inside a nested type
323+
// nested types go through serialize_seq which'll quote strings
324+
let mut rest = value;
325+
while let Some(nextidx) = rest.find('\\') {
326+
let (before, after) = rest.split_at(nextidx + 1);
327+
rest = after;
328+
self.writer.write_str(before)?;
329+
self.writer.write_char('\\')?;
330+
}
331+
self.writer.write_str(rest)?;
332+
Ok(())
333+
}
334+
335+
#[inline]
336+
fn serialize_seq(self, _len: Option<usize>) -> Result<SqlListSerializer<'a, W>> {
337+
self.writer.write_char('[')?;
338+
Ok(SqlListSerializer {
339+
writer: self.writer,
340+
has_items: false,
341+
closing_char: ']',
342+
})
343+
}
344+
345+
#[inline]
346+
fn serialize_tuple(self, _len: usize) -> Result<SqlListSerializer<'a, W>> {
347+
self.writer.write_char('(')?;
348+
Ok(SqlListSerializer {
349+
writer: self.writer,
350+
has_items: false,
351+
closing_char: ')',
352+
})
353+
}
354+
355+
#[inline]
356+
fn serialize_some<T: Serialize + ?Sized>(self, _value: &T) -> Result {
357+
_value.serialize(self)
358+
}
359+
360+
#[inline]
361+
fn serialize_none(self) -> std::result::Result<Self::Ok, Self::Error> {
362+
self.writer.write_str("NULL")?;
363+
Ok(())
364+
}
365+
366+
#[inline]
367+
fn serialize_unit_variant(
368+
self,
369+
_name: &'static str,
370+
_variant_index: u32,
371+
variant: &'static str,
372+
) -> Result {
373+
escape::string(variant, self.writer)?;
374+
Ok(())
375+
}
376+
377+
#[inline]
378+
fn serialize_newtype_struct<T: Serialize + ?Sized>(
379+
self,
380+
_name: &'static str,
381+
value: &T,
382+
) -> Result {
383+
value.serialize(self)
384+
}
385+
386+
#[inline]
387+
fn serialize_newtype_variant<T: Serialize + ?Sized>(
388+
self,
389+
_name: &'static str,
390+
_variant_index: u32,
391+
_variant: &'static str,
392+
_value: &T,
393+
) -> Result {
394+
Err(SerializerError::Unsupported("serialize_newtype_variant"))
395+
}
396+
397+
#[inline]
398+
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
399+
Err(SerializerError::Unsupported("serialize_tuple_struct"))
400+
}
401+
402+
#[inline]
403+
fn serialize_tuple_variant(
404+
self,
405+
_name: &'static str,
406+
_variant_index: u32,
407+
_variant: &'static str,
408+
_len: usize,
409+
) -> Result<Impossible> {
410+
Err(SerializerError::Unsupported("serialize_tuple_variant"))
411+
}
412+
413+
#[inline]
414+
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
415+
Err(SerializerError::Unsupported("serialize_struct"))
416+
}
417+
418+
#[inline]
419+
fn serialize_struct_variant(
420+
self,
421+
_name: &'static str,
422+
_variant_index: u32,
423+
_variant: &'static str,
424+
_len: usize,
425+
) -> Result<Self::SerializeStructVariant> {
426+
Err(SerializerError::Unsupported("serialize_struct_variant"))
427+
}
428+
429+
#[inline]
430+
fn is_human_readable(&self) -> bool {
431+
true
432+
}
433+
}
434+
274435
// === Public API ===
275436

276437
pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
@@ -279,6 +440,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu
279440
.map_err(|err| err.to_string())
280441
}
281442

443+
pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
444+
value
445+
.serialize(ParamSerializer { writer })
446+
.map_err(|err| err.to_string())
447+
}
448+
282449
#[cfg(test)]
283450
mod tests {
284451
use super::*;

0 commit comments

Comments
 (0)