Skip to content

Commit

Permalink
Show TCP read
Browse files Browse the repository at this point in the history
  • Loading branch information
hatoo committed Oct 4, 2022
1 parent 0982257 commit ea9b661
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 17 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Unreleased

- Show TCP read bytes instead of body size

# 0.5.5 (2022-09-19)

- Add colors to the tui view #64
Expand Down
33 changes: 16 additions & 17 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use futures::future::FutureExt;
use futures::StreamExt;
use rand::prelude::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use thiserror::Error;

use crate::tcp_stream::CustomTcpStream;
use crate::ConnectToEntry;

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -128,6 +130,7 @@ impl ClientBuilder {
rng: rand::rngs::StdRng::from_entropy(),
},
client: None,
read_bytes_counter: Arc::new(AtomicUsize::new(0)),
timeout: self.timeout,
http_version: self.http_version,
redirect_limit: self.redirect_limit,
Expand Down Expand Up @@ -195,6 +198,7 @@ pub struct Client {
body: Option<&'static [u8]>,
dns: DNS,
client: Option<hyper::client::conn::SendRequest<hyper::Body>>,
read_bytes_counter: Arc<AtomicUsize>,
timeout: Option<std::time::Duration>,
redirect_limit: usize,
disable_keepalive: bool,
Expand All @@ -211,6 +215,7 @@ impl Client {
} else {
let stream = tokio::net::TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
let stream = CustomTcpStream::new(stream, self.read_bytes_counter.clone());
// stream.set_keepalive(std::time::Duration::from_secs(1).into())?;
let (send, conn) = hyper::client::conn::handshake(stream).await?;
tokio::spawn(conn);
Expand All @@ -225,6 +230,7 @@ impl Client {
) -> Result<hyper::client::conn::SendRequest<hyper::Body>, ClientError> {
let stream = tokio::net::TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
let stream = CustomTcpStream::new(stream, self.read_bytes_counter.clone());

let connector = if self.insecure {
native_tls::TlsConnector::builder()
Expand All @@ -251,6 +257,7 @@ impl Client {
) -> Result<hyper::client::conn::SendRequest<hyper::Body>, ClientError> {
let stream = tokio::net::TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
let stream = CustomTcpStream::new(stream, self.read_bytes_counter.clone());

let mut root_cert_store = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs()? {
Expand Down Expand Up @@ -331,19 +338,17 @@ impl Client {
connection_time = Some(ConnectionTime { dns_lookup, dialup });
}
let request = self.request(&self.url)?;
self.read_bytes_counter.store(0, Ordering::Relaxed);
match send_request.send_request(request).await {
Ok(res) => {
let (parts, mut stream) = res.into_parts();
let mut status = parts.status;

let mut len_sum = 0;
while let Some(chunk) = stream.next().await {
len_sum += chunk?.len();
}
while stream.next().await.is_some() {}

if self.redirect_limit != 0 {
if let Some(location) = parts.headers.get("Location") {
let (send_request_redirect, new_status, len) = self
let (send_request_redirect, new_status) = self
.redirect(
send_request,
&self.url.clone(),
Expand All @@ -354,7 +359,6 @@ impl Client {

send_request = send_request_redirect;
status = new_status;
len_sum = len;
}
}

Expand All @@ -364,7 +368,7 @@ impl Client {
start,
end,
status,
len_bytes: len_sum,
len_bytes: self.read_bytes_counter.load(Ordering::Relaxed),
connection_time,
};

Expand Down Expand Up @@ -404,7 +408,6 @@ impl Client {
(
hyper::client::conn::SendRequest<hyper::Body>,
http::StatusCode,
usize,
),
ClientError,
>,
Expand Down Expand Up @@ -451,28 +454,25 @@ impl Client {
)?,
);
}
self.read_bytes_counter.store(0, Ordering::Relaxed);
let res = send_request.send_request(request).await?;
let (parts, mut stream) = res.into_parts();
let mut status = parts.status;

let mut len_sum = 0;
while let Some(chunk) = stream.next().await {
len_sum += chunk?.len();
}
while stream.next().await.is_some() {}

if let Some(location) = parts.headers.get("Location") {
let (send_request_redirect, new_status, len) = self
let (send_request_redirect, new_status) = self
.redirect(send_request, &url, location, limit - 1)
.await?;
send_request = send_request_redirect;
status = new_status;
len_sum = len;
}

if let Some(send_request_base) = send_request_base {
Ok((send_request_base, status, len_sum))
Ok((send_request_base, status))
} else {
Ok((send_request, status, len_sum))
Ok((send_request, status))
}
}
.boxed()
Expand Down Expand Up @@ -546,7 +546,6 @@ pub async fn work(
n_tasks: usize,
n_workers: usize,
) {
use std::sync::atomic::{AtomicUsize, Ordering};
let counter = Arc::new(AtomicUsize::new(0));

let futures = (0..n_workers)
Expand Down
1 change: 1 addition & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod client;
mod histogram;
mod monitor;
mod printer;
mod tcp_stream;
mod timescale;

use client::{ClientError, RequestResult};
Expand Down
69 changes: 69 additions & 0 deletions src/tcp_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Ported from https://github.com/lnx-search/rewrk/pull/6
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::ReadBuf;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;

use std::io::{IoSlice, Result};

pub struct CustomTcpStream {
inner: TcpStream,
counter: Arc<AtomicUsize>,
}

impl CustomTcpStream {
pub fn new(stream: TcpStream, counter: Arc<AtomicUsize>) -> Self {
Self {
inner: stream,
counter,
}
}
}

impl AsyncRead for CustomTcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<()>> {
let result = Pin::new(&mut self.inner).poll_read(cx, buf);

self.counter
.fetch_add(buf.filled().len(), Ordering::Relaxed);

result
}
}

impl AsyncWrite for CustomTcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

0 comments on commit ea9b661

Please sign in to comment.