Skip to content

Commit

Permalink
changes to concurrency
Browse files Browse the repository at this point in the history
  • Loading branch information
joalopez1206 committed Jan 30, 2025
1 parent 6c1f96e commit 3ae8b9b
Showing 1 changed file with 172 additions and 23 deletions.
195 changes: 172 additions & 23 deletions src/nameserver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,6 @@ impl NameServer {
}
}

/*async fn handle_request(&self, socket: Arc<Mutex<UdpSocket>>,
data: Vec<u8>,
addr: std::net::SocketAddr) {
let mut message = DnsMessage::from_bytes(&data).expect("Error al parsear el mensaje");
let rrs_to_add = NameServer::search_query(&self.zone, &message);
if rrs_to_add.len() > 0 {
NameServer::add_rrs(&mut message, &rrs_to_add)
}
let response = message.to_bytes();
// lock the socket and send the response
let mut sock = socket.lock().await;
if let Err(e) = sock.send_to(&response, addr).await {
eprintln!("Failed to send response to {}: {}", addr, e);
} else {
println!("Sent response to {:?}", addr);
}
}*/

fn add_rrs(msg :&mut DnsMessage, rrs: &Vec<ResourceRecord>) {
msg.set_answer(rrs.clone());
let mut header = msg.get_header();
Expand All @@ -107,9 +87,178 @@ impl NameServer {


#[cfg(test)]
mod tests {
mod ns_tests {
use super::*;
pub fn test_response() {

use futures_util::future;
use tokio::time::{timeout, Duration};
use crate::message::rdata::mx_rdata::MxRdata;
use crate::message::DnsMessage;
use crate::domain_name::DomainName;
use crate::message::resource_record::ResourceRecord;
use crate::message::rdata::a_rdata::ARdata;
use crate::message::rdata::Rdata;
use crate::message::rclass::Rclass;
use crate::message::rrtype::Rrtype;
use crate::message::rcode::Rcode;

#[tokio::test]
async fn test_name_server_init() {
let mut server = NameServer {
zone: vec![
// Some sample RRs (will be cleared by init)
{
let mut rr = ResourceRecord::new(Rdata::A(ARdata::new()));
rr.set_name(DomainName::new_from_string("example.com".to_string()));
rr.set_rclass(Rclass::IN);
rr.set_type_code(Rrtype::A);
rr
}
],
shared_sock: Arc::new(Mutex::new(
UdpSocket::bind("127.0.0.1:0").await.unwrap()
)),
};

server.init("127.0.0.1:0").await.unwrap();
assert_eq!(server.zone.len(), 0, "Zone should be cleared on init");

let socket = server.shared_sock.lock().await;
assert_ne!(socket.local_addr().unwrap().port(), 0, "Should bind a valid port");
}

#[test]
fn test_name_server_add_rrs() {
let mut message = DnsMessage::new();
let mut rr = ResourceRecord::new(Rdata::A(ARdata::new()));
rr.set_name(DomainName::new_from_string("example.com".to_string()));
rr.set_rclass(Rclass::IN);
rr.set_type_code(Rrtype::A);

let rrs = vec![rr.clone()];
NameServer::add_rrs(&mut message, &rrs);

assert_eq!(message.get_answer().len(), 1, "Should have one RR in answer");
assert_eq!(message.get_answer()[0], rr);

let header = message.get_header();
assert!(header.get_aa(), "AA flag should be set");
assert!(header.get_qr(), "QR flag should be set (response)");
assert_eq!(header.get_ancount(), 1, "ANCOUT should be 1");
assert_eq!(header.get_rcode(), Rcode::NOERROR, "Rcode should be NOERROR");
}

#[test]
fn test_name_server_search_query() {
let mut rr_a = ResourceRecord::new(Rdata::A(ARdata::new()));
rr_a.set_name(DomainName::new_from_string("example.com".to_string()));
rr_a.set_type_code(Rrtype::A);
rr_a.set_rclass(Rclass::IN);

let mut rr_mx = ResourceRecord::new(Rdata::MX(MxRdata::new())); // dummy Rdata, for example
rr_mx.set_name(DomainName::new_from_string("example.com".to_string()));
rr_mx.set_type_code(Rrtype::MX);
rr_mx.set_rclass(Rclass::IN);

let mut rr_other = ResourceRecord::new(Rdata::A(ARdata::new()));
rr_other.set_name(DomainName::new_from_string("example.org".to_string()));
rr_other.set_type_code(Rrtype::A);
rr_other.set_rclass(Rclass::IN);

let zone = vec![rr_a.clone(), rr_mx, rr_other];

let mut query_message = DnsMessage::new();
{
let mut question = query_message.get_question();
question.set_qname(DomainName::new_from_string("example.com".to_string()));
question.set_rrtype(Rrtype::A);
question.set_rclass(Rclass::IN);
query_message.set_question(question);
}

let found = NameServer::search_query(&zone, &query_message);
assert_eq!(found.len(), 1, "Should find exactly one matching RR");
assert_eq!(found[0], rr_a, "Should match the A record for example.com");
}

#[tokio::test]
async fn test_name_server_run() {
let mut server = NameServer {
zone: vec![],
shared_sock: Arc::new(Mutex::new(
UdpSocket::bind("127.0.0.1:0").await.unwrap()
)),
};

let local_addr = server.shared_sock.lock().await.local_addr().unwrap();
let handle = tokio::spawn(async move {
// This loop runs indefinitely, so we rely on a timeout in the test
let _ = server.run(&local_addr.to_string()).await;
});

//time to bind/listen
tokio::time::sleep(Duration::from_millis(100)).await;

let test_socket = UdpSocket::bind("127.0.0.1:0").await.unwrap();
let query_bytes = DnsMessage::new().to_bytes();
let _ = test_socket.send_to(&query_bytes, &local_addr).await;
let _ = timeout(Duration::from_millis(300), handle).await;
}
#[ignore]
#[tokio::test]
async fn test_concurrency_with_timeout() {
let mut server = NameServer {
zone: vec![], // or some test zone RRs if needed
shared_sock: Arc::new(Mutex::new(
UdpSocket::bind("127.0.0.1:0").await.unwrap()
)),
};

// Get the server's local address
let local_addr = server.shared_sock.lock().await.local_addr().unwrap();

// Run the server in a background task
let server_task = tokio::spawn(async move {
let _ = server.run(&local_addr.to_string()).await;
});

// We will spawn multiple parallel "clients"
let num_clients = 5;
let tasks = (0..num_clients).map(|i| {
tokio::spawn(async move {

let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();

let query = DnsMessage::new().to_bytes();
sock.send_to(&query, local_addr).await.unwrap();


let mut buf = vec![0u8; 1024];
let (bytes_received, from_addr) = sock.recv_from(&mut buf).await.unwrap();


assert!(bytes_received > 0, "Should receive some data from the server");
assert_eq!(from_addr.ip(), local_addr.ip(), "Response should come from the server IP");


i
})
});

// Wait for all client tasks to complete
let results = futures_util::future::join_all(tasks).await;
for (i, res) in results.into_iter().enumerate() {
let val = res.expect("Task panicked");
assert_eq!(val, i);
}

// After client tasks are done, we forcibly stop the server with a timeout
match tokio::time::timeout(Duration::from_millis(300), server_task).await {
Ok(_) => {
}
Err(_elapsed) => {
}
}
}


}

0 comments on commit 3ae8b9b

Please sign in to comment.