Skip to content

Commit

Permalink
Server side parameters via with_param
Browse files Browse the repository at this point in the history
Fixes #142
  • Loading branch information
serprex committed Oct 2, 2024
1 parent de075b9 commit 3f87b34
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 20 deletions.
9 changes: 8 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#[macro_use]
extern crate static_assertions;

use self::{error::Result, http_client::HttpClient};
use self::{error::Result, http_client::HttpClient, sql::ser};
use ::serde::Serialize;
use std::{collections::HashMap, fmt::Display, sync::Arc};

pub use self::{compression::Compression, row::Row};
Expand Down Expand Up @@ -160,6 +161,12 @@ impl Client {
self
}

pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
let mut param = String::from("");
ser::write_param(&mut param, &value)?;
Ok(self.with_option(format!("param_{name}"), param))
}

/// Used to specify a header that will be passed to all queries.
///
/// # Example
Expand Down
10 changes: 8 additions & 2 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use hyper::{header::CONTENT_LENGTH, Method, Request};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use url::Url;

Expand All @@ -10,7 +10,7 @@ use crate::{
request_body::RequestBody,
response::Response,
row::Row,
sql::{Bind, SqlBuilder},
sql::{ser, Bind, SqlBuilder},
Client,
};

Expand Down Expand Up @@ -195,6 +195,12 @@ impl Query {
self.client.add_option(name, value);
self
}

pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
let mut param = String::from("");
ser::write_param(&mut param, &value)?;
Ok(self.with_option(format!("param_{name}"), param))
}
}

/// A cursor that emits rows.
Expand Down
2 changes: 1 addition & 1 deletion src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier};

mod bind;
pub(crate) mod escape;
mod ser;
pub(crate) mod ser;

