From f902b588cada88c899e5d95f36f9e55bf01882f0 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 8 Feb 2025 21:09:47 +0900 Subject: [PATCH 1/4] add --proxy-header --- src/client.rs | 53 +++++++++++++++++++++++++++++++++++---------------- src/db.rs | 1 + src/main.rs | 38 ++++++++++++++++++++++++++---------- 3 files changed, 66 insertions(+), 26 deletions(-) diff --git a/src/client.rs b/src/client.rs index 327cd818..bba17310 100644 --- a/src/client.rs +++ b/src/client.rs @@ -174,6 +174,7 @@ pub struct Client { pub url_generator: UrlGenerator, pub method: http::Method, pub headers: http::header::HeaderMap, + pub proxy_headers: http::header::HeaderMap, pub body: Option<&'static [u8]>, pub dns: Dns, pub timeout: Option, @@ -494,14 +495,21 @@ impl Client { let (dns_lookup, stream) = self.client(proxy_url, rng, self.is_proxy_http2()).await?; if url.scheme() == "https" { // Do CONNECT request to proxy - let req = http::Request::builder() - .method(Method::CONNECT) - .uri(format!( - "{}:{}", - url.host_str().unwrap(), - url.port_or_known_default().unwrap() - )) - .body(http_body_util::Full::default())?; + let req = { + let mut builder = + http::Request::builder() + .method(Method::CONNECT) + .uri(format!( + "{}:{}", + url.host_str().unwrap(), + url.port_or_known_default().unwrap() + )); + *builder + .headers_mut() + .ok_or(ClientError::GetHeaderFromBuilderError)? = + self.proxy_headers.clone(); + builder.body(http_body_util::Full::default())? + }; let res = if self.proxy_http_version == http::Version::HTTP_2 { let mut send_request = stream.handshake_http2().await?; send_request.send_request(req).await? @@ -557,6 +565,12 @@ impl Client { aws_config.sign_request(self.method.as_str(), &mut headers, url, bytes)? } + if use_proxy { + for (key, value) in self.proxy_headers.iter() { + headers.insert(key, value.clone()); + } + } + *builder .headers_mut() .ok_or(ClientError::GetHeaderFromBuilderError)? = headers; @@ -670,14 +684,21 @@ impl Client { if let Some(proxy_url) = &self.proxy_url { let (dns_lookup, stream) = self.client(proxy_url, rng, self.is_proxy_http2()).await?; if url.scheme() == "https" { - let req = http::Request::builder() - .method(Method::CONNECT) - .uri(format!( - "{}:{}", - url.host_str().unwrap(), - url.port_or_known_default().unwrap() - )) - .body(http_body_util::Full::default())?; + let req = { + let mut builder = + http::Request::builder() + .method(Method::CONNECT) + .uri(format!( + "{}:{}", + url.host_str().unwrap(), + url.port_or_known_default().unwrap() + )); + *builder + .headers_mut() + .ok_or(ClientError::GetHeaderFromBuilderError)? = + self.proxy_headers.clone(); + builder.body(http_body_util::Full::default())? + }; let res = if self.proxy_http_version == http::Version::HTTP_2 { let mut send_request = stream.handshake_http2().await?; send_request.send_request(req).await? diff --git a/src/db.rs b/src/db.rs index 26131774..0145f17c 100644 --- a/src/db.rs +++ b/src/db.rs @@ -78,6 +78,7 @@ mod test_db { url_generator: UrlGenerator::new_static("http://example.com".parse().unwrap()), method: Method::GET, headers: HeaderMap::new(), + proxy_headers: HeaderMap::new(), body: None, dns: Dns { resolver: hickory_resolver::AsyncResolver::tokio_from_system_conf().unwrap(), diff --git a/src/main.rs b/src/main.rs index c4eb6d47..8f0003d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,9 +4,12 @@ use clap::Parser; use crossterm::tty::IsTty; use hickory_resolver::config::{ResolverConfig, ResolverOpts}; use humantime::Duration; -use hyper::http::{ - self, - header::{HeaderName, HeaderValue}, +use hyper::{ + http::{ + self, + header::{HeaderName, HeaderValue}, + }, + HeaderMap, }; use printer::{PrintConfig, PrintMode}; use rand_regex::Regex; @@ -143,6 +146,11 @@ Note: If qps is specified, burst will be ignored", method: http::Method, #[arg(help = "Custom HTTP header. Examples: -H \"foo: bar\"", short = 'H')] headers: Vec, + #[arg( + help = "Custom Proxy HTTP header. Examples: --proxy-header \"foo: bar\"", + long = "proxy-header" + )] + proxy_headers: Vec, #[arg(help = "Timeout for each request. Default to infinite.", short = 't')] timeout: Option, #[arg(help = "HTTP Accept Header.", short = 'A')] @@ -501,13 +509,7 @@ async fn run() -> anyhow::Result<()> { for (k, v) in opts .headers .into_iter() - .map(|s| { - let header = s.splitn(2, ':').collect::>(); - anyhow::ensure!(header.len() == 2, anyhow::anyhow!("Parse header")); - let name = HeaderName::from_str(header[0])?; - let value = HeaderValue::from_str(header[1].trim_start_matches(' '))?; - Ok::<(HeaderName, HeaderValue), anyhow::Error>((name, value)) - }) + .map(|s| parse_header(s.as_str())) .collect::>>()? { headers.insert(k, v); @@ -516,6 +518,13 @@ async fn run() -> anyhow::Result<()> { headers }; + let proxy_headers = { + opts.proxy_headers + .into_iter() + .map(|s| parse_header(s.as_str())) + .collect::>>()? + }; + let body: Option<&'static [u8]> = match (opts.body_string, opts.body_path) { (Some(body), _) => Some(Box::leak(body.into_boxed_str().into_boxed_bytes())), (_, Some(path)) => { @@ -550,6 +559,7 @@ async fn run() -> anyhow::Result<()> { url_generator, method: opts.method, headers, + proxy_headers, body, dns: client::Dns { resolver, @@ -946,3 +956,11 @@ impl Opts { } } } + +fn parse_header(s: &str) -> Result<(HeaderName, HeaderValue), anyhow::Error> { + let header = s.splitn(2, ':').collect::>(); + anyhow::ensure!(header.len() == 2, anyhow::anyhow!("Parse header")); + let name = HeaderName::from_str(header[0])?; + let value = HeaderValue::from_str(header[1].trim_start_matches(' '))?; + Ok::<(HeaderName, HeaderValue), anyhow::Error>((name, value)) +} From 14b8b20175e9cd160e68d80b406b898dde652fcb Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 8 Feb 2025 21:40:12 +0900 Subject: [PATCH 2/4] test --proxy-header --- tests/tests.rs | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/tests/tests.rs b/tests/tests.rs index 15a05264..b849a97d 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -755,22 +755,31 @@ where let proxy = proxy.clone(); let service = service.clone(); + + let outer = service_fn(move |req| { + // Test --proxy-header option + assert_eq!( + req.headers() + .get("proxy-authorization") + .unwrap() + .to_str() + .unwrap(), + "test" + ); + + MitmProxy::wrap_service(proxy.clone(), service.clone()).call(req) + }); + tokio::spawn(async move { if http2 { let _ = hyper::server::conn::http2::Builder::new(TokioExecutor::new()) - .serve_connection( - TokioIo::new(stream), - MitmProxy::wrap_service(proxy, service), - ) + .serve_connection(TokioIo::new(stream), outer) .await; } else { let _ = hyper::server::conn::http1::Builder::new() .preserve_header_case(true) .title_case_headers(true) - .serve_connection( - TokioIo::new(stream), - MitmProxy::wrap_service(proxy, service), - ) + .serve_connection(TokioIo::new(stream), outer) .with_upgrades() .await; } @@ -800,6 +809,7 @@ async fn test_proxy_with_setting(https: bool, http2: bool, proxy_http2: bool) { let scheme = if https { "https" } else { "http" }; proc.args(["--no-tui", "--debug", "--insecure", "-x"]) .arg(format!("http://127.0.0.1:{proxy_port}/")) + .args(["--proxy-header", "proxy-authorization: test"]) .arg(format!("{scheme}://example.com/")); if http2 { proc.arg("--http2"); @@ -810,16 +820,10 @@ async fn test_proxy_with_setting(https: bool, http2: bool, proxy_http2: bool) { proc.stdin(std::process::Stdio::null()) .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::null()); - let stdout = proc - .spawn() - .unwrap() - .wait_with_output() - .await - .unwrap() - .stdout; - - assert!(String::from_utf8(stdout).unwrap().contains("Hello World"),); + .stderr(std::process::Stdio::piped()); + let outputs = proc.spawn().unwrap().wait_with_output().await.unwrap(); + let stdout = String::from_utf8(outputs.stdout).unwrap(); + assert!(stdout.contains("Hello World"),); } #[tokio::test] From 12b98e04f94f6e2870177d3c4a9f0264482f4232 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sat, 8 Feb 2025 21:41:37 +0900 Subject: [PATCH 3/4] README --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 2d3ee85d..b4580369 100644 --- a/README.md +++ b/README.md @@ -140,6 +140,8 @@ Options: HTTP method [default: GET] -H Custom HTTP header. Examples: -H "foo: bar" + --proxy-header + Custom Proxy HTTP header. Examples: --proxy-header "foo: bar" -t Timeout for each request. Default to infinite. -A From 07427ec621e96cef01a0cf3f49fd65b6e1e1d978 Mon Sep 17 00:00:00 2001 From: hatoo Date: Sun, 9 Feb 2025 12:02:25 +0900 Subject: [PATCH 4/4] Fix test for windows --- tests/tests.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/tests.rs b/tests/tests.rs index b849a97d..2bce9b9c 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -818,12 +818,17 @@ async fn test_proxy_with_setting(https: bool, http2: bool, proxy_http2: bool) { proc.arg("--proxy-http2"); } + // When std::process::Stdio::piped() is used, the wait_with_output() method will hang in Windows. proc.stdin(std::process::Stdio::null()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()); + .stdout(std::process::Stdio::inherit()) + .stderr(std::process::Stdio::inherit()); + // So, we test status code only for now. + assert!(proc.status().await.unwrap().success()); + /* let outputs = proc.spawn().unwrap().wait_with_output().await.unwrap(); let stdout = String::from_utf8(outputs.stdout).unwrap(); assert!(stdout.contains("Hello World"),); + */ } #[tokio::test]