diff --git a/src/client/tls_connection.rs b/src/client/tls_connection.rs index 932b414f..1d188e1c 100644 --- a/src/client/tls_connection.rs +++ b/src/client/tls_connection.rs @@ -5,6 +5,7 @@ use crate::message::rdata::a_rdata::ARdata; use crate::message::resource_record::ResourceRecord; use super::client_connection::ConnectionProtocol; use super::client_error::ClientError; +use super::client_security::ClientSecurity; use async_trait::async_trait; use futures_util::TryFutureExt; use rustls::pki_types::ServerName; @@ -39,10 +40,10 @@ pub struct ClientTLSConnection { } #[async_trait] -impl ClientConnection for ClientTLSConnection { +impl ClientSecurity for ClientTLSConnection { /// Creates TLSConnection - fn new(server_addr:IpAddr, timeout: Duration, new_default: usize) -> Self { + fn new(server_addr:IpAddr, timeout: Duration) -> Self { ClientTLSConnection { server_addr: server_addr, timeout: timeout, @@ -78,15 +79,15 @@ impl ClientConnection for ClientTLSConnection { // get the domain name to a srting let dns_name_from_message = dns_query.get_question().get_qname().to_string(); let server_name = ServerName::try_from(dns_name_from_message).expect("invalid DNS name"); - let name_server= ServerName::try_from(dns_query.get_question().get_qname().to_string()).expect("invalid DNS name"); + //let name_server= ServerName::try_from(dns_query.get_question().get_qname().to_string()).expect("invalid DNS name"); //let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name).unwrap(); - let mut connector = TlsConnector::from(Arc::new(config)); + let connector = TlsConnector::from(Arc::new(config)); let conn_timeout: Duration = self.get_timeout(); //let bytes: Vec = dns_query.to_bytes(); let server_addr:SocketAddr = SocketAddr::new(self.get_server_addr(), 453); - let name= dns_query.get_question().get_qname().to_string(); + // let name= dns_query.get_question().get_qname().to_string(); let stream = TcpStream::connect(server_addr).await; let stream = match stream { @@ -102,7 +103,7 @@ impl ClientConnection for ClientTLSConnection { format!("IP mismatch: expected {}, got {}", expected_ip, actual_ip), )).into()); } - let mut tls_stream = connector.connect(server_name, stream).await; + let tls_stream = connector.connect(server_name, stream).await; //let mut tls = rustls::Stream::new(&mut conn, &mut stream); @@ -124,10 +125,11 @@ impl ClientConnection for ClientTLSConnection { }; tls_stream_result.write_all(&full_msg).await?; // Read response - let mut msg_size_response: [u8; 2] = [0; 2]; + let msg_size_response: [u8; 2] = [0; 2]; //tls.read_exact(msg_size_response).await?; - let response_length = u16::from_be_bytes(msg_size_response) as usize; - let mut response = vec![0u8; response_length]; + + //let response_length = u16::from_be_bytes(msg_size_response) as usize; + //tls.read_exact(&mut response); @@ -197,16 +199,15 @@ mod tls_connection_test{ use crate::domain_name::DomainName; use crate::message::rrtype::Rrtype; use crate::message::rclass::Rclass; - const DEFAULT_SIZE: usize = 512; #[test] - fn create_tcp() { + fn create_tls() { // let domain_name = String::from("uchile.cl"); let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); let _port: u16 = 8088; let timeout = Duration::from_secs(100); - let _conn_new = ClientTLSConnection::new(ip_addr,timeout, DEFAULT_SIZE); + let _conn_new = ClientTLSConnection::new(ip_addr,timeout); assert_eq!(_conn_new.get_server_addr(), IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))); assert_eq!(_conn_new.get_timeout(), Duration::from_secs(100)); @@ -215,7 +216,7 @@ mod tls_connection_test{ fn get_ip_v4(){ let ip_address = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); let timeout = Duration::from_secs(100); - let connection = ClientTLSConnection::new(ip_address, timeout, DEFAULT_SIZE); + let connection = ClientTLSConnection::new(ip_address, timeout); //check if the ip is the same assert_eq!(connection.get_ip(), IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))); } @@ -225,7 +226,7 @@ mod tls_connection_test{ // ip in V6 version is the equivalent to (192, 168, 0, 1) in V4 let ip_address = IpAddr::V6(Ipv6Addr::new(0xc0, 0xa8, 0, 1, 0, 0, 0, 0)); let timeout = Duration::from_secs(100); - let connection = ClientTLSConnection::new(ip_address, timeout, DEFAULT_SIZE); + let connection = ClientTLSConnection::new(ip_address, timeout); //check if the ip is the same assert_eq!(connection.get_ip(), IpAddr::V6(Ipv6Addr::new(0xc0, 0xa8, 0, 1, 0, 0, 0, 0))); } @@ -233,7 +234,7 @@ mod tls_connection_test{ fn get_server_addr(){ let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); let timeout = Duration::from_secs(100); - let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout, DEFAULT_SIZE); + let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout); assert_eq!(_conn_new.get_server_addr(), IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))); } @@ -242,7 +243,7 @@ mod tls_connection_test{ fn set_server_addr(){ let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); let timeout = Duration::from_secs(100); - let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout, DEFAULT_SIZE); + let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout); assert_eq!(_conn_new.get_server_addr(), IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1))); @@ -254,7 +255,7 @@ mod tls_connection_test{ fn get_timeout(){ let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); let timeout = Duration::from_secs(100); - let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout, DEFAULT_SIZE); + let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout); assert_eq!(_conn_new.get_timeout(), Duration::from_secs(100)); } @@ -263,7 +264,7 @@ mod tls_connection_test{ fn set_timeout(){ let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); let timeout = Duration::from_secs(100); - let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout, DEFAULT_SIZE); + let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout); assert_eq!(_conn_new.get_timeout(), Duration::from_secs(100)); @@ -273,16 +274,17 @@ mod tls_connection_test{ } #[tokio::test] - async fn send_timeout() { + async fn send() { let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1)); - let _port: u16 = 8088; - let timeout = Duration::from_secs(2); - - let conn_new = ClientTLSConnection::new(ip_addr,timeout, DEFAULT_SIZE); - let dns_query = DnsMessage::new(); - //let response = conn_new.send(dns_query).await; - - //assert!(response.is_err()); + let timeout = Duration::from_secs(100); + let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout); + let mut domain_name = DomainName::new(); + domain_name.set_name("example.com".to_string()); + let question = DnsMessage::new_query_message(domain_name, Rrtype::A, Rclass::IN, 0, true, 0); + let mut dns_query = DnsMessage::new(); + + let response = _conn_new.send(dns_query).await; + assert_eq!(response.is_ok(), true); }