diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py index bf476f1f..37e83a3a 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server.py @@ -14,6 +14,7 @@ class Opcode(IntEnum): DATA = 3 ACK = 4 ERROR = 5 + OACK = 6 class TftpErrorCode(IntEnum): @@ -32,27 +33,36 @@ class TftpServer: TFTP Server that handles read requests (RRQ). """ - def __init__( - self, host: str, port: int, root_dir: str, block_size: int = 512, timeout: float = 5.0, retries: int = 3 - ): + def __init__(self, host: str, port: int, root_dir: str, + block_size: int = 512, timeout: float = 5.0, retries: int = 3): self.host = host self.port = port self.root_dir = pathlib.Path(os.path.abspath(root_dir)) self.block_size = block_size self.timeout = timeout self.retries = retries - self.active_transfers: Set["TftpTransfer"] = set() + self.active_transfers: Set['TftpTransfer'] = set() self.shutdown_event = asyncio.Event() self.transport: Optional[asyncio.DatagramTransport] = None - self.protocol: Optional["TftpServerProtocol"] = None + self.protocol: Optional['TftpServerProtocol'] = None self.logger = logging.getLogger(self.__class__.__name__) + self.ready_event = asyncio.Event() + + @property + def address(self) -> Optional[Tuple[str, int]]: + """Get the server's bound address and port.""" + if self.transport: + return self.transport.get_extra_info('socket').getsockname() + return None async def start(self): self.logger.info(f"Starting TFTP server on {self.host}:{self.port}") loop = asyncio.get_running_loop() + self.ready_event.set() self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: TftpServerProtocol(self), local_addr=(self.host, self.port) + lambda: TftpServerProtocol(self), + local_addr=(self.host, self.port) ) try: @@ -82,11 +92,11 @@ async def shutdown(self): self.logger.info("Shutdown signal received for TFTP server") self.shutdown_event.set() - def register_transfer(self, transfer: "TftpTransfer"): + def register_transfer(self, transfer: 'TftpTransfer'): self.active_transfers.add(transfer) self.logger.debug(f"Registered transfer: {transfer}") - def unregister_transfer(self, transfer: "TftpTransfer"): + def unregister_transfer(self, transfer: 'TftpTransfer'): self.active_transfers.discard(transfer) self.logger.debug(f"Unregistered transfer: {transfer}") @@ -120,7 +130,7 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): return try: - opcode = Opcode(int.from_bytes(data[0:2], "big")) + opcode = Opcode(int.from_bytes(data[0:2], 'big')) except ValueError: self.logger.error(f"Unknown opcode from {addr}") self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") @@ -136,59 +146,171 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]): try: - # Parse filename and mode from request - parts = data[2:].split(b"\x00") - if len(parts) < 2: - self.logger.error(f"Invalid RRQ format from {addr}") - raise ValueError("Invalid RRQ format") - - filename = parts[0].decode("utf-8") - mode = parts[1].decode("utf-8").lower() - - self.logger.info(f"RRQ from {addr}: '{filename}' in mode '{mode}'") - - if mode not in ("netascii", "octet"): - self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}") - self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported transfer mode") - return - - # Resolve file path securely - requested_path = self.server.root_dir / filename - resolved_path = requested_path.resolve() + filename, mode, options = self._parse_request(data) + self.logger.info(f"RRQ from {addr}: '{filename}' in mode '{mode}' with options {options}") - if not resolved_path.is_file(): - self.logger.error(f"File not found: {resolved_path}") - self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") + if not self._validate_mode(mode, addr): return - if not is_subpath(resolved_path, self.server.root_dir): - self.logger.error(f"Access violation: {resolved_path} is outside the root directory") - self._send_error(addr, TftpErrorCode.ACCESS_VIOLATION, "Access denied") + resolved_path = self._resolve_and_validate_path(filename, addr) + if not resolved_path: return - transfer = TftpReadTransfer( - server=self.server, - filepath=resolved_path, - client_addr=addr, - block_size=self.server.block_size, - timeout=self.server.timeout, - retries=self.server.retries, - ) - self.server.register_transfer(transfer) - asyncio.create_task(transfer.start()) + negotiated_options, blksize, timeout = self._negotiate_options(options) + self.logger.info(f"Negotiated options: {negotiated_options}") + await self._start_transfer(resolved_path, addr, blksize, timeout, negotiated_options) except Exception as e: self.logger.error(f"Error handling RRQ from {addr}: {e}") self._send_error(addr, TftpErrorCode.NOT_DEFINED, str(e)) + def _send_oack(self, addr: Tuple[str, int], options: dict): + """Send Option Acknowledgment (OACK) packet.""" + oack_data = Opcode.OACK.to_bytes(2, 'big') + for opt_name, opt_value in options.items(): + oack_data += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8') + + if self.transport: + self.transport.sendto(oack_data, addr) + self.logger.debug(f"Sent OACK to {addr} with options {options}") + def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): error_packet = ( - Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" + Opcode.ERROR.to_bytes(2, 'big') + + error_code.to_bytes(2, 'big') + + message.encode('utf-8') + b'\x00' ) if self.transport: self.transport.sendto(error_packet, addr) self.logger.debug(f"Sent ERROR {error_code.name} to {addr}: {message}") + def _parse_request(self, data: bytes) -> Tuple[str, str, dict]: + parts = data[2:].split(b'\x00') + if len(parts) < 2: + raise ValueError("Invalid RRQ format") + + filename = parts[0].decode('utf-8') + if len(filename) > 255: # RFC 1350 doesn't specify a limit + raise ValueError("Filename too long") + if not all(c.isprintable() and c not in '<>:"/\\|?*' for c in filename): + raise ValueError("Invalid characters in filename") + if '\x00' in filename: + raise ValueError("Null byte in filename") + mode = parts[1].decode('utf-8').lower() + options = self._parse_options(parts[2:]) + + return filename, mode, options + + + def _parse_options(self, option_parts: list) -> dict: + options = {} + i = 0 + while i < len(option_parts) - 1: + try: + opt_name = option_parts[i].decode('utf-8').lower() + opt_value = option_parts[i + 1].decode('utf-8') + options[opt_name] = opt_value + i += 2 + except Exception: + break + return options + + def _validate_mode(self, mode: str, addr: Tuple[str, int]) -> bool: + if mode not in ('netascii', 'octet'): + self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}") + self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unsupported transfer mode") + return False + return True + + def _resolve_and_validate_path(self, filename: str, addr: Tuple[str, int]) -> Optional[pathlib.Path]: + requested_path = self.server.root_dir / filename + resolved_path = requested_path.resolve() + + if not resolved_path.is_file(): + self.logger.error(f"File not found: {resolved_path}") + self._send_error(addr, TftpErrorCode.FILE_NOT_FOUND, "File not found") + return None + + if not is_subpath(resolved_path, self.server.root_dir): + self.logger.error(f"Access violation: {resolved_path} is outside root directory") + self._send_error(addr, TftpErrorCode.ACCESS_VIOLATION, "Access denied") + return None + + return resolved_path + + def _negotiate_block_size(self, requested_blksize: Optional[str]) -> int: + if requested_blksize is None: + return self.server.block_size + + try: + blksize = int(requested_blksize) + if 512 <= blksize <= 65464: + return blksize + else: + self.logger.warning( + f"Requested block size {blksize} out of range (512-65464), " + f"using default: {self.server.block_size}" + ) + return self.server.block_size + except ValueError: + self.logger.warning( + f"Invalid block size value '{requested_blksize}', " + f"using default: {self.server.block_size}" + ) + return self.server.block_size + + def _negotiate_timeout(self, requested_timeout: Optional[str]) -> float: + if requested_timeout is None: + return self.server.timeout + + try: + timeout = int(requested_timeout) + if 1 <= timeout <= 255: + return float(timeout) + else: + self.logger.warning( + f"Timeout value {timeout} out of range (1-255), " + f"using default: {self.server.timeout}" + ) + return self.server.timeout + except ValueError: + self.logger.warning( + f"Invalid timeout value '{requested_timeout}', " + f"using default: {self.server.timeout}" + ) + return self.server.timeout + + def _negotiate_options(self, options: dict) -> Tuple[dict, int, float]: + negotiated = {} + blksize = self.server.block_size + timeout = self.server.timeout + + if 'blksize' in options: + requested = options['blksize'] + blksize = self._negotiate_block_size(requested) + negotiated['blksize'] = blksize + + if 'timeout' in options: + requested = options['timeout'] + timeout = self._negotiate_timeout(requested) + negotiated['timeout'] = int(timeout) + + return negotiated, blksize, timeout + + + async def _start_transfer(self, filepath: pathlib.Path, addr: Tuple[str, int], + blksize: int, timeout: float, negotiated_options: dict): + transfer = TftpReadTransfer( + server=self.server, + filepath=filepath, + client_addr=addr, + block_size=blksize, + timeout=timeout, + retries=self.server.retries, + negotiated_options=negotiated_options + ) + self.server.register_transfer(transfer) + asyncio.create_task(transfer.start()) def is_subpath(path: pathlib.Path, root: pathlib.Path) -> bool: try: @@ -203,15 +325,8 @@ class TftpTransfer: Base class for TFTP transfers. """ - def __init__( - self, - server: TftpServer, - filepath: pathlib.Path, - client_addr: Tuple[str, int], - block_size: int, - timeout: float, - retries: int, - ): + def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int], + block_size: int, timeout: float, retries: int): self.server = server self.filepath = filepath self.client_addr = client_addr @@ -219,7 +334,7 @@ def __init__( self.timeout = timeout self.retries = retries self.transport: Optional[asyncio.DatagramTransport] = None - self.protocol: Optional["TftpTransferProtocol"] = None + self.protocol: Optional['TftpTransferProtocol'] = None self.cleanup_task: Optional[asyncio.Task] = None self.logger = logging.getLogger(self.__class__.__name__) @@ -237,114 +352,178 @@ async def cleanup(self): class TftpReadTransfer(TftpTransfer): - """ - Handles a TFTP Read (RRQ) transfer. - """ - - def __init__( - self, - server: TftpServer, - filepath: pathlib.Path, - client_addr: Tuple[str, int], - block_size: int, - timeout: float, - retries: int, - ): - super().__init__(server, filepath, client_addr, block_size, timeout, retries) - self.block_num = 1 + def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int], + block_size: int, timeout: float, retries: int, negotiated_options: Optional[dict] = None): + super().__init__( + server=server, + filepath=filepath, + client_addr=client_addr, + block_size=block_size, + timeout=timeout, + retries=retries + ) + self.block_num = 0 self.ack_received = asyncio.Event() self.last_ack = 0 + self.oack_confirmed = False + self.negotiated_options = negotiated_options + self.current_packet: Optional[bytes] = None async def start(self): self.logger.info(f"Starting read transfer of '{self.filepath.name}' to {self.client_addr}") - loop = asyncio.get_running_loop() - self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: TftpTransferProtocol(self), local_addr=("0.0.0.0", 0), remote_addr=self.client_addr - ) - local_addr = self.transport.get_extra_info("sockname") - self.logger.debug(f"Transfer bound to local {local_addr}") + if not await self._initialize_transfer(): + return try: - async with aiofiles.open(self.filepath, "rb") as f: - while True: - if self.server.shutdown_event.is_set(): - self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") - break - data = await f.read(self.block_size) - if data: - packet = Opcode.DATA.to_bytes(2, "big") + self.block_num.to_bytes(2, "big") + data - success = await self._send_with_retries(packet) - if not success: - self.logger.error(f"Failed to send block {self.block_num} to {self.client_addr}") - break - self.logger.debug(f"Block {self.block_num} sent successfully") - self.block_num += 1 - - # If the data read is less than block_size, this is the last packet - if len(data) < self.block_size: - self.logger.info(f"Final block {self.block_num - 1} reached for {self.client_addr}") - break - else: - # If no data is returned, it means the file size is an exact multiple of block_size - # Send an extra empty DATA packet to signal end of transfer - packet = Opcode.DATA.to_bytes(2, "big") + self.block_num.to_bytes(2, "big") + b"" - success = await self._send_with_retries(packet) - if not success: - self.logger.error( - f"Failed to send final empty block {self.block_num} to {self.client_addr}" - ) - break - self.logger.info(f"Transfer complete to {self.client_addr}, final block {self.block_num}") - break + # if no options were negotiated, we can start sending data immediately + if not self.negotiated_options: + self.oack_confirmed = True + await self._perform_transfer() except Exception as e: self.logger.error(f"Error during read transfer: {e}") finally: await self.cleanup() - async def _send_with_retries(self, packet: bytes) -> bool: + async def _initialize_transfer(self) -> bool: + loop = asyncio.get_running_loop() + + self.transport, self.protocol = await loop.create_datagram_endpoint( + lambda: TftpTransferProtocol(self), + local_addr=('0.0.0.0', 0), + remote_addr=self.client_addr + ) + local_addr = self.transport.get_extra_info('sockname') + self.logger.debug(f"Transfer bound to local {local_addr}") + + # Only send OACK if we have non-default options to negotiate + if self.negotiated_options and ( + self.negotiated_options['blksize'] != 512 or + self.negotiated_options['timeout'] != self.server.timeout + ): + oack_packet = self._create_oack_packet() + if not await self._send_with_retries(oack_packet, is_oack=True): + self.logger.error("Failed to get acknowledgment for OACK") + return False + + self.block_num = 1 + return True + + async def _perform_transfer(self): + async with aiofiles.open(self.filepath, 'rb') as f: + while True: + if self.server.shutdown_event.is_set(): + self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}") + break + + data = await f.read(self.block_size) + if not await self._handle_data_block(data): + break + + async def _handle_data_block(self, data: bytes) -> bool: + """ + Handle sending a block of data to the client. + Returns False if transfer should stop, True if it should continue. + """ + if not data and self.block_num == 1: + # Empty file case + packet = self._create_data_packet(b'') + await self._send_with_retries(packet) + return False + elif data: + packet = self._create_data_packet(data) + success = await self._send_with_retries(packet) + if not success: + self.logger.error(f"Failed to send block {self.block_num} to {self.client_addr}") + return False + + self.logger.debug(f"Block {self.block_num} sent successfully") + self.block_num += 1 + + # wrap block number around if it exceeds 16 bits + self.block_num %= 65536 + + if len(data) < self.block_size: + self.logger.info(f"Final block {self.block_num - 1} sent") + return False + return True + else: + # EOF reached + packet = self._create_data_packet(b'') + success = await self._send_with_retries(packet) + if not success: + self.logger.error(f"Failed to send final block {self.block_num}") + else: + self.logger.info(f"Transfer complete, final block {self.block_num}") + return False + + def _create_oack_packet(self) -> bytes: + packet = Opcode.OACK.to_bytes(2, 'big') + for opt_name, opt_value in self.negotiated_options.items(): + packet += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8') + return packet + + def _create_data_packet(self, data: bytes) -> bytes: + return ( + Opcode.DATA.to_bytes(2, 'big') + + self.block_num.to_bytes(2, 'big') + + data + ) + + def _send_packet(self, packet: bytes): + self.transport.sendto(packet) + if packet[0:2] == Opcode.DATA.to_bytes(2, 'big'): + block = int.from_bytes(packet[2:4], 'big') + data_length = len(packet) - 4 + self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}") + elif packet[0:2] == Opcode.OACK.to_bytes(2, 'big'): + self.logger.debug(f"Sent OACK to {self.client_addr}") + + async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool: self.current_packet = packet + expected_block = 0 if is_oack else self.block_num + for attempt in range(1, self.retries + 1): try: self._send_packet(packet) - self.logger.debug(f"Sent DATA block {self.block_num}, waiting for ACK (Attempt {attempt})") + self.logger.debug( + f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, " + f"waiting for ACK (Attempt {attempt})" + ) self.ack_received.clear() await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout) - if self.last_ack == self.block_num: - self.logger.debug(f"ACK received for block {self.block_num}") + if self.last_ack == expected_block: + self.logger.debug(f"ACK received for block {expected_block}") return True else: - self.logger.warning(f"Received wrong ACK: expected {self.block_num}, got {self.last_ack}") + self.logger.warning(f"Received wrong ACK: expected {expected_block}, got {self.last_ack}") except asyncio.TimeoutError: - self.logger.warning(f"Timeout waiting for ACK of block {self.block_num} (Attempt {attempt})") + self.logger.warning(f"Timeout waiting for ACK of block {expected_block} (Attempt {attempt})") return False - def _send_packet(self, packet: bytes): - """ - Sends a DATA packet to the client. - """ - self.transport.sendto(packet) - block = int.from_bytes(packet[2:4], "big") - data_length = len(packet) - 4 - self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}") - def handle_ack(self, block_num: int): self.logger.debug(f"Received ACK for block {block_num} from {self.client_addr}") + + # special handling for OACK acknowledgment + if not self.oack_confirmed and self.negotiated_options and block_num == 0: + self.oack_confirmed = True + self.last_ack = block_num + self.ack_received.set() + return + if block_num == self.block_num: self.last_ack = block_num self.ack_received.set() elif block_num == self.block_num - 1: - # Duplicate ACK for previous block, resend current packet - self.logger.warning(f"Duplicate ACK for block {block_num} received, resending DATA block {self.block_num}") + self.logger.warning(f"Duplicate ACK for block {block_num} received, resending block {self.block_num}") self.transport.sendto(self.current_packet) else: self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}") - class TftpTransferProtocol(asyncio.DatagramProtocol): """ Protocol for handling ACKs during a TFTP transfer. @@ -356,7 +535,7 @@ def __init__(self, transfer: TftpReadTransfer): def connection_made(self, transport: asyncio.DatagramTransport): self.transfer.transport = transport - local_addr = transport.get_extra_info("sockname") + local_addr = transport.get_extra_info('sockname') self.logger.debug(f"Transfer protocol connection established on {local_addr} for {self.transfer.client_addr}") def datagram_received(self, data: bytes, addr: Tuple[str, int]): @@ -366,24 +545,24 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]): return if len(data) < 4: - self.logger.warning(f"Received malformed ACK from {addr}") + self.logger.warning(f"Received malformed packet from {addr}") return try: - opcode = Opcode(int.from_bytes(data[0:2], "big")) + opcode = Opcode(int.from_bytes(data[0:2], 'big')) except ValueError: - self.logger.error(f"Unknown opcode in ACK from {addr}") - self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode in ACK") + self.logger.error(f"Unknown opcode from {addr}") + self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unknown opcode") return - if opcode != Opcode.ACK: - self.logger.warning(f"Expected ACK but got opcode {opcode} from {addr}") - self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Expected ACK") - return + if opcode == Opcode.ACK: + block_num = int.from_bytes(data[2:4], 'big') + self.logger.debug(f"Received ACK for block {block_num} from {addr}") + self.transfer.handle_ack(block_num) + else: + self.logger.warning(f"Unexpected opcode {opcode} from {addr}") + self._send_error(addr, TftpErrorCode.ILLEGAL_OPERATION, "Unexpected opcode") - block_num = int.from_bytes(data[2:4], "big") - self.logger.debug(f"Received ACK for block {block_num} from {addr}") - self.transfer.handle_ack(block_num) def error_received(self, exc): self.logger.error(f"Error received: {exc}") @@ -393,7 +572,9 @@ def connection_lost(self, exc): def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str): error_packet = ( - Opcode.ERROR.to_bytes(2, "big") + error_code.to_bytes(2, "big") + message.encode("utf-8") + b"\x00" + Opcode.ERROR.to_bytes(2, 'big') + + error_code.to_bytes(2, 'big') + + message.encode('utf-8') + b'\x00' ) if self.transfer.transport: self.transfer.transport.sendto(error_packet) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py index 6ffa4453..5242f0c9 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py @@ -9,62 +9,80 @@ @pytest.fixture async def tftp_server(): - """Fixture to create and cleanup a TFTP server instance.""" with tempfile.TemporaryDirectory() as temp_dir: test_file_path = Path(temp_dir) / "test.txt" test_file_path.write_text("Hello, TFTP!") - server = TftpServer(host="127.0.0.1", port=0, root_dir=temp_dir) + server = TftpServer( + host="127.0.0.1", + port=0, + root_dir=temp_dir + ) + server_task = asyncio.create_task(server.start()) + + for _ in range(10): + if server.address is not None: + break + await asyncio.sleep(0.1) + else: + await server.shutdown() + server_task.cancel() + raise RuntimeError("Failed to bind TFTP server to a port.") - yield server, temp_dir + yield server, temp_dir, server.address[1] await server.shutdown() + await server_task + for task in asyncio.all_tasks(): + if not task.done() and task != asyncio.current_task(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass async def create_test_client(server_port): - """Helper function to create a test UDP client.""" loop = asyncio.get_running_loop() transport, protocol = await loop.create_datagram_endpoint( - asyncio.DatagramProtocol, remote_addr=("127.0.0.1", server_port) + asyncio.DatagramProtocol, + remote_addr=('127.0.0.1', 0) ) return transport, protocol - -@pytest.mark.anyio +@pytest.mark.asyncio async def test_server_startup_and_shutdown(tftp_server): """Test that server starts up and shuts down cleanly.""" - server, _ = tftp_server + server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - await asyncio.sleep(0.1) + await server.ready_event.wait() await server.shutdown() - # Wait for server task to complete await server_task assert True - -@pytest.mark.anyio +@pytest.mark.asyncio async def test_read_request_for_existing_file(tftp_server): """Test reading an existing file from the server.""" - server, temp_dir = tftp_server + server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - await asyncio.sleep(0.1) + await server.ready_event.wait() try: transport, _ = await create_test_client(server.port) rrq_packet = ( - Opcode.RRQ.to_bytes(2, "big") - + b"test.txt\x00" # Filename - + b"octet\x00" # Mode + Opcode.RRQ.to_bytes(2, 'big') + + b'test.txt\x00' + # filename + b'octet\x00' # mode ) transport.sendto(rrq_packet) - await asyncio.sleep(0.1) + await server.ready_event.wait() assert server.transport is not None @@ -73,23 +91,23 @@ async def test_read_request_for_existing_file(tftp_server): await server.shutdown() await server_task - -@pytest.mark.anyio +@pytest.mark.asyncio async def test_read_request_for_nonexistent_file(tftp_server): """Test reading a non-existent file returns appropriate error.""" - server, _ = tftp_server + server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - await asyncio.sleep(0.1) try: transport, protocol = await create_test_client(server.port) - rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"nonexistent.txt\x00" + b"octet\x00" + rrq_packet = ( + Opcode.RRQ.to_bytes(2, 'big') + + b'nonexistent.txt\x00' + + b'octet\x00' + ) transport.sendto(rrq_packet) - await asyncio.sleep(0.1) - assert server.transport is not None finally: @@ -97,20 +115,22 @@ async def test_read_request_for_nonexistent_file(tftp_server): await server.shutdown() await server_task - -@pytest.mark.anyio +@pytest.mark.asyncio async def test_write_request_rejection(tftp_server): """Test that write requests are properly rejected (server is read-only).""" - server, _ = tftp_server + server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - await asyncio.sleep(0.1) + try: transport, _ = await create_test_client(server.port) - wrq_packet = Opcode.WRQ.to_bytes(2, "big") + b"test.txt\x00" + b"octet\x00" + wrq_packet = ( + Opcode.WRQ.to_bytes(2, 'big') + + b'test.txt\x00' + + b'octet\x00' + ) transport.sendto(wrq_packet) - await asyncio.sleep(0.1) assert server.transport is not None @@ -119,17 +139,15 @@ async def test_write_request_rejection(tftp_server): await server.shutdown() await server_task - -@pytest.mark.anyio +@pytest.mark.asyncio async def test_invalid_packet_handling(tftp_server): - server, _ = tftp_server + server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - await asyncio.sleep(0.1) + await server.ready_event.wait() try: transport, _ = await create_test_client(server.port) - transport.sendto(b"\x00\x01") - await asyncio.sleep(0.1) + transport.sendto(b'\x00\x01') assert server.transport is not None @@ -138,22 +156,24 @@ async def test_invalid_packet_handling(tftp_server): await server.shutdown() await server_task - -@pytest.mark.anyio +@pytest.mark.asyncio async def test_path_traversal_prevention(tftp_server): """Test that path traversal attempts are blocked.""" - server, _ = tftp_server + server, temp_dir, server_port = tftp_server server_task = asyncio.create_task(server.start()) - await asyncio.sleep(0.1) + await server.ready_event.wait() try: transport, _ = await create_test_client(server.port) - rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"../../../etc/passwd\x00" + b"octet\x00" + rrq_packet = ( + Opcode.RRQ.to_bytes(2, 'big') + + b'../../../etc/passwd\x00' + + b'octet\x00' + ) transport.sendto(rrq_packet) - await asyncio.sleep(0.1) assert server.transport is not None @@ -162,7 +182,116 @@ async def test_path_traversal_prevention(tftp_server): await server.shutdown() await server_task +@pytest.mark.asyncio +async def test_options_negotiation(tftp_server): + """Test that options (blksize, timeout) are properly negotiated.""" + server, temp_dir, server_port = tftp_server + server_task = asyncio.create_task(server.start()) + await server.ready_event.wait() -@pytest.fixture -def anyio_backend(): - return "asyncio" + try: + transport, _ = await create_test_client(server.port) + + # RRQ with options + rrq_packet = ( + Opcode.RRQ.to_bytes(2, 'big') + + b'test.txt\x00' + + b'octet\x00' + + b'blksize\x00' + + b'1024\x00' + + b'timeout\x00' + + b'3\x00' + ) + + transport.sendto(rrq_packet) + + assert server.transport is not None + + finally: + transport.close() + await server.shutdown() + await server_task + +@pytest.mark.asyncio +async def test_retry_mechanism(tftp_server): + server, _, server_port = tftp_server + + # make the test faster + server.timeout = 1 + + transport = None + + class TestProtocol(asyncio.DatagramProtocol): + def __init__(self): + self.received_packets = [] + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + self.received_packets.append(data) + + try: + loop = asyncio.get_running_loop() + transport, protocol = await loop.create_datagram_endpoint( + lambda: TestProtocol(), + local_addr=('127.0.0.1', 0) + ) + + assert transport is not None, "Failed to create transport" + + rrq_packet = ( + Opcode.RRQ.to_bytes(2, 'big') + + b'test.txt\x00' + + b'octet\x00' + ) + + transport.sendto(rrq_packet, ('127.0.0.1', server_port)) + + await asyncio.sleep(server.timeout * 2) + + data_packets = [p for p in protocol.received_packets + if p[0:2] == Opcode.DATA.to_bytes(2, 'big')] + + assert len(data_packets) > 1, "Server should have retried sending DATA packet" + + block_numbers = {int.from_bytes(p[2:4], 'big') for p in data_packets} + assert len(block_numbers) == 1, "All retried packets should be for the same block" + assert 1 in block_numbers, "First block number should be 1" + + except Exception as e: + pytest.fail(f"Test failed with error: {str(e)}") + + finally: + if transport is not None: + transport.close() + + +@pytest.mark.asyncio +async def test_invalid_options_handling(tftp_server): + server, temp_dir, server_port = tftp_server + server_task = asyncio.create_task(server.start()) + await server.ready_event.wait() + + try: + transport, _ = await create_test_client(server.port) + + rrq_packet = ( + Opcode.RRQ.to_bytes(2, 'big') + + b'test.txt\x00' + + b'octet\x00' + + b'blksize\x00' + + b'invalid\x00' + + b'timeout\x00' + + b'999999\x00' + ) + + transport.sendto(rrq_packet) + + assert server.transport is not None + + finally: + transport.close() + await server.shutdown() + await server_task