|
| 1 | +//! Functionality for using `serde_qs` with `axum`. |
| 2 | +//! |
| 3 | +//! Source: https://github.com/samscott89/serde_qs/blob/b7278b73c637f7c427be762082fee5938ba0c023/src/axum.rs |
| 4 | +//! NOTE: The source was copy-pasted here because `serde_qs` is unmaintained and this implementation |
| 5 | +//! would prevent us from upgrading Axum. |
| 6 | +//! NOTE: We could also use the simpler approach from https://github.com/tokio-rs/axum/issues/434#issuecomment-954898159 |
| 7 | +
|
| 8 | +use std::sync::Arc; |
| 9 | + |
| 10 | +use serde_qs::{Config as QsConfig, Error as QsError}; |
| 11 | + |
| 12 | +use axum::{ |
| 13 | + extract::{Extension, FromRequestParts}, |
| 14 | + http::StatusCode, |
| 15 | + response::{IntoResponse, Response}, |
| 16 | + BoxError, Error, |
| 17 | +}; |
| 18 | + |
| 19 | +#[derive(Clone, Copy, Default)] |
| 20 | +/// Extract typed information from from the request's query. |
| 21 | +/// |
| 22 | +/// ## Example |
| 23 | +/// |
| 24 | +/// ```rust |
| 25 | +/// # extern crate axum_framework as axum; |
| 26 | +/// use serde_qs::axum::QsQuery; |
| 27 | +/// use serde_qs::Config; |
| 28 | +/// use axum::{response::IntoResponse, routing::get, Router, body::Body}; |
| 29 | +/// |
| 30 | +/// #[derive(serde::Deserialize)] |
| 31 | +/// pub struct UsersFilter { |
| 32 | +/// id: Vec<u64>, |
| 33 | +/// } |
| 34 | +/// |
| 35 | +/// async fn filter_users( |
| 36 | +/// QsQuery(info): QsQuery<UsersFilter> |
| 37 | +/// ) -> impl IntoResponse { |
| 38 | +/// info.id |
| 39 | +/// .iter() |
| 40 | +/// .map(|i| i.to_string()) |
| 41 | +/// .collect::<Vec<String>>() |
| 42 | +/// .join(", ") |
| 43 | +/// } |
| 44 | +/// |
| 45 | +/// fn main() { |
| 46 | +/// let app = Router::<()>::new() |
| 47 | +/// .route("/users", get(filter_users)); |
| 48 | +/// } |
| 49 | +pub struct QsQuery<T>(pub T); |
| 50 | + |
| 51 | +impl<T> std::ops::Deref for QsQuery<T> { |
| 52 | + type Target = T; |
| 53 | + |
| 54 | + fn deref(&self) -> &Self::Target { |
| 55 | + &self.0 |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +impl<T: std::fmt::Display> std::fmt::Display for QsQuery<T> { |
| 60 | + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
| 61 | + self.0.fmt(f) |
| 62 | + } |
| 63 | +} |
| 64 | + |
| 65 | +impl<T: std::fmt::Debug> std::fmt::Debug for QsQuery<T> { |
| 66 | + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { |
| 67 | + self.0.fmt(f) |
| 68 | + } |
| 69 | +} |
| 70 | + |
| 71 | +#[axum::async_trait] |
| 72 | +impl<T, S> FromRequestParts<S> for QsQuery<T> |
| 73 | +where |
| 74 | + T: serde::de::DeserializeOwned, |
| 75 | + S: Send + Sync, |
| 76 | +{ |
| 77 | + type Rejection = QsQueryRejection; |
| 78 | + |
| 79 | + async fn from_request_parts( |
| 80 | + parts: &mut axum::http::request::Parts, |
| 81 | + state: &S, |
| 82 | + ) -> Result<Self, Self::Rejection> { |
| 83 | + let Extension(qs_config) = Extension::<QsQueryConfig>::from_request_parts(parts, state) |
| 84 | + .await |
| 85 | + .unwrap_or_else(|_| Extension(QsQueryConfig::default())); |
| 86 | + let error_handler = qs_config.error_handler.clone(); |
| 87 | + let config: QsConfig = qs_config.into(); |
| 88 | + let query = parts.uri.query().unwrap_or_default(); |
| 89 | + match config.deserialize_str::<T>(query) { |
| 90 | + Ok(value) => Ok(QsQuery(value)), |
| 91 | + Err(err) => match error_handler { |
| 92 | + Some(handler) => Err((handler)(err)), |
| 93 | + None => Err(QsQueryRejection::new(err, StatusCode::BAD_REQUEST)), |
| 94 | + }, |
| 95 | + } |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +#[derive(Debug)] |
| 100 | +/// Rejection type for extractors that deserialize query strings |
| 101 | +pub struct QsQueryRejection { |
| 102 | + error: axum::Error, |
| 103 | + status: StatusCode, |
| 104 | +} |
| 105 | + |
| 106 | +impl std::fmt::Display for QsQueryRejection { |
| 107 | + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| 108 | + write!( |
| 109 | + f, |
| 110 | + "Failed to deserialize query string. Error: {}", |
| 111 | + self.error, |
| 112 | + ) |
| 113 | + } |
| 114 | +} |
| 115 | + |
| 116 | +impl QsQueryRejection { |
| 117 | + /// Create new rejection |
| 118 | + pub fn new<E>(error: E, status: StatusCode) -> Self |
| 119 | + where |
| 120 | + E: Into<BoxError>, |
| 121 | + { |
| 122 | + QsQueryRejection { |
| 123 | + error: Error::new(error), |
| 124 | + status, |
| 125 | + } |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +impl IntoResponse for QsQueryRejection { |
| 130 | + fn into_response(self) -> Response { |
| 131 | + let mut res = self.to_string().into_response(); |
| 132 | + *res.status_mut() = self.status; |
| 133 | + res |
| 134 | + } |
| 135 | +} |
| 136 | + |
| 137 | +impl std::error::Error for QsQueryRejection { |
| 138 | + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { |
| 139 | + Some(&self.error) |
| 140 | + } |
| 141 | +} |
| 142 | + |
| 143 | +#[derive(Clone)] |
| 144 | +/// Query extractor configuration |
| 145 | +/// |
| 146 | +/// QsQueryConfig wraps [`Config`](crate::de::Config) and implement [`Clone`] |
| 147 | +/// for [`FromRequest`](https://docs.rs/axum/0.5/axum/extract/trait.FromRequest.html) |
| 148 | +/// |
| 149 | +/// ## Example |
| 150 | +/// |
| 151 | +/// ```rust |
| 152 | +/// # extern crate axum_framework as axum; |
| 153 | +/// use serde_qs::axum::{QsQuery, QsQueryConfig, QsQueryRejection}; |
| 154 | +/// use serde_qs::Config; |
| 155 | +/// use axum::{ |
| 156 | +/// response::IntoResponse, |
| 157 | +/// routing::get, |
| 158 | +/// Router, |
| 159 | +/// body::Body, |
| 160 | +/// extract::Extension, |
| 161 | +/// http::StatusCode, |
| 162 | +/// }; |
| 163 | +/// use std::sync::Arc; |
| 164 | +/// |
| 165 | +/// #[derive(serde::Deserialize)] |
| 166 | +/// pub struct UsersFilter { |
| 167 | +/// id: Vec<u64>, |
| 168 | +/// } |
| 169 | +/// |
| 170 | +/// async fn filter_users( |
| 171 | +/// QsQuery(info): QsQuery<UsersFilter> |
| 172 | +/// ) -> impl IntoResponse { |
| 173 | +/// info.id |
| 174 | +/// .iter() |
| 175 | +/// .map(|i| i.to_string()) |
| 176 | +/// .collect::<Vec<String>>() |
| 177 | +/// .join(", ") |
| 178 | +/// } |
| 179 | +/// |
| 180 | +/// fn main() { |
| 181 | +/// let app = Router::<()>::new() |
| 182 | +/// .route("/users", get(filter_users)) |
| 183 | +/// .layer(Extension(QsQueryConfig::new(5, false) |
| 184 | +/// .error_handler(|err| { |
| 185 | +/// QsQueryRejection::new(err, StatusCode::UNPROCESSABLE_ENTITY) |
| 186 | +/// }))); |
| 187 | +/// } |
| 188 | +pub struct QsQueryConfig { |
| 189 | + max_depth: usize, |
| 190 | + strict: bool, |
| 191 | + error_handler: Option<Arc<dyn Fn(QsError) -> QsQueryRejection + Send + Sync>>, |
| 192 | +} |
| 193 | + |
| 194 | +impl QsQueryConfig { |
| 195 | + /// Create new config wrapper |
| 196 | + pub fn new(max_depth: usize, strict: bool) -> Self { |
| 197 | + Self { |
| 198 | + max_depth, |
| 199 | + strict, |
| 200 | + error_handler: None, |
| 201 | + } |
| 202 | + } |
| 203 | + |
| 204 | + /// Set custom error handler |
| 205 | + pub fn error_handler<F>(mut self, f: F) -> Self |
| 206 | + where |
| 207 | + F: Fn(QsError) -> QsQueryRejection + Send + Sync + 'static, |
| 208 | + { |
| 209 | + self.error_handler = Some(Arc::new(f)); |
| 210 | + self |
| 211 | + } |
| 212 | +} |
| 213 | + |
| 214 | +impl From<QsQueryConfig> for QsConfig { |
| 215 | + fn from(config: QsQueryConfig) -> Self { |
| 216 | + Self::new(config.max_depth, config.strict) |
| 217 | + } |
| 218 | +} |
| 219 | + |
| 220 | +impl Default for QsQueryConfig { |
| 221 | + fn default() -> Self { |
| 222 | + Self { |
| 223 | + max_depth: 5, |
| 224 | + strict: true, |
| 225 | + error_handler: None, |
| 226 | + } |
| 227 | + } |
| 228 | +} |
0 commit comments