From 3ae8b9b1f8d2afee34f2630102afe3824782821c Mon Sep 17 00:00:00 2001 From: joalopez Date: Thu, 30 Jan 2025 11:33:35 -0300 Subject: [PATCH] changes to concurrency --- src/nameserver.rs | 195 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 172 insertions(+), 23 deletions(-) diff --git a/src/nameserver.rs b/src/nameserver.rs index 637dc6ad..e1743e04 100644 --- a/src/nameserver.rs +++ b/src/nameserver.rs @@ -62,26 +62,6 @@ impl NameServer { } } - /*async fn handle_request(&self, socket: Arc>, - data: Vec, - addr: std::net::SocketAddr) { - let mut message = DnsMessage::from_bytes(&data).expect("Error al parsear el mensaje"); - - let rrs_to_add = NameServer::search_query(&self.zone, &message); - - if rrs_to_add.len() > 0 { - NameServer::add_rrs(&mut message, &rrs_to_add) - } - let response = message.to_bytes(); - // lock the socket and send the response - let mut sock = socket.lock().await; - if let Err(e) = sock.send_to(&response, addr).await { - eprintln!("Failed to send response to {}: {}", addr, e); - } else { - println!("Sent response to {:?}", addr); - } - }*/ - fn add_rrs(msg :&mut DnsMessage, rrs: &Vec) { msg.set_answer(rrs.clone()); let mut header = msg.get_header(); @@ -107,9 +87,178 @@ impl NameServer { #[cfg(test)] -mod tests { +mod ns_tests { use super::*; - pub fn test_response() { - + use futures_util::future; + use tokio::time::{timeout, Duration}; + use crate::message::rdata::mx_rdata::MxRdata; + use crate::message::DnsMessage; + use crate::domain_name::DomainName; + use crate::message::resource_record::ResourceRecord; + use crate::message::rdata::a_rdata::ARdata; + use crate::message::rdata::Rdata; + use crate::message::rclass::Rclass; + use crate::message::rrtype::Rrtype; + use crate::message::rcode::Rcode; + + #[tokio::test] + async fn test_name_server_init() { + let mut server = NameServer { + zone: vec![ + // Some sample RRs (will be cleared by init) + { + let mut rr = ResourceRecord::new(Rdata::A(ARdata::new())); + rr.set_name(DomainName::new_from_string("example.com".to_string())); + rr.set_rclass(Rclass::IN); + rr.set_type_code(Rrtype::A); + rr + } + ], + shared_sock: Arc::new(Mutex::new( + UdpSocket::bind("127.0.0.1:0").await.unwrap() + )), + }; + + server.init("127.0.0.1:0").await.unwrap(); + assert_eq!(server.zone.len(), 0, "Zone should be cleared on init"); + + let socket = server.shared_sock.lock().await; + assert_ne!(socket.local_addr().unwrap().port(), 0, "Should bind a valid port"); + } + + #[test] + fn test_name_server_add_rrs() { + let mut message = DnsMessage::new(); + let mut rr = ResourceRecord::new(Rdata::A(ARdata::new())); + rr.set_name(DomainName::new_from_string("example.com".to_string())); + rr.set_rclass(Rclass::IN); + rr.set_type_code(Rrtype::A); + + let rrs = vec![rr.clone()]; + NameServer::add_rrs(&mut message, &rrs); + + assert_eq!(message.get_answer().len(), 1, "Should have one RR in answer"); + assert_eq!(message.get_answer()[0], rr); + + let header = message.get_header(); + assert!(header.get_aa(), "AA flag should be set"); + assert!(header.get_qr(), "QR flag should be set (response)"); + assert_eq!(header.get_ancount(), 1, "ANCOUT should be 1"); + assert_eq!(header.get_rcode(), Rcode::NOERROR, "Rcode should be NOERROR"); + } + + #[test] + fn test_name_server_search_query() { + let mut rr_a = ResourceRecord::new(Rdata::A(ARdata::new())); + rr_a.set_name(DomainName::new_from_string("example.com".to_string())); + rr_a.set_type_code(Rrtype::A); + rr_a.set_rclass(Rclass::IN); + + let mut rr_mx = ResourceRecord::new(Rdata::MX(MxRdata::new())); // dummy Rdata, for example + rr_mx.set_name(DomainName::new_from_string("example.com".to_string())); + rr_mx.set_type_code(Rrtype::MX); + rr_mx.set_rclass(Rclass::IN); + + let mut rr_other = ResourceRecord::new(Rdata::A(ARdata::new())); + rr_other.set_name(DomainName::new_from_string("example.org".to_string())); + rr_other.set_type_code(Rrtype::A); + rr_other.set_rclass(Rclass::IN); + + let zone = vec![rr_a.clone(), rr_mx, rr_other]; + + let mut query_message = DnsMessage::new(); + { + let mut question = query_message.get_question(); + question.set_qname(DomainName::new_from_string("example.com".to_string())); + question.set_rrtype(Rrtype::A); + question.set_rclass(Rclass::IN); + query_message.set_question(question); + } + + let found = NameServer::search_query(&zone, &query_message); + assert_eq!(found.len(), 1, "Should find exactly one matching RR"); + assert_eq!(found[0], rr_a, "Should match the A record for example.com"); } + + #[tokio::test] + async fn test_name_server_run() { + let mut server = NameServer { + zone: vec![], + shared_sock: Arc::new(Mutex::new( + UdpSocket::bind("127.0.0.1:0").await.unwrap() + )), + }; + + let local_addr = server.shared_sock.lock().await.local_addr().unwrap(); + let handle = tokio::spawn(async move { + // This loop runs indefinitely, so we rely on a timeout in the test + let _ = server.run(&local_addr.to_string()).await; + }); + + //time to bind/listen + tokio::time::sleep(Duration::from_millis(100)).await; + + let test_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + let query_bytes = DnsMessage::new().to_bytes(); + let _ = test_socket.send_to(&query_bytes, &local_addr).await; + let _ = timeout(Duration::from_millis(300), handle).await; + } + #[ignore] + #[tokio::test] + async fn test_concurrency_with_timeout() { + let mut server = NameServer { + zone: vec![], // or some test zone RRs if needed + shared_sock: Arc::new(Mutex::new( + UdpSocket::bind("127.0.0.1:0").await.unwrap() + )), + }; + + // Get the server's local address + let local_addr = server.shared_sock.lock().await.local_addr().unwrap(); + + // Run the server in a background task + let server_task = tokio::spawn(async move { + let _ = server.run(&local_addr.to_string()).await; + }); + + // We will spawn multiple parallel "clients" + let num_clients = 5; + let tasks = (0..num_clients).map(|i| { + tokio::spawn(async move { + + let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap(); + + let query = DnsMessage::new().to_bytes(); + sock.send_to(&query, local_addr).await.unwrap(); + + + let mut buf = vec![0u8; 1024]; + let (bytes_received, from_addr) = sock.recv_from(&mut buf).await.unwrap(); + + + assert!(bytes_received > 0, "Should receive some data from the server"); + assert_eq!(from_addr.ip(), local_addr.ip(), "Response should come from the server IP"); + + + i + }) + }); + + // Wait for all client tasks to complete + let results = futures_util::future::join_all(tasks).await; + for (i, res) in results.into_iter().enumerate() { + let val = res.expect("Task panicked"); + assert_eq!(val, i); + } + + // After client tasks are done, we forcibly stop the server with a timeout + match tokio::time::timeout(Duration::from_millis(300), server_task).await { + Ok(_) => { + } + Err(_elapsed) => { + } + } + } + + } \ No newline at end of file