Skip to content

Commit

Permalink
draft: try to use channel in worker loop
Browse files Browse the repository at this point in the history
  • Loading branch information
Keksoj committed Jan 2, 2024
1 parent a9b1ce0 commit b063ac4
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 48 deletions.
127 changes: 93 additions & 34 deletions bin/src/command/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ use futures::{
mpsc::{channel, Receiver, Sender},
oneshot,
},
{SinkExt, StreamExt},
AsyncReadExt, {SinkExt, StreamExt},
};
use futures_lite::{
future,
io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
};
use mio::net::UnixStream as MioUnixStream;
use nix::{
sys::signal::{kill, Signal},
unistd::Pid,
Expand All @@ -31,7 +32,7 @@ use prost::Message;
use serde::{Deserialize, Serialize};

use sozu_command_lib::{
channel::delimiter_size,
channel::{delimiter_size, Channel},
config::Config,
logging::setup_logging_with_config,
proto::command::{
Expand Down Expand Up @@ -257,20 +258,21 @@ impl CommandServer {
.worker_channel
.take()
.with_context(|| format!("No channel present in worker {}", worker.id))?
.sock;
.sock
.as_raw_fd();
let (worker_tx, worker_rx) = channel(10000);
worker.sender = Some(worker_tx);

let main_to_worker_stream = Async::new(unsafe {
let fd = main_to_worker_channel.into_raw_fd();
UnixStream::from_raw_fd(fd)
})
.with_context(|| "Could not get a unix stream from the file descriptor")?;
// let main_to_worker_stream = Async::new(unsafe {
// let fd = main_to_worker_channel.into_raw_fd();
// UnixStream::from_raw_fd(fd)
// })
// .with_context(|| "Could not get a unix stream from the file descriptor")?;

let id = worker.id;
let command_tx = command_tx.clone();
smol::spawn(async move {
worker_loop(id, main_to_worker_stream, command_tx, worker_rx).await;
worker_loop(id, main_to_worker_channel, command_tx, worker_rx).await;
})
.detach();
}
Expand Down Expand Up @@ -425,14 +427,15 @@ impl CommandServer {
let sender = Some(worker_tx);

debug!("deserializing worker: {:?}", serialized);
let worker_stream = Async::new(unsafe { UnixStream::from_raw_fd(serialized.fd) })
.with_context(|| "Could not create an async unix stream to spawn the worker")?;
// let worker_stream = unsafe { MioUnixStream::from_raw_fd(serialized.fd) };
// .with_context(|| "Could not create an async unix stream to spawn the worker")?;

let id = serialized.id;
let command_tx = tx.clone();
//async fn worker(id: u32, sock: Async<UnixStream>, tx: Sender<CommandMessage>, rx: Receiver<()>) -> std::io::Result<()> {
let socket_fd = serialized.fd.clone();
smol::spawn(async move {
worker_loop(id, worker_stream, command_tx, worker_rx).await;
worker_loop(id, socket_fd, command_tx, worker_rx).await;
})
.detach();

Expand Down Expand Up @@ -638,7 +641,7 @@ impl CommandServer {
info!("created new worker: {}", new_worker_id);
self.next_worker_id += 1;

let sock = new_worker
let worker_socket = new_worker
.worker_channel
.take()
.with_context(|| {
Expand All @@ -647,19 +650,23 @@ impl CommandServer {
new_worker.id
)
})? // this used to crash with unwrap(), do we still want to crash?
.sock;
.sock
.as_raw_fd();
let (worker_tx, worker_rx) = channel(10_000);
new_worker.sender = Some(worker_tx);

/*
let stream = Async::new(unsafe {
let fd = sock.into_raw_fd();
UnixStream::from_raw_fd(fd)
})?;
*/

let new_worker_id = new_worker.id;
let command_tx = self.command_tx.clone();

smol::spawn(async move {
worker_loop(new_worker_id, stream, command_tx, worker_rx).await;
worker_loop(new_worker_id, worker_socket, command_tx, worker_rx).await;
})
.detach();

Expand Down Expand Up @@ -756,7 +763,7 @@ impl CommandServer {
// FIXME: this message happens a lot at startup because AddCluster
// messages receive responses from each of the HTTP, HTTPS and TCP
// proxys. The clusters list should be merged
debug!("unknown response id: {}", response.id);
// debug!("unknown response id: {}", response.id);
}
Some((mut requester_tx, mut expected_responses)) => {
let response_id = response.id.clone();
Expand Down Expand Up @@ -788,7 +795,7 @@ impl CommandServer {
Ok(Success::WorkerResponse)
}

/// Count frontends and backends in the state, update their count in the CommandServer
/// Count frontends and backends in the state, update the server state accordingly
pub fn update_counts(&mut self) {
self.backends_count = self.state.count_backends();
self.frontends_count = self.state.count_frontends();
Expand Down Expand Up @@ -1059,40 +1066,78 @@ async fn client_loop(
/// - parse ProxyResponses from the unix stream and send them to the CommandServer
async fn worker_loop(
worker_id: u32,
stream: Async<UnixStream>,
socket_fd: i32,
mut command_tx: Sender<CommandMessage>,
mut worker_rx: Receiver<WorkerRequest>,
) {
let read_stream = Arc::new(stream);
let mut write_stream = read_stream.clone();
let read_stream = unsafe { MioUnixStream::from_raw_fd(socket_fd) };
let mut read_channel: Arc<Channel<WorkerRequest, WorkerResponse>> =
Arc::new(Channel::new(read_stream, 0, 169_480));
// read_channel.blocking().unwrap();

let write_stream = unsafe { MioUnixStream::from_raw_fd(socket_fd) };
let mut write_channel: Channel<WorkerRequest, ()> = Channel::new(write_stream, 0, 169_480);
write_channel.blocking().unwrap();
// let read_stream = Arc::new(stream);
// let mut write_stream = read_stream.clone();

smol::spawn(async move {
debug!("will start sending messages to worker {}", worker_id);
while let Some(worker_request) = worker_rx.next().await {
debug!("sending to worker {}: {:?}", worker_id, worker_request);

write_channel.write_message(&worker_request).expect("yolo");

/*
// TODO: the best would be to use a channel
let payload = worker_request.encode_to_vec();
let payload_len = payload.len() + delimiter_size();
let delimiter = payload_len.to_le_bytes();
let _ = write_stream.write_all(&delimiter).await;
let _ = write_stream.write_all(&payload).await;
let _ = write_channel.write_all(&delimiter).await;
let _ = write_channel.write_all(&payload).await;
*/
}
})
.detach();

debug!("will start receiving messages from worker {}", worker_id);

while let Ok(response) = read_channel.read_message() {
debug!("got worker response: {:?}", response);
if let Err(e) = command_tx
.send(CommandMessage::WorkerResponse {
worker_id,
response,
})
.await
{
error!("error sending worker response to command server: {:?}", e);
}
}
/*
// this does essentially what Channel::try_read_delimited_message() does
let mut buf_reader = BufReader::new(read_stream);
let mut buf_reader = BufReader::new(read_channel);
// thiss buffer is growable if the message is incomplete
let mut message_buffer: Vec<u8> = Vec::new();
let mut missing_message_length: usize = 0;
let mut loop_counter = 0usize;
let mut non_empty_times = 0usize;
loop {
println!("worker loop {}", loop_counter);
loop_counter += 1;
println!("buf reader len: {}", buf_reader.buffer().len());
if !buf_reader.buffer().is_empty() {
non_empty_times += 1;
} else {
non_empty_times = 0;
}
let buffer = match buf_reader.fill_buf().await {
Ok(buf) => buf,
Err(e) => {
Expand All @@ -1107,29 +1152,42 @@ async fn worker_loop(
}
let buffer_len = buffer.len();
println!("buffer length: {}", buffer_len);
println!("missing message length: {}", missing_message_length);
if missing_message_length == 0 {
let mut delimiter = [0u8; delimiter_size()];
if buffer_len >= delimiter_size() {
let delimiter: [u8; delimiter_size()] = match buffer[..delimiter_size()].try_into()
{
delimiter = match buffer[..delimiter_size()].try_into() {
Ok(delimiter) => delimiter,
Err(_) => {
error!("mismatched buffer size");
break;
}
};
let message_len = usize::from_le_bytes(delimiter);
} else {
// let mut missing_delimiter: Vec<u8> = vec![0; delimiter_size() - buffer_len];
delimiter[..buffer_len].copy_from_slice(buffer);
let _ = buf_reader
.read(&mut delimiter[buffer_len..])
.await
.expect("Read delimiter, the black magic failed lol");
}
let message_len = usize::from_le_bytes(delimiter);
if message_len > 1_000_000_000 {
error!("Skipping invalid message");
buf_reader.consume(buffer_len);
}
if message_len > 1_000_000_000 {
// println!("{:02x?}", buffer);
panic!("skipping invalid message");
error!("Skipping invalid message");
buf_reader.consume(buffer_len);
continue;
}
buf_reader.consume(delimiter_size());
buf_reader.consume(delimiter_size());
missing_message_length = message_len - delimiter_size();
missing_message_length = message_len - delimiter_size();
}
continue;
continue;
}
// grow the incomplete message buffer
Expand Down Expand Up @@ -1167,6 +1225,7 @@ async fn worker_loop(
error!("error sending worker response to command server: {:?}", e);
}
}
*/
error!("worker loop stopped, will close the worker {}", worker_id);

// if the loop breaks, request the command server to close the worker
Expand Down
35 changes: 24 additions & 11 deletions bin/src/command/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{
collections::{BTreeMap, HashSet},
fs::File,
io::{ErrorKind, Read},
os::unix::io::{FromRawFd, IntoRawFd},
os::unix::io::{AsRawFd, FromRawFd, IntoRawFd},
os::unix::net::UnixStream,
time::{Duration, Instant},
};
Expand Down Expand Up @@ -155,7 +155,7 @@ impl CommandServer {

let mut buffer = Buffer::with_capacity(200000);

info!("starting to load state from {}", path);
println!("starting to load state from {}", path);

let mut message_counter = 0usize;
let mut diff_counter = 0usize;
Expand All @@ -180,7 +180,11 @@ impl CommandServer {
match parse_several_requests::<WorkerRequest>(buffer.data()) {
Ok((i, requests)) => {
if !i.is_empty() {
debug!("could not parse {} bytes", i.len());
debug!(
"load-state: {} unparsed bytes: {}",
i.len(),
String::from_utf8_lossy(i)
);
if previous == buffer.available_data() {
bail!("error consuming load state message");
}
Expand Down Expand Up @@ -254,8 +258,10 @@ impl CommandServer {

let command_tx = self.command_tx.to_owned();
let path = path.to_owned();

println!("MAIN: creating thread");
smol::spawn(async move {
println!("MAIN load-state: awaiting worker responses");
debug!("MAIN load-state: awaiting worker responses");
let mut ok = 0usize;
let mut error = 0usize;
while let Some((proxy_response, _)) = load_state_rx.next().await {
Expand All @@ -276,8 +282,8 @@ impl CommandServer {
Some(client_id) => client_id,
None => {
match error {
0 => info!("loading state: {} ok messages, 0 errors", ok),
_ => error!("loading state: {} ok messages, {} errors", ok, error),
0 => println!("loading state: {} ok messages, 0 errors", ok),
_ => println!("loading state: {} ok messages, {} errors", ok, error),
}
return;
}
Expand All @@ -304,6 +310,7 @@ impl CommandServer {
}
})
.detach();
println!("MAIN: detached thread");

self.update_counts();
Ok(None)
Expand Down Expand Up @@ -393,24 +400,27 @@ impl CommandServer {

self.next_worker_id += 1;

let sock = worker
let worker_socket = worker
.worker_channel
.take()
.expect("No channel on the worker being launched")
.sock;
.sock
.as_raw_fd();
let (worker_tx, worker_rx) = channel(10000);
worker.sender = Some(worker_tx);

/*
let stream = Async::new(unsafe {
let fd = sock.into_raw_fd();
UnixStream::from_raw_fd(fd)
})?;
*/

let id = worker.id;
let command_tx = self.command_tx.clone();

smol::spawn(async move {
super::worker_loop(id, stream, command_tx, worker_rx).await;
super::worker_loop(id, worker_socket, command_tx, worker_rx).await;
})
.detach();

Expand Down Expand Up @@ -530,7 +540,8 @@ impl CommandServer {
.worker_channel
.take()
.with_context(|| "No channel on new worker".to_string())?
.sock;
.sock
.as_raw_fd();
let (worker_tx, worker_rx) = channel(10000);
new_worker.sender = Some(worker_tx);

Expand Down Expand Up @@ -686,15 +697,17 @@ impl CommandServer {
None => error!("could not get the list of listeners from the previous worker"),
};

/*
let stream = Async::new(unsafe {
let fd = sock.into_raw_fd();
UnixStream::from_raw_fd(fd)
})?;
*/

let id = new_worker.id;
let command_tx = self.command_tx.clone();
smol::spawn(async move {
super::worker_loop(id, stream, command_tx, worker_rx).await;
super::worker_loop(id, sock, command_tx, worker_rx).await;
})
.detach();

Expand Down
Loading

0 comments on commit b063ac4

Please sign in to comment.