Skip to content

Commit

Permalink
feat(cmd): Add execute timeout (#16)
Browse files Browse the repository at this point in the history
Added execute timeout for command. This can prevent command hanging cause csync hanging.
  • Loading branch information
fioncat authored Mar 28, 2024
1 parent c0a18cb commit 0064994
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 28 deletions.
5 changes: 2 additions & 3 deletions src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ pub use watch::WatchClient;
use std::borrow::Cow;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use anyhow::{bail, Context, Result};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::{lookup_host, TcpSocket, TcpStream};
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio::time::{self, Instant};

use crate::net::auth::Auth;
use crate::net::conn::Connection;
use crate::net::frame::{self, DataFrame, Frame};
use crate::utils;

struct Client<S: AsyncWrite + AsyncRead + Unpin + Send> {
conn: Arc<Mutex<Connection<S>>>,
Expand Down Expand Up @@ -97,7 +96,7 @@ impl<S: AsyncWrite + AsyncRead + Unpin + Send + 'static> Client<S> {
let _ = done_tx.send(result);
});

match time::timeout_at(Instant::now() + Duration::from_secs(1), done_rx).await {
match utils::with_timeout(done_rx).await {
Ok(result) => result.unwrap(),
Err(_) => bail!("send data timeout after 1s"),
}
Expand Down
1 change: 1 addition & 0 deletions src/sync/read.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ impl Reader {
}
let result = Cmd::new(&self.cfg.cmd, None, true)
.execute()
.await
.context("execute read command");
if result.is_err() && self.cfg.allow_cmd_failure {
Ok(None)
Expand Down
12 changes: 6 additions & 6 deletions src/sync/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,13 @@ impl Writer {

let is_image = from_utf8(&frame.body).is_err();
if is_image {
if let Err(err) = self.handle_image(frame) {
if let Err(err) = self.handle_image(frame).await {
self.err_tx.send(err).await.unwrap();
}
continue;
}

if let Err(err) = self.handle_text(frame) {
if let Err(err) = self.handle_text(frame).await {
self.err_tx.send(err).await.unwrap();
}
}
Expand All @@ -104,7 +104,7 @@ impl Writer {
Ok(())
}

fn handle_image(&mut self, frame: DataFrame) -> Result<()> {
async fn handle_image(&mut self, frame: DataFrame) -> Result<()> {
if self.cfg.download_image {
let path = self.image_path.as_ref().unwrap();
println!("<Download image to {}>", path.display());
Expand All @@ -116,19 +116,19 @@ impl Writer {
if !self.cfg.image_cmd.is_empty() {
println!("<Execute image command>");
let mut cmd = Cmd::new(&self.cfg.image_cmd, Some(frame.body), false);
cmd.execute().context("execute image command")?;
cmd.execute().await.context("execute image command")?;
return Ok(());
}

println!("<Image data>");
Ok(())
}

fn handle_text(&mut self, frame: DataFrame) -> Result<()> {
async fn handle_text(&mut self, frame: DataFrame) -> Result<()> {
if !self.cfg.text_cmd.is_empty() {
println!("<Execute text command>");
let mut cmd = Cmd::new(&self.cfg.text_cmd, Some(frame.body), false);
cmd.execute().context("execute text command")?;
cmd.execute().await.context("execute text command")?;
return Ok(());
}

Expand Down
56 changes: 37 additions & 19 deletions src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
use std::fs;
use std::io::{self, Read, Write};
use std::future::Future;
use std::io;
use std::path::Path;
use std::process::Command;
use std::process::Stdio;
use std::time::Duration;

use anyhow::bail;
use anyhow::{Context, Result};
use log::info;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::process::Command;
use tokio::time::{self, Instant, Timeout};

pub fn ensure_dir<P: AsRef<Path>>(dir: P) -> Result<()> {
match fs::read_dir(dir.as_ref()) {
Expand Down Expand Up @@ -81,38 +85,48 @@ impl Cmd {
Cmd { cmd, input }
}

pub fn execute(&mut self) -> Result<Option<Vec<u8>>> {
pub async fn execute(&mut self) -> Result<Option<Vec<u8>>> {
let mut child = match self.cmd.spawn() {
Ok(child) => child,
Err(e) if e.kind() == io::ErrorKind::NotFound => {
bail!(
"cannot find command `{}`, please make sure it is installed",
self.get_name()
);
}
Err(e) => {
return Err(e)
.with_context(|| format!("cannot launch command `{}`", self.get_name()))
bail!("cannot find command, please make sure it is installed");
}
Err(e) => return Err(e).context("cannot launch command"),
};

if let Some(input) = &self.input {
let handle = child.stdin.as_mut().unwrap();
handle
.write_all(input)
.with_context(|| format!("write input to command `{}`", self.get_name()))?;
.await
.context("write input to command")?;
drop(child.stdin.take());
}

let mut stdout = child.stdout.take();

let status = child.wait().context("wait command done")?;
let status = match with_timeout(child.wait()).await {
Ok(result) => result.context("wait command exit")?,
Err(_) => {
// The command hang, try to kill it to avoid leakage. The kill also has a
// timeout.
if with_timeout(child.kill()).await.is_err() {
// Kill failed, the child process is completely blocked now and cannot
// handle kill signal. We donot known how to handle this, report the
// warning message. Let user to handle this.
let id = child.id().unwrap_or(0);
println!("WARN: Failed to kill child process {id} after timeout, process leakage may appear, please be attention");
}
bail!("execute command timeout after 1s");
}
};
let output = match stdout.as_mut() {
Some(stdout) => {
let mut out = Vec::new();
stdout
.read_to_end(&mut out)
.with_context(|| format!("read stdout from command `{}`", self.get_name()))?;
.await
.context("read stdout from command")?;
Some(out)
}
None => None,
Expand All @@ -128,11 +142,6 @@ impl Cmd {
None => bail!("command exited with unknown code"),
}
}

#[inline]
fn get_name(&self) -> &str {
self.cmd.get_program().to_str().unwrap_or("<unknown>")
}
}

pub fn get_digest(data: &[u8]) -> String {
Expand All @@ -147,3 +156,12 @@ pub fn shellexpand(s: impl AsRef<str>) -> Result<String> {
.with_context(|| format!("expand env for '{}'", s.as_ref()))
.map(|s| s.into_owned())
}

/// Every long operations should have an 1s timeout.
#[inline]
pub fn with_timeout<F>(future: F) -> Timeout<F>
where
F: Future,
{
time::timeout_at(Instant::now() + Duration::from_secs(1), future)
}

0 comments on commit 0064994

Please sign in to comment.