Skip to content

Commit

Permalink
tftp: update tests
Browse files Browse the repository at this point in the history
Signed-off-by: Benny Zlotnik <[email protected]>
  • Loading branch information
bennyz committed Jan 27, 2025
1 parent 7d4d2fb commit 96700ae
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self, host: str, port: int, root_dir: str,
self.transport: Optional[asyncio.DatagramTransport] = None
self.protocol: Optional['TftpServerProtocol'] = None
self.logger = logging.getLogger(self.__class__.__name__)
self.ready_event = asyncio.Event()

@property
def address(self) -> Optional[Tuple[str, int]]:
Expand All @@ -58,6 +59,7 @@ async def start(self):
self.logger.info(f"Starting TFTP server on {self.host}:{self.port}")
loop = asyncio.get_running_loop()

self.ready_event.set()
self.transport, self.protocol = await loop.create_datagram_endpoint(
lambda: TftpServerProtocol(self),
local_addr=(self.host, self.port)
Expand Down
221 changes: 175 additions & 46 deletions packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,80 @@

@pytest.fixture
async def tftp_server():
"""Fixture to create and cleanup a TFTP server instance."""
with tempfile.TemporaryDirectory() as temp_dir:
test_file_path = Path(temp_dir) / "test.txt"
test_file_path.write_text("Hello, TFTP!")

server = TftpServer(host="127.0.0.1", port=0, root_dir=temp_dir)
server = TftpServer(
host="127.0.0.1",
port=0,
root_dir=temp_dir
)
server_task = asyncio.create_task(server.start())

for _ in range(10):
if server.address is not None:
break
await asyncio.sleep(0.1)
else:
await server.shutdown()
server_task.cancel()
raise RuntimeError("Failed to bind TFTP server to a port.")

yield server, temp_dir
yield server, temp_dir, server.address[1]

await server.shutdown()
await server_task

for task in asyncio.all_tasks():
if not task.done() and task != asyncio.current_task():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

async def create_test_client(server_port):
"""Helper function to create a test UDP client."""
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_datagram_endpoint(
asyncio.DatagramProtocol, remote_addr=("127.0.0.1", server_port)
asyncio.DatagramProtocol,
remote_addr=('127.0.0.1', 0)
)
return transport, protocol


@pytest.mark.anyio
@pytest.mark.asyncio
async def test_server_startup_and_shutdown(tftp_server):
"""Test that server starts up and shuts down cleanly."""
server, _ = tftp_server
server, temp_dir, server_port = tftp_server

server_task = asyncio.create_task(server.start())
await asyncio.sleep(0.1)
await server.ready_event.wait()

await server.shutdown()

# Wait for server task to complete
await server_task

assert True


@pytest.mark.anyio
@pytest.mark.asyncio
async def test_read_request_for_existing_file(tftp_server):
"""Test reading an existing file from the server."""
server, temp_dir = tftp_server
server, temp_dir, server_port = tftp_server

server_task = asyncio.create_task(server.start())
await asyncio.sleep(0.1)
await server.ready_event.wait()

try:
transport, _ = await create_test_client(server.port)

rrq_packet = (
Opcode.RRQ.to_bytes(2, "big")
+ b"test.txt\x00" # Filename
+ b"octet\x00" # Mode
Opcode.RRQ.to_bytes(2, 'big') +
b'test.txt\x00' + # filename
b'octet\x00' # mode
)

transport.sendto(rrq_packet)
await asyncio.sleep(0.1)
await server.ready_event.wait()

assert server.transport is not None

Expand All @@ -73,44 +91,46 @@ async def test_read_request_for_existing_file(tftp_server):
await server.shutdown()
await server_task


@pytest.mark.anyio
@pytest.mark.asyncio
async def test_read_request_for_nonexistent_file(tftp_server):
"""Test reading a non-existent file returns appropriate error."""
server, _ = tftp_server
server, temp_dir, server_port = tftp_server

server_task = asyncio.create_task(server.start())
await asyncio.sleep(0.1)

try:
transport, protocol = await create_test_client(server.port)

rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"nonexistent.txt\x00" + b"octet\x00"
rrq_packet = (
Opcode.RRQ.to_bytes(2, 'big') +
b'nonexistent.txt\x00' +
b'octet\x00'
)

transport.sendto(rrq_packet)
await asyncio.sleep(0.1)

assert server.transport is not None

finally:
transport.close()
await server.shutdown()
await server_task


@pytest.mark.anyio
@pytest.mark.asyncio
async def test_write_request_rejection(tftp_server):
"""Test that write requests are properly rejected (server is read-only)."""
server, _ = tftp_server
server, temp_dir, server_port = tftp_server
server_task = asyncio.create_task(server.start())
await asyncio.sleep(0.1)


try:
transport, _ = await create_test_client(server.port)
wrq_packet = Opcode.WRQ.to_bytes(2, "big") + b"test.txt\x00" + b"octet\x00"
wrq_packet = (
Opcode.WRQ.to_bytes(2, 'big') +
b'test.txt\x00' +
b'octet\x00'
)

transport.sendto(wrq_packet)
await asyncio.sleep(0.1)

assert server.transport is not None

Expand All @@ -119,17 +139,15 @@ async def test_write_request_rejection(tftp_server):
await server.shutdown()
await server_task


@pytest.mark.anyio
@pytest.mark.asyncio
async def test_invalid_packet_handling(tftp_server):
server, _ = tftp_server
server, temp_dir, server_port = tftp_server
server_task = asyncio.create_task(server.start())
await asyncio.sleep(0.1)
await server.ready_event.wait()

try:
transport, _ = await create_test_client(server.port)
transport.sendto(b"\x00\x01")
await asyncio.sleep(0.1)
transport.sendto(b'\x00\x01')

assert server.transport is not None

Expand All @@ -138,22 +156,24 @@ async def test_invalid_packet_handling(tftp_server):
await server.shutdown()
await server_task


@pytest.mark.anyio
@pytest.mark.asyncio
async def test_path_traversal_prevention(tftp_server):
"""Test that path traversal attempts are blocked."""
server, _ = tftp_server
server, temp_dir, server_port = tftp_server

server_task = asyncio.create_task(server.start())
await asyncio.sleep(0.1)
await server.ready_event.wait()

try:
transport, _ = await create_test_client(server.port)

rrq_packet = Opcode.RRQ.to_bytes(2, "big") + b"../../../etc/passwd\x00" + b"octet\x00"
rrq_packet = (
Opcode.RRQ.to_bytes(2, 'big') +
b'../../../etc/passwd\x00' +
b'octet\x00'
)

transport.sendto(rrq_packet)
await asyncio.sleep(0.1)

assert server.transport is not None

Expand All @@ -162,7 +182,116 @@ async def test_path_traversal_prevention(tftp_server):
await server.shutdown()
await server_task

@pytest.mark.asyncio
async def test_options_negotiation(tftp_server):
"""Test that options (blksize, timeout) are properly negotiated."""
server, temp_dir, server_port = tftp_server
server_task = asyncio.create_task(server.start())
await server.ready_event.wait()

@pytest.fixture
def anyio_backend():
return "asyncio"
try:
transport, _ = await create_test_client(server.port)

# RRQ with options
rrq_packet = (
Opcode.RRQ.to_bytes(2, 'big') +
b'test.txt\x00' +
b'octet\x00' +
b'blksize\x00' +
b'1024\x00' +
b'timeout\x00' +
b'3\x00'
)

transport.sendto(rrq_packet)

assert server.transport is not None

finally:
transport.close()
await server.shutdown()
await server_task

@pytest.mark.asyncio
async def test_retry_mechanism(tftp_server):
server, _, server_port = tftp_server

# make the test faster
server.timeout = 1

transport = None

class TestProtocol(asyncio.DatagramProtocol):
def __init__(self):
self.received_packets = []
self.transport = None

def connection_made(self, transport):
self.transport = transport

def datagram_received(self, data, addr):
self.received_packets.append(data)

try:
loop = asyncio.get_running_loop()
transport, protocol = await loop.create_datagram_endpoint(
lambda: TestProtocol(),
local_addr=('127.0.0.1', 0)
)

assert transport is not None, "Failed to create transport"

rrq_packet = (
Opcode.RRQ.to_bytes(2, 'big') +
b'test.txt\x00' +
b'octet\x00'
)

transport.sendto(rrq_packet, ('127.0.0.1', server_port))

await asyncio.sleep(server.timeout * 2)

data_packets = [p for p in protocol.received_packets
if p[0:2] == Opcode.DATA.to_bytes(2, 'big')]

assert len(data_packets) > 1, "Server should have retried sending DATA packet"

block_numbers = {int.from_bytes(p[2:4], 'big') for p in data_packets}
assert len(block_numbers) == 1, "All retried packets should be for the same block"
assert 1 in block_numbers, "First block number should be 1"

except Exception as e:
pytest.fail(f"Test failed with error: {str(e)}")

finally:
if transport is not None:
transport.close()


@pytest.mark.asyncio
async def test_invalid_options_handling(tftp_server):
server, temp_dir, server_port = tftp_server
server_task = asyncio.create_task(server.start())
await server.ready_event.wait()

try:
transport, _ = await create_test_client(server.port)

rrq_packet = (
Opcode.RRQ.to_bytes(2, 'big') +
b'test.txt\x00' +
b'octet\x00' +
b'blksize\x00' +
b'invalid\x00' +
b'timeout\x00' +
b'999999\x00'
)

transport.sendto(rrq_packet)

assert server.transport is not None

finally:
transport.close()
await server.shutdown()
await server_task
10 changes: 5 additions & 5 deletions packages/jumpstarter-driver-tftp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ raw-options = { 'root' = '../../'}
Homepage = "https://jumpstarter.dev"
source_archive = "https://github.com/jumpstarter-dev/repo/archive/{commit_hash}.zip"

# [tool.pytest.ini_options]
[tool.pytest.ini_options]
# #addopts = "--cov --cov-report=html --cov-report=xml"
# log_cli = true
# log_cli_level = "INFO"
log_cli = true
log_cli_level = "INFO"
# # testpaths = ["src"]
# asyncio_mode = "auto"
asyncio_mode = "auto"

[build-system]
requires = ["hatchling", "hatch-vcs"]
build-backend = "hatchling.build"
build-backend = "hatchling.build"

0 comments on commit 96700ae

Please sign in to comment.