diff --git a/Cargo.toml b/Cargo.toml index b0af201..79b6d48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,7 +59,6 @@ io-enum = "1.0.0" flate2 = { version = "1.0", default-features = false } lru = "0.7" mysql_common = { version = "0.28.0", default-features = false } -socket2 = "0.4" once_cell = "1.7.2" pem = "1.0.1" percent-encoding = "2.1.0" @@ -94,3 +93,9 @@ named_pipe = "~0.4" [target.'cfg(unix)'.dependencies] libc = "0.2" + +[target.'cfg(not(target_os = "wasi"))'.dependencies] +socket2 = "0.4" + +[target.'cfg(target_os = "wasi")'.dependencies] +wasmedge_wasi_socket = { git = "https://github.com/second-state/wasmedge_wasi_socket" } diff --git a/src/conn/mod.rs b/src/conn/mod.rs index 7ce7147..b3404ce 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -28,6 +28,8 @@ use mysql_common::{ packets::SslRequest, }; +#[cfg(not(target_os = "wasi"))] +use std::process; use std::{ borrow::{Borrow, Cow}, cmp, @@ -36,7 +38,6 @@ use std::{ io::{self, Write as _}, mem, ops::{Deref, DerefMut}, - process, sync::Arc, }; @@ -406,33 +407,56 @@ impl Conn { let tcp_nodelay = opts.get_tcp_nodelay(); let tcp_connect_timeout = opts.get_tcp_connect_timeout(); let bind_address = opts.bind_address().cloned(); - let stream = if let Some(socket) = opts.get_socket() { - Stream::connect_socket(&*socket, read_timeout, write_timeout)? - } else { + #[cfg(not(target_os = "wasi"))] + { + let stream = if let Some(socket) = opts.get_socket() { + Stream::connect_socket(&*socket, read_timeout, write_timeout)? + } else { + let port = opts.get_tcp_port(); + let ip_or_hostname = match opts.get_host() { + url::Host::Domain(domain) => domain, + url::Host::Ipv4(ip) => ip.to_string(), + url::Host::Ipv6(ip) => ip.to_string(), + }; + Stream::connect_tcp( + &*ip_or_hostname, + port, + read_timeout, + write_timeout, + tcp_keepalive_time, + #[cfg(any(target_os = "linux", target_os = "macos",))] + tcp_keepalive_probe_interval_secs, + #[cfg(any(target_os = "linux", target_os = "macos",))] + tcp_keepalive_probe_count, + #[cfg(target_os = "linux")] + tcp_user_timeout, + tcp_nodelay, + tcp_connect_timeout, + bind_address, + )? + }; + self.0.stream = Some(MySyncFramed::new(stream)); + } + #[cfg(target_os = "wasi")] + { let port = opts.get_tcp_port(); let ip_or_hostname = match opts.get_host() { url::Host::Domain(domain) => domain, url::Host::Ipv4(ip) => ip.to_string(), url::Host::Ipv6(ip) => ip.to_string(), }; - Stream::connect_tcp( + let stream = Stream::connect_tcp( &*ip_or_hostname, port, read_timeout, write_timeout, tcp_keepalive_time, - #[cfg(any(target_os = "linux", target_os = "macos",))] - tcp_keepalive_probe_interval_secs, - #[cfg(any(target_os = "linux", target_os = "macos",))] - tcp_keepalive_probe_count, - #[cfg(target_os = "linux")] - tcp_user_timeout, tcp_nodelay, tcp_connect_timeout, bind_address, - )? - }; - self.0.stream = Some(MySyncFramed::new(stream)); + )?; + self.0.stream = Some(MySyncFramed::new(stream)); + } Ok(()) } @@ -650,7 +674,10 @@ impl Conn { attrs.insert("_client_name".into(), "rust-mysql-simple".into()); attrs.insert("_client_version".into(), env!("CARGO_PKG_VERSION").into()); attrs.insert("_os".into(), env!("CARGO_CFG_TARGET_OS").into()); + #[cfg(not(target_os = "wasi"))] attrs.insert("_pid".into(), process::id().to_string()); + #[cfg(target_os = "wasi")] + attrs.insert("_pid".into(), "66666".into()); attrs.insert("_platform".into(), env!("CARGO_CFG_TARGET_ARCH").into()); attrs.insert("program_name".into(), program_name); @@ -1389,6 +1416,7 @@ mod test { assert_eq!(db_name, DB_NAME); } + #[cfg(not(target_os = "wasi"))] #[test] fn should_connect_by_hostname() { let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost")); @@ -1475,6 +1503,7 @@ mod test { ); } + #[cfg(not(target_os = "wasi"))] #[test] fn should_parse_large_text_result() { let mut conn = Conn::new(get_opts()).unwrap(); @@ -1538,6 +1567,7 @@ mod test { assert_eq!(rows, vec![row1, row2]); } + #[cfg(not(target_os = "wasi"))] #[test] fn should_parse_large_binary_result() { let mut conn = Conn::new(get_opts()).unwrap(); @@ -1714,6 +1744,7 @@ mod test { ); } + #[cfg(not(target_os = "wasi"))] #[test] fn should_connect_via_socket_for_127_0_0_1() { let opts = OptsBuilder::from_opts(get_opts()); @@ -1723,6 +1754,7 @@ mod test { } } + #[cfg(not(target_os = "wasi"))] #[test] fn should_connect_via_socket_localhost() { let opts = OptsBuilder::from_opts(get_opts()).ip_or_hostname(Some("localhost")); @@ -1735,6 +1767,7 @@ mod test { /// QueryResult::drop hangs on connectivity errors (see [blackbeam/rust-mysql-simple#306][1]). /// /// [1]: https://github.com/blackbeam/rust-mysql-simple/issues/306 + #[cfg(not(target_os = "wasi"))] #[test] fn issue_306() { let (tx, rx) = channel::<()>(); @@ -1884,6 +1917,7 @@ mod test { .unwrap(); } + #[cfg(not(target_os = "wasi"))] #[test] fn issue_285() { let (tx, rx) = sync_channel::<()>(0); @@ -1962,6 +1996,7 @@ mod test { } } + #[cfg(not(target_os = "wasi"))] #[test] fn should_handle_tcp_connect_timeout() { use crate::error::{DriverError::ConnectTimeout, Error::DriverError}; @@ -1999,6 +2034,7 @@ mod test { assert_eq!(result.affected_rows(), 1); } + #[cfg(not(target_os = "wasi"))] #[test] fn should_bind_before_connect() { let port = 28000 + (rand::random::() % 2000); @@ -2017,6 +2053,7 @@ mod test { ); } + #[cfg(not(target_os = "wasi"))] #[test] fn should_bind_before_connect_with_timeout() { let port = 30000 + (rand::random::() % 2000); @@ -2177,8 +2214,10 @@ mod test { ); } } - + #[cfg(not(target_os = "wasi"))] let pid = process::id().to_string(); + #[cfg(target_os = "wasi")] + let pid = "66666".to_string(); let progname = std::env::args_os() .next() .unwrap() @@ -2211,6 +2250,7 @@ mod test { } } + #[cfg(not(target_os = "wasi"))] #[test] fn should_read_binlog() -> crate::Result<()> { use std::{ diff --git a/src/conn/pool.rs b/src/conn/pool.rs index 2115945..b6bf933 100644 --- a/src/conn/pool.rs +++ b/src/conn/pool.rs @@ -396,6 +396,7 @@ impl Queryable for PooledConn { } } +#[cfg(not(target_os = "wasi"))] #[cfg(test)] #[allow(non_snake_case)] mod test { diff --git a/src/io/mod.rs b/src/io/mod.rs index d87908b..3686d51 100644 --- a/src/io/mod.rs +++ b/src/io/mod.rs @@ -11,16 +11,17 @@ use io_enum::*; #[cfg(windows)] use named_pipe as np; +#[cfg(not(target_os = "wasi"))] +use std::net::{self, SocketAddr}; #[cfg(unix)] use std::os::{ unix, unix::io::{AsRawFd, RawFd}, }; -use std::{ - fmt, io, - net::{self, SocketAddr}, - time::Duration, -}; +use std::{fmt, io, time::Duration}; + +#[cfg(target_os = "wasi")] +use wasmedge_wasi_socket::{self, SocketAddr}; use crate::error::{ DriverError::{ConnectTimeout, CouldNotConnect}, @@ -141,6 +142,7 @@ impl Stream { } } + #[cfg(not(target_os = "wasi"))] pub fn is_socket(&self) -> bool { match self { Stream::SocketStream(_) => true, @@ -148,6 +150,11 @@ impl Stream { } } + #[cfg(target_os = "wasi")] + pub fn is_socket(&self) -> bool { + false + } + #[cfg(all(not(feature = "native-tls"), not(feature = "rustls")))] pub fn make_secure(self, _host: url::Host, _ssl_opts: crate::SslOpts) -> MyResult { panic!( @@ -173,7 +180,10 @@ pub enum TcpStream { Secure(BufStream>), #[cfg(feature = "rustls")] Secure(BufStream>), + #[cfg(not(target_os = "wasi"))] Insecure(BufStream), + #[cfg(target_os = "wasi")] + Insecure(BufStream), } #[cfg(unix)] diff --git a/src/io/tcp.rs b/src/io/tcp.rs index 1523a5a..10c67bd 100644 --- a/src/io/tcp.rs +++ b/src/io/tcp.rs @@ -5,13 +5,15 @@ // license , at your // option. All files in the project carrying such notice may not be copied, // modified, or distributed except according to those terms. - +#[cfg(not(target_os = "wasi"))] use socket2::{Domain, SockAddr, Socket, Type}; - -use std::{ - io, - net::{SocketAddr, TcpStream, ToSocketAddrs}, - time::Duration, +#[cfg(not(target_os = "wasi"))] +use std::net::{SocketAddr, TcpStream, ToSocketAddrs}; +use std::{io, time::Duration}; +#[cfg(target_os = "wasi")] +use wasmedge_wasi_socket::{ + socket::{AddressFamily, Socket, SocketType}, + SocketAddr, TcpStream, ToSocketAddrs, }; pub struct MyTcpBuilder { @@ -132,15 +134,30 @@ impl MyTcpBuilder { let fold_fun = |prev, sock_addr: &SocketAddr| match prev { Ok(socket) => Ok(socket), Err(_) => { - let domain = Domain::for_address(*sock_addr); - let socket = Socket::new(domain, Type::STREAM, None)?; - socket.bind(&bind_address.into())?; - if let Some(connect_timeout) = connect_timeout { - socket.connect_timeout(&SockAddr::from(*sock_addr), connect_timeout)?; - } else { - socket.connect(&SockAddr::from(*sock_addr))?; + #[cfg(not(target_os = "wasi"))] + { + let domain = Domain::for_address(*sock_addr); + let socket = Socket::new(domain, Type::STREAM, None)?; + socket.bind(&bind_address.into())?; + if let Some(connect_timeout) = connect_timeout { + socket.connect_timeout(&SockAddr::from(*sock_addr), connect_timeout)?; + } else { + socket.connect(&SockAddr::from(*sock_addr))?; + } + Ok(socket) + } + #[cfg(target_os = "wasi")] + { + let domain = if sock_addr.is_ipv4() { + AddressFamily::Inet4 + } else { + AddressFamily::Inet6 + }; + let socket = Socket::new(domain, SocketType::Stream)?; + socket.bind(&bind_address.into())?; + socket.connect(&SocketAddr::from(*sock_addr))?; + Ok(socket) } - Ok(socket) } }; @@ -166,24 +183,40 @@ impl MyTcpBuilder { .fold(Err(err), |prev, sock_addr| match prev { Ok(socket) => Ok(socket), Err(_) => { - let domain = Domain::for_address(sock_addr); - let socket = Socket::new(domain, Type::STREAM, None)?; - if let Some(connect_timeout) = connect_timeout { - socket.connect_timeout(&sock_addr.into(), connect_timeout)?; - } else { + #[cfg(not(target_os = "wasi"))] + { + let domain = Domain::for_address(sock_addr); + let socket = Socket::new(domain, Type::STREAM, None)?; + if let Some(connect_timeout) = connect_timeout { + socket.connect_timeout(&sock_addr.into(), connect_timeout)?; + } else { + socket.connect(&sock_addr.into())?; + } + Ok(socket) + } + #[cfg(target_os = "wasi")] + { + let domain = if sock_addr.is_ipv4() { + AddressFamily::Inet4 + } else { + AddressFamily::Inet6 + }; + let socket = Socket::new(domain, SocketType::Stream)?; socket.connect(&sock_addr.into())?; + Ok(socket) } - Ok(socket) } }) }?; - - socket.set_read_timeout(read_timeout)?; - socket.set_write_timeout(write_timeout)?; - if let Some(duration) = keepalive_time_ms { - let conf = - socket2::TcpKeepalive::new().with_time(Duration::from_millis(duration as u64)); - socket.set_tcp_keepalive(&conf)?; + #[cfg(not(target_os = "wasi"))] + { + socket.set_read_timeout(read_timeout)?; + socket.set_write_timeout(write_timeout)?; + if let Some(duration) = keepalive_time_ms { + let conf = + socket2::TcpKeepalive::new().with_time(Duration::from_millis(duration as u64)); + socket.set_tcp_keepalive(&conf)?; + } } #[cfg(any(target_os = "linux", target_os = "macos",))] if let Some(keepalive_probe_interval_secs) = keepalive_probe_interval_secs { @@ -236,6 +269,7 @@ impl MyTcpBuilder { } } } + #[cfg(not(target_os = "wasi"))] socket.set_nodelay(nodelay)?; Ok(TcpStream::from(socket)) }