#[derive(Debug, Clone)]
pub(crate) enum SqlBuilder {
Expand Down
199 changes: 183 additions & 16 deletions src/sql/ser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,32 +8,32 @@ use thiserror::Error;

use super::escape;

// === SqlSerializerError ===
// === SerializerError ===

#[derive(Debug, Error)]
enum SqlSerializerError {
enum SerializerError {
#[error("{0} is unsupported")]
Unsupported(&'static str),
#[error("{0}")]
Custom(String),
}

impl ser::Error for SqlSerializerError {
impl ser::Error for SerializerError {
fn custom<T: fmt::Display>(msg: T) -> Self {
Self::Custom(msg.to_string())
}
}

impl From<fmt::Error> for SqlSerializerError {
impl From<fmt::Error> for SerializerError {
fn from(err: fmt::Error) -> Self {
Self::Custom(err.to_string())
}
}

// === SqlSerializer ===

type Result<T = (), E = SqlSerializerError> = std::result::Result<T, E>;
type Impossible = ser::Impossible<(), SqlSerializerError>;
type Result<T = (), E = SerializerError> = std::result::Result<T, E>;
type Impossible = ser::Impossible<(), SerializerError>;

struct SqlSerializer<'a, W> {
writer: &'a mut W,
Expand All @@ -43,7 +43,7 @@ macro_rules! unsupported {
($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => {
#[inline]
fn $ser_method(self, _v: $ty) -> $ret {
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
Err(SerializerError::Unsupported(stringify!($ser_method)))
}
unsupported!($($other)*);
};
Expand All @@ -53,7 +53,7 @@ macro_rules! unsupported {
($ser_method:ident, $($other:tt)*) => {
#[inline]
fn $ser_method(self) -> Result {
Err(SqlSerializerError::Unsupported(stringify!($ser_method)))
Err(SerializerError::Unsupported(stringify!($ser_method)))
}
unsupported!($($other)*);
};
Expand All @@ -73,7 +73,7 @@ macro_rules! forward_to_display {
}

impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = SqlListSerializer<'a, W>;
Expand Down Expand Up @@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_value: &T,
) -> Result {
Err(SqlSerializerError::Unsupported("serialize_newtype_variant"))
Err(SerializerError::Unsupported("serialize_newtype_variant"))
}

#[inline]
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
Err(SqlSerializerError::Unsupported("serialize_tuple_struct"))
Err(SerializerError::Unsupported("serialize_tuple_struct"))
}

#[inline]
Expand All @@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_len: usize,
) -> Result<Impossible> {
Err(SqlSerializerError::Unsupported("serialize_tuple_variant"))
Err(SerializerError::Unsupported("serialize_tuple_variant"))
}

#[inline]
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
Err(SqlSerializerError::Unsupported("serialize_struct"))
Err(SerializerError::Unsupported("serialize_struct"))
}

#[inline]
Expand All @@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> {
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(SqlSerializerError::Unsupported("serialize_struct_variant"))
Err(SerializerError::Unsupported("serialize_struct_variant"))
}

#[inline]
Expand All @@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> {
}

impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();

#[inline]
Expand All @@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> {
}

impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
type Error = SqlSerializerError;
type Error = SerializerError;
type Ok = ();

#[inline]
Expand All @@ -271,6 +271,167 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> {
}
}

// === ParamSerializer ===

struct ParamSerializer<'a, W> {
writer: &'a mut W,
}

impl<'a, W: Write> Serializer for ParamSerializer<'a, W> {
type Error = SerializerError;
type Ok = ();
type SerializeMap = Impossible;
type SerializeSeq = SqlListSerializer<'a, W>;
type SerializeStruct = Impossible;
type SerializeStructVariant = Impossible;
type SerializeTuple = SqlListSerializer<'a, W>;
type SerializeTupleStruct = Impossible;
type SerializeTupleVariant = Impossible;

unsupported!(
serialize_map(Option<usize>) -> Result<Impossible>,
serialize_bytes(&[u8]),
serialize_unit,
serialize_unit_struct(&'static str),
);

forward_to_display!(
serialize_i8(i8),
serialize_i16(i16),
serialize_i32(i32),
serialize_i64(i64),
serialize_i128(i128),
serialize_u8(u8),
serialize_u16(u16),
serialize_u32(u32),
serialize_u64(u64),
serialize_u128(u128),
serialize_f32(f32),
serialize_f64(f64),
serialize_bool(bool),
);

#[inline]
fn serialize_char(self, value: char) -> Result {
let mut tmp = [0u8; 4];
self.serialize_str(value.encode_utf8(&mut tmp))
}

#[inline]
fn serialize_str(self, value: &str) -> Result {
// ClickHouse expects strings in params to be unquoted until inside a nested type
// nested types go through serialize_seq which'll quote strings
let mut rest = value;
while let Some(nextidx) = rest.find('\\') {
let (before, after) = rest.split_at(nextidx + 1);
rest = after;
self.writer.write_str(before)?;
self.writer.write_char('\\')?;
}
self.writer.write_str(rest)?;
Ok(())
}

#[inline]
fn serialize_seq(self, _len: Option<usize>) -> Result<SqlListSerializer<'a, W>> {
self.writer.write_char('[')?;
Ok(SqlListSerializer {
writer: self.writer,
has_items: false,
closing_char: ']',
})
}

#[inline]
fn serialize_tuple(self, _len: usize) -> Result<SqlListSerializer<'a, W>> {
self.writer.write_char('(')?;
Ok(SqlListSerializer {
writer: self.writer,
has_items: false,
closing_char: ')',
})
}

#[inline]
fn serialize_some<T: Serialize + ?Sized>(self, _value: &T) -> Result {
_value.serialize(self)
}

#[inline]
fn serialize_none(self) -> std::result::Result<Self::Ok, Self::Error> {
self.writer.write_str("NULL")?;
Ok(())
}

#[inline]
fn serialize_unit_variant(
self,
_name: &'static str,
_variant_index: u32,
variant: &'static str,
) -> Result {
escape::string(variant, self.writer)?;
Ok(())
}

#[inline]
fn serialize_newtype_struct<T: Serialize + ?Sized>(
self,
_name: &'static str,
value: &T,
) -> Result {
value.serialize(self)
}

#[inline]
fn serialize_newtype_variant<T: Serialize + ?Sized>(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_value: &T,
) -> Result {
Err(SerializerError::Unsupported("serialize_newtype_variant"))
}

#[inline]
fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result<Impossible> {
Err(SerializerError::Unsupported("serialize_tuple_struct"))
}

#[inline]
fn serialize_tuple_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Impossible> {
Err(SerializerError::Unsupported("serialize_tuple_variant"))
}

#[inline]
fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
Err(SerializerError::Unsupported("serialize_struct"))
}

#[inline]
fn serialize_struct_variant(
self,
_name: &'static str,
_variant_index: u32,
_variant: &'static str,
_len: usize,
) -> Result<Self::SerializeStructVariant> {
Err(SerializerError::Unsupported("serialize_struct_variant"))
}

#[inline]
fn is_human_readable(&self) -> bool {
true
}
}

// === Public API ===

pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
Expand All @@ -279,6 +440,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu
.map_err(|err| err.to_string())
}

pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> {
value
.serialize(ParamSerializer { writer })
.map_err(|err| err.to_string())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading

0 comments on commit 3f87b34

Please sign in to comment.