Skip to content

Commit

Permalink
fix: ClientTLSConnection::send was updated
Browse files Browse the repository at this point in the history
  • Loading branch information
Jotape24 committed Jan 30, 2025
1 parent 30a83a4 commit 1b4a444
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 97 deletions.
2 changes: 1 addition & 1 deletion examples/dotls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn main() -> Result<(), ClientError> {
Rrtype::A,
Rclass::IN,
0,
false,
true,
1);

rt.block_on(async {
Expand Down
176 changes: 80 additions & 96 deletions src/client/tls_connection.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

use crate::message::DnsMessage;
use crate::message::rdata::Rdata;
use crate::message::rdata::a_rdata::ARdata;
Expand All @@ -12,7 +11,7 @@ use std::io::Error as IoError;
use std::io::ErrorKind;
use tokio::io::AsyncWriteExt;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::net::{lookup_host, TcpStream};
use std::net::IpAddr;
use std::net::SocketAddr;
use tokio::time::Duration;
Expand All @@ -21,6 +20,7 @@ use tokio_rustls::rustls::ClientConfig;
use tokio_rustls::TlsConnector;
use std::sync::Arc;


#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct ClientTLSConnection {
/// Client address
Expand Down Expand Up @@ -58,102 +58,65 @@ impl ClientSecurity for ClientTLSConnection {

/// creates socket tcp, sends query and receive response
async fn send(self, dns_query: DnsMessage) -> Result<Vec<u8>, ClientError> {
// async fn send(self, dns_query: DnsMessage) -> Result<(Vec<u8>, IpAddr), ClientError> {
//let root_store = RootCertStore::empty();
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
roots.add(cert).unwrap();
}
let config = ClientConfig::builder()
// Configure the root certificate store with platform-native certificates
let mut roots = rustls::RootCertStore::empty();
for cert in rustls_native_certs::load_native_certs().expect("could not load platform certs") {
roots.add(cert).unwrap();
}
let config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
// get the domain name to a srting
let dns_name_from_message = dns_query.get_question().get_qname().to_string();
let server_name = ServerName::try_from(dns_name_from_message).expect("invalid DNS name");
//let name_server= ServerName::try_from(dns_query.get_question().get_qname().to_string()).expect("invalid DNS name");
//let mut conn = rustls::ClientConnection::new(Arc::new(config), server_name).unwrap();
let connector = TlsConnector::from(Arc::new(config));

let conn_timeout: Duration = self.get_timeout();
//let bytes: Vec<u8> = dns_query.to_bytes();

let server_addr:SocketAddr = SocketAddr::new(self.get_server_addr(), 453);
// let name= dns_query.get_question().get_qname().to_string();
let stream = TcpStream::connect(server_addr).await;

let stream = match stream {
Ok(stream) => stream,
Err(e) => return Err(ClientError::from(e)),
};
//Verify that the connected IP matches the expected IP
let actual_ip = stream.peer_addr()?.ip();
let expected_ip = self.get_server_addr();
if actual_ip != expected_ip {
return Err(ClientError::Io(IoError::new(
ErrorKind::PermissionDenied,
format!("IP mismatch: expected {}, got {}", expected_ip, actual_ip),
)).into());
}
let tls_stream = connector.connect(server_name, stream).await;

//let mut tls = rustls::Stream::new(&mut conn, &mut stream);

let bytes = dns_query.to_bytes();
let msg_length = bytes.len() as u16;
let full_msg = [&msg_length.to_be_bytes(), bytes.as_slice()].concat();


// let mut stream: TcpStream = TcpStream::connect_timeout(&server_addr,timeout)?;
//let conn_task = TcpStream::connect(&server_addr).await;




// Handle the result of the TLS connection
let mut tls_stream_result = match tls_stream {
Ok(stream) => stream,
Err(e) => return Err(ClientError::Io(IoError::new(ErrorKind::Other, format!("TLS connection error: {}", e))).into()),
};
tls_stream_result.write_all(&full_msg).await?;
// Read response
let msg_size_response: [u8; 2] = [0; 2];
//tls.read_exact(msg_size_response).await?;

//let response_length = u16::from_be_bytes(msg_size_response) as usize;


//tls.read_exact(&mut response);

let tls_msg_len: u16 = (msg_size_response[0] as u16) << 8 | msg_size_response[1] as u16;
let mut vec_msg: Vec<u8> = Vec::new();
let ip = self.get_server_addr();
let mut additionals = dns_query.get_additional();
let mut ar = ARdata::new();
ar.set_address(ip);
let a_rdata = Rdata::A(ar);
let rr = ResourceRecord::new(a_rdata);
additionals.push(rr);


while vec_msg.len() < tls_msg_len as usize {
let mut msg = [0; 512];
let read_task = tls_stream_result.read(&mut msg);
let number_of_bytes_msg_result = match timeout(conn_timeout, read_task).await {
Ok(n) => n,
Err(_) => return Err(ClientError::Io(IoError::new(ErrorKind::TimedOut, format!("Error: timeout"))).into()),
};

let number_of_bytes_msg = match number_of_bytes_msg_result {
Ok(n) if n > 0 => n,
_ => return Err(IoError::new(ErrorKind::Other, format!("Error: no data received "))).map_err(Into::into),

};

vec_msg.extend_from_slice(&msg[..number_of_bytes_msg]);
}

return Ok(vec_msg);

// Resolve the server's IP address to a domain name
let server_name_res = Self::resolve_hostname(self.get_server_addr()).await;
let server_name = match server_name_res {
Ok(server_name_str) => ServerName::try_from(server_name_str).expect("invalid DNS name"),
Err(_) => return Err(ClientError::FormatError("Unable to resolve the IP address to a valid domain.")),
};

// Create a TLS connector with the configured certificates
let connector = TlsConnector::from(Arc::new(config));

// Connect to the DNS server over TCP on port 853
let server_addr: SocketAddr = SocketAddr::new(self.get_server_addr(), 853);
let stream = TcpStream::connect(server_addr).await.map_err(|e| ClientError::from(e))?;

// Verify that the connected IP matches the expected IP
let actual_ip = stream.peer_addr()?.ip();
let expected_ip = self.get_server_addr();
if actual_ip != expected_ip {
return Err(ClientError::Io(IoError::new(
ErrorKind::PermissionDenied,
format!("IP mismatch: expected {}, got {}", expected_ip, actual_ip),
)).into());
}

// // Establish the TLS connection
let mut tls_stream = connector.connect(server_name, stream).await.map_err(|e| {
ClientError::Io(IoError::new(ErrorKind::Other, format!("TLS connection error: {}", e)))
})?;


// Prepare the DNS query message
let bytes = dns_query.to_bytes();
let msg_length = bytes.len() as u16;
let full_msg = [&msg_length.to_be_bytes(), bytes.as_slice()].concat();

// Send the DNS query over the TLS connection
tls_stream.write_all(&full_msg).await?;

// Read the size of the response
let mut msg_size_response: [u8; 2] = [0; 2];
tls_stream.read_exact(&mut msg_size_response).await?;
let tls_msg_len: u16 = u16::from_be_bytes(msg_size_response);

// Read the full DNS response
let mut response = vec![0u8; tls_msg_len as usize];
tls_stream.read_exact(&mut response).await?;

// Return the response
Ok(response)
}
}

//Getters
Expand All @@ -167,6 +130,27 @@ impl ClientTLSConnection {
return self.timeout.clone();
}

/// Resolves the IP to a domain name or returns an error if it cannot be resolved.
async fn resolve_hostname(ip: IpAddr) -> Result<String, String> {
let socket_addr = format!("{}:843", ip); // Use port 443 (HTTPS) or the appropriate one
match lookup_host(socket_addr).await {
Ok(mut addrs) => {
// If the IP is resolved, return the domain name
if let Some(SocketAddr::V4(addr)) = addrs.next() {
return Ok(addr.ip().to_string());
}
}
Err(_) => {
// If resolution fails, return an error
return Err("Could not resolve the IP to a domain name.".to_string());
}
}

// If no domain is found, return an error
Err("Unable to resolve the IP address to a valid domain.".to_string())
}



}

Expand Down Expand Up @@ -266,7 +250,7 @@ mod tls_connection_test{

#[tokio::test]
async fn send() {
let ip_addr = IpAddr::V4(Ipv4Addr::new(192, 168, 0, 1));
let ip_addr = IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1));
let timeout = Duration::from_secs(100);
let mut _conn_new = ClientTLSConnection::new(ip_addr,timeout);
let mut domain_name = DomainName::new();
Expand Down

0 comments on commit 1b4a444

Please sign in to comment.