From 269f0031c53c2f4433fae555e21ea36a7c796520 Mon Sep 17 00:00:00 2001 From: hatoo Date: Thu, 7 Nov 2024 21:27:08 +0900 Subject: [PATCH] Use Service trait --- README.md | 28 +++++---- examples/dev_proxy.rs | 87 +++++++++++++------------- examples/https.rs | 7 +-- examples/proxy.rs | 28 +++++---- examples/websocket.rs | 140 ++++++++++++++++++++++-------------------- src/lib.rs | 42 ++++++------- tests/test.rs | 59 ++++++++++-------- 7 files changed, 205 insertions(+), 186 deletions(-) diff --git a/README.md b/README.md index 083070f..bf35689 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ use std::path::PathBuf; use clap::{Args, Parser}; use http_mitm_proxy::{DefaultClient, MitmProxy}; +use hyper::service::service_fn; use moka::sync::Cache; use tracing_subscriber::EnvFilter; @@ -88,23 +89,26 @@ async fn main() { let client = DefaultClient::new().unwrap(); let server = proxy - .bind(("127.0.0.1", 3003), move |_client_addr, req| { - let client = client.clone(); - async move { - let uri = req.uri().clone(); + .bind( + ("127.0.0.1", 3003), + service_fn(move |req| { + let client = client.clone(); + async move { + let uri = req.uri().clone(); - // You can modify request here - // or You can just return response anywhere + // You can modify request here + // or You can just return response anywhere - let (res, _upgrade) = client.send_request(req).await?; + let (res, _upgrade) = client.send_request(req).await?; - println!("{} -> {}", uri, res.status()); + println!("{} -> {}", uri, res.status()); - // You can modify response here + // You can modify response here - Ok::<_, http_mitm_proxy::default_client::Error>(res) - } - }) + Ok::<_, http_mitm_proxy::default_client::Error>(res) + } + }), + ) .await .unwrap(); diff --git a/examples/dev_proxy.rs b/examples/dev_proxy.rs index c0e3494..32f9963 100644 --- a/examples/dev_proxy.rs +++ b/examples/dev_proxy.rs @@ -5,7 +5,7 @@ use bytes::Bytes; use clap::{Args, Parser}; use http_body_util::{BodyExt, Full}; use http_mitm_proxy::{DefaultClient, MitmProxy}; -use hyper::Response; +use hyper::{service::service_fn, Response}; use moka::sync::Cache; #[derive(Parser)] @@ -82,50 +82,53 @@ async fn main() { let client = DefaultClient::new().unwrap(); let proxy = proxy - .bind(("127.0.0.1", 3003), move |_client_addr, mut req| { - let client = client.clone(); - async move { - // Forward connection from http/https dev.example to http://127.0.0.1:3333 - if req.uri().host() == Some("dev.example") { - // Return a response created by the proxy - if req.uri().path() == "/test.json" { - let res = Response::builder() - .header(hyper::header::CONTENT_TYPE, "application/json") - .body( - Full::new(Bytes::from("{data: 123}")) - .map_err(|e| match e {}) - .boxed(), - ) - .unwrap(); - return Ok(res); + .bind( + ("127.0.0.1", 3003), + service_fn(move |mut req| { + let client = client.clone(); + async move { + // Forward connection from http/https dev.example to http://127.0.0.1:3333 + if req.uri().host() == Some("dev.example") { + // Return a response created by the proxy + if req.uri().path() == "/test.json" { + let res = Response::builder() + .header(hyper::header::CONTENT_TYPE, "application/json") + .body( + Full::new(Bytes::from("{data: 123}")) + .map_err(|e| match e {}) + .boxed(), + ) + .unwrap(); + return Ok(res); + } + + req.headers_mut().insert( + hyper::header::HOST, + hyper::header::HeaderValue::from_maybe_shared(format!( + "127.0.0.1:{}", + port + )) + .unwrap(), + ); + + let mut parts = req.uri().clone().into_parts(); + parts.scheme = Some(hyper::http::uri::Scheme::HTTP); + parts.authority = Some( + hyper::http::uri::Authority::from_maybe_shared(format!( + "127.0.0.1:{}", + port + )) + .unwrap(), + ); + *req.uri_mut() = hyper::Uri::from_parts(parts).unwrap(); } - req.headers_mut().insert( - hyper::header::HOST, - hyper::header::HeaderValue::from_maybe_shared(format!( - "127.0.0.1:{}", - port - )) - .unwrap(), - ); - - let mut parts = req.uri().clone().into_parts(); - parts.scheme = Some(hyper::http::uri::Scheme::HTTP); - parts.authority = Some( - hyper::http::uri::Authority::from_maybe_shared(format!( - "127.0.0.1:{}", - port - )) - .unwrap(), - ); - *req.uri_mut() = hyper::Uri::from_parts(parts).unwrap(); - } - - let (res, _upgrade) = client.send_request(req).await?; + let (res, _upgrade) = client.send_request(req).await?; - Ok::<_, http_mitm_proxy::default_client::Error>(res.map(|b| b.boxed())) - } - }) + Ok::<_, http_mitm_proxy::default_client::Error>(res.map(|b| b.boxed())) + } + }), + ) .await .unwrap(); diff --git a/examples/https.rs b/examples/https.rs index 7812175..7ef008d 100644 --- a/examples/https.rs +++ b/examples/https.rs @@ -100,7 +100,7 @@ async fn main() { let server = async move { loop { - let (stream, client_addr) = listener.accept().await.unwrap(); + let (stream, _client_addr) = listener.accept().await.unwrap(); let proxy = proxy.clone(); let client = client.clone(); let tls_acceptor = tls_acceptor.clone(); @@ -111,9 +111,8 @@ async fn main() { MitmProxy::hyper_service( proxy.clone(), - client_addr, req, - move |_client_addr, req| { + service_fn(move |req| { let client = client.clone(); async move { let uri = req.uri().clone(); @@ -129,7 +128,7 @@ async fn main() { Ok::<_, http_mitm_proxy::default_client::Error>(res) } - }, + }), ) }); diff --git a/examples/proxy.rs b/examples/proxy.rs index 015e305..000f2ce 100644 --- a/examples/proxy.rs +++ b/examples/proxy.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use clap::{Args, Parser}; use http_mitm_proxy::{DefaultClient, MitmProxy}; +use hyper::service::service_fn; use moka::sync::Cache; use tracing_subscriber::EnvFilter; @@ -75,23 +76,26 @@ async fn main() { let client = DefaultClient::new().unwrap(); let server = proxy - .bind(("127.0.0.1", 3003), move |_client_addr, req| { - let client = client.clone(); - async move { - let uri = req.uri().clone(); + .bind( + ("127.0.0.1", 3003), + service_fn(move |req| { + let client = client.clone(); + async move { + let uri = req.uri().clone(); - // You can modify request here - // or You can just return response anywhere + // You can modify request here + // or You can just return response anywhere - let (res, _upgrade) = client.send_request(req).await?; + let (res, _upgrade) = client.send_request(req).await?; - println!("{} -> {}", uri, res.status()); + println!("{} -> {}", uri, res.status()); - // You can modify response here + // You can modify response here - Ok::<_, http_mitm_proxy::default_client::Error>(res) - } - }) + Ok::<_, http_mitm_proxy::default_client::Error>(res) + } + }), + ) .await .unwrap(); diff --git a/examples/websocket.rs b/examples/websocket.rs index 874d67a..635046a 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use clap::{Args, Parser}; use http_mitm_proxy::{default_client::Upgraded, DefaultClient, MitmProxy}; +use hyper::service::service_fn; use moka::sync::Cache; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tracing_subscriber::EnvFilter; @@ -78,95 +79,98 @@ async fn main() { let client = DefaultClient::new().unwrap().with_upgrades(); let server = proxy - .bind(("127.0.0.1", 3003), move |_client_addr, req| { - let client = client.clone(); - async move { - let uri = req.uri().clone(); + .bind( + ("127.0.0.1", 3003), + service_fn(move |req| { + let client = client.clone(); + async move { + let uri = req.uri().clone(); - // You can modify request here - // or You can just return response anywhere + // You can modify request here + // or You can just return response anywhere - let (res, upgrade) = client.send_request(req).await?; + let (res, upgrade) = client.send_request(req).await?; - // println!("{} -> {}", uri, res.status()); - if let Some(upgrade) = upgrade { - // If the response is an upgrade, e.g. Websocket, you can see traffic. - // Modifying upgraded traffic is not supported yet. + // println!("{} -> {}", uri, res.status()); + if let Some(upgrade) = upgrade { + // If the response is an upgrade, e.g. Websocket, you can see traffic. + // Modifying upgraded traffic is not supported yet. - // You can try https://echo.websocket.org/.ws to test websocket. - println!("Upgrade connection"); + // You can try https://echo.websocket.org/.ws to test websocket. + println!("Upgrade connection"); - tokio::spawn(async move { - let Upgraded { client, server } = upgrade.await.unwrap().unwrap(); - let url = uri.to_string(); + tokio::spawn(async move { + let Upgraded { client, server } = upgrade.await.unwrap().unwrap(); + let url = uri.to_string(); - let (mut client_rx, mut client_tx) = tokio::io::split(client); - let (mut server_rx, mut server_tx) = tokio::io::split(server); + let (mut client_rx, mut client_tx) = tokio::io::split(client); + let (mut server_rx, mut server_tx) = tokio::io::split(server); - let url0 = url.clone(); - let client_to_server = async move { - let mut buf = Vec::new(); + let url0 = url.clone(); + let client_to_server = async move { + let mut buf = Vec::new(); - loop { - if client_rx.read_buf(&mut buf).await.unwrap() == 0 { - break; - } loop { - let input = &mut buf.as_slice(); - if let Ok((frame, read)) = - websocket::frame.with_taken().parse_next(input) - { - println!( - "{} Client: {}", - &url0, - String::from_utf8_lossy(&frame.payload_data) - ); - server_tx.write_all(read).await.unwrap(); - buf = input.to_vec(); - } else { + if client_rx.read_buf(&mut buf).await.unwrap() == 0 { break; } + loop { + let input = &mut buf.as_slice(); + if let Ok((frame, read)) = + websocket::frame.with_taken().parse_next(input) + { + println!( + "{} Client: {}", + &url0, + String::from_utf8_lossy(&frame.payload_data) + ); + server_tx.write_all(read).await.unwrap(); + buf = input.to_vec(); + } else { + break; + } + } } - } - }; + }; - let url0 = url.clone(); - let server_to_client = async move { - let mut buf = Vec::new(); + let url0 = url.clone(); + let server_to_client = async move { + let mut buf = Vec::new(); - loop { - if server_rx.read_buf(&mut buf).await.unwrap() == 0 { - break; - } loop { - let input = &mut buf.as_slice(); - if let Ok((frame, read)) = - websocket::frame.with_taken().parse_next(input) - { - println!( - "{} Server: {}", - &url0, - String::from_utf8_lossy(&frame.payload_data) - ); - client_tx.write_all(read).await.unwrap(); - buf = input.to_vec(); - } else { + if server_rx.read_buf(&mut buf).await.unwrap() == 0 { break; } + loop { + let input = &mut buf.as_slice(); + if let Ok((frame, read)) = + websocket::frame.with_taken().parse_next(input) + { + println!( + "{} Server: {}", + &url0, + String::from_utf8_lossy(&frame.payload_data) + ); + client_tx.write_all(read).await.unwrap(); + buf = input.to_vec(); + } else { + break; + } + } } - } - }; + }; - tokio::spawn(client_to_server); - tokio::spawn(server_to_client); - }); - } + tokio::spawn(client_to_server); + tokio::spawn(server_to_client); + }); + } - // You can modify response here + // You can modify response here - Ok::<_, http_mitm_proxy::default_client::Error>(res) - } - }) + Ok::<_, http_mitm_proxy::default_client::Error>(res) + } + }), + ) .await .unwrap(); diff --git a/src/lib.rs b/src/lib.rs index b34cb19..23ea465 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,12 +5,12 @@ use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use hyper::{ body::{Body, Incoming}, server, - service::service_fn, + service::{service_fn, HttpService}, Method, Request, Response, StatusCode, }; use hyper_util::rt::{TokioExecutor, TokioIo}; use moka::sync::Cache; -use std::{borrow::Borrow, future::Future, net::SocketAddr, sync::Arc}; +use std::{borrow::Borrow, future::Future, sync::Arc}; use tls::{generate_cert, CertifiedKeyDer}; use tokio::net::{TcpListener, TcpStream, ToSocketAddrs}; @@ -53,7 +53,7 @@ impl MitmProxy { impl + Send + Sync + 'static> MitmProxy { /// Bind to a socket address and return a future that runs the proxy server. /// URL for requests that passed to service are full URL including scheme. - pub async fn bind( + pub async fn bind( self, addr: A, service: S, @@ -62,8 +62,11 @@ impl + Send + Sync + 'static> MitmProxy { B: Body + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, E2: std::error::Error + Send + Sync + 'static, - S: Fn(SocketAddr, Request) -> F + Send + Sync + Clone + 'static, - F: Future, E2>> + Send, + S: HttpService + + Send + + Sync + + Clone + + 'static, { let listener = TcpListener::bind(addr).await?; @@ -71,7 +74,7 @@ impl + Send + Sync + 'static> MitmProxy { Ok(async move { loop { - let Ok((stream, client_addr)) = listener.accept().await else { + let Ok((stream, _)) = listener.accept().await else { continue; }; @@ -85,12 +88,7 @@ impl + Send + Sync + 'static> MitmProxy { .serve_connection( TokioIo::new(stream), service_fn(|req| { - Self::hyper_service( - proxy.clone(), - client_addr, - req, - service.clone(), - ) + Self::hyper_service(proxy.clone(), req, service.clone()) }), ) .with_upgrades() @@ -107,15 +105,17 @@ impl + Send + Sync + 'static> MitmProxy { /// See `examples/https.rs` for usage. /// If you want to serve simple HTTP proxy server, you can use `bind` method instead. /// `bind` will call this method internally. - pub async fn hyper_service( + pub async fn hyper_service( proxy: Arc, - client_addr: SocketAddr, req: Request, - service: S, + mut service: S, ) -> Result>, E2> where - S: Fn(SocketAddr, Request) -> F + Send + Clone + 'static, - F: Future, E2>> + Send, + S: HttpService + + Send + + Sync + + Clone + + 'static, B: Body + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, E2: std::error::Error + Send + Sync + 'static, @@ -167,11 +167,11 @@ impl + Send + Sync + 'static> MitmProxy { }; let f = move |mut req: Request<_>| { let connect_authority = connect_authority.clone(); - let service = service.clone(); + let mut service = service.clone(); async move { inject_authority(&mut req, connect_authority.clone()); - service(client_addr, req).await + service.call(req).await } }; let res = if client.get_ref().1.alpn_protocol() == Some(b"h2") { @@ -209,9 +209,7 @@ impl + Send + Sync + 'static> MitmProxy { )) } else { // http - service(client_addr, req) - .await - .map(|res| res.map(|b| b.boxed())) + service.call(req).await.map(|res| res.map(|b| b.boxed())) } } diff --git a/tests/test.rs b/tests/test.rs index ab232ce..033463e 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -1,6 +1,5 @@ use std::{ convert::Infallible, - net::SocketAddr, sync::{atomic::AtomicU16, Arc}, }; @@ -15,7 +14,8 @@ use futures::stream; use http_mitm_proxy::{DefaultClient, MitmProxy}; use hyper::{ body::{Body, Incoming}, - Response, Uri, + service::{service_fn, HttpService}, + Uri, }; use moka::sync::Cache; @@ -76,13 +76,12 @@ fn proxy_client() -> DefaultClient { DefaultClient::new().unwrap() } -async fn setup(app: Router, service: S) -> (u16, u16) +async fn setup(app: Router, service: S) -> (u16, u16) where B: Body + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, E2: std::error::Error + Send + Sync + 'static, - S: Fn(SocketAddr, Request) -> F + Send + Sync + Clone + 'static, - F: std::future::Future, E2>> + Send + 'static, + S: HttpService + Send + Sync + Clone + 'static, { let proxy = MitmProxy::new(Some(root_cert()), Some(Cache::new(128))); let proxy_port = get_port(); @@ -98,7 +97,7 @@ where (proxy_port, port) } -async fn setup_tls( +async fn setup_tls( app: Router, service: S, root_cert: Arc, @@ -107,8 +106,7 @@ where B: Body + Send + Sync + 'static, E: std::error::Error + Send + Sync + 'static, E2: std::error::Error + Send + Sync + 'static, - S: Fn(SocketAddr, Request) -> F + Send + Sync + Clone + 'static, - F: std::future::Future, E2>> + Send + 'static, + S: HttpService + Send + Sync + Clone + 'static, { let proxy = MitmProxy::new(Some(root_cert), Some(Cache::new(128))); let proxy_port = get_port(); @@ -130,10 +128,13 @@ async fn test_simple_http() { let app = Router::new().route("/", get(|| async move { BODY })); let proxy_client = proxy_client(); - let (proxy_port, port) = setup(app, move |_, req| { - let proxy_client = proxy_client.clone(); - async move { proxy_client.send_request(req).await.map(|t| t.0) } - }) + let (proxy_port, port) = setup( + app, + service_fn(move |req| { + let proxy_client = proxy_client.clone(); + async move { proxy_client.send_request(req).await.map(|t| t.0) } + }), + ) .await; let client = client(proxy_port); @@ -162,14 +163,17 @@ async fn test_modify_http() { ); let proxy_client = proxy_client(); - let (proxy_port, port) = setup(app, move |_, mut req| { - let proxy_client = proxy_client.clone(); - async move { - req.headers_mut() - .insert("X-test", "modified".parse().unwrap()); - proxy_client.send_request(req).await.map(|t| t.0) - } - }) + let (proxy_port, port) = setup( + app, + service_fn(move |mut req| { + let proxy_client = proxy_client.clone(); + async move { + req.headers_mut() + .insert("X-test", "modified".parse().unwrap()); + proxy_client.send_request(req).await.map(|t| t.0) + } + }), + ) .await; let client = client(proxy_port); @@ -196,10 +200,13 @@ async fn test_sse_http() { ); let proxy_client = proxy_client(); - let (proxy_port, port) = setup(app, move |_, req| { - let proxy_client = proxy_client.clone(); - async move { proxy_client.send_request(req).await.map(|t| t.0) } - }) + let (proxy_port, port) = setup( + app, + service_fn(move |req| { + let proxy_client = proxy_client.clone(); + async move { proxy_client.send_request(req).await.map(|t| t.0) } + }), + ) .await; let client = client(proxy_port); @@ -225,7 +232,7 @@ async fn test_simple_https() { let proxy_client = proxy_client(); let (proxy_port, port) = setup_tls( app, - move |_, mut req| { + service_fn(move |mut req| { let proxy_client = proxy_client.clone(); async move { let mut parts = req.uri().clone().into_parts(); @@ -235,7 +242,7 @@ async fn test_simple_https() { proxy_client.send_request(req).await.map(|t| t.0) } - }, + }), cert.clone(), ) .await;