Skip to content

Commit ee4720b

Browse files
authored
Merge pull request #241 from bennyz/tftp-checksum
tftp: add checksum validation
2 parents 675b96c + 42c92d2 commit ee4720b

File tree

6 files changed

+126
-97
lines changed

6 files changed

+126
-97
lines changed
Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,37 @@
11
import logging
2-
import time
32

43
import pytest
54
from jumpstarter_driver_tftp.driver import FileNotFound, TftpError
65
from jumpstarter_testing.pytest import JumpstarterTest
76

87
log = logging.getLogger(__name__)
98

10-
119
class TestResource(JumpstarterTest):
1210
filter_labels = {"board": "rpi4"}
1311

1412
@pytest.fixture()
15-
def test_tftp_upload(self, client):
13+
def setup_tftp(self, client):
14+
# Move the setup code to a fixture
15+
client.tftp.start()
16+
yield client
17+
client.tftp.stop()
18+
19+
def test_tftp_operations(self, setup_tftp):
20+
client = setup_tftp
21+
test_file = "test.bin"
22+
23+
# Create test file
24+
with open(test_file, "wb") as f:
25+
f.write(b"Hello from TFTP streaming test!")
26+
1627
try:
17-
client.tftp.start()
18-
print("TFTP server started")
19-
20-
time.sleep(1)
21-
22-
test_file = "test.bin"
23-
with open(test_file, "wb") as f:
24-
f.write(b"Hello from TFTP streaming test!")
25-
26-
try:
27-
client.tftp.put_local_file(test_file)
28-
print(f"Successfully uploaded {test_file}")
29-
30-
files = client.tftp.list_files()
31-
print(f"Files in TFTP root: {files}")
32-
33-
if test_file in files:
34-
client.tftp.delete_file(test_file)
35-
print(f"Successfully deleted {test_file}")
36-
else:
37-
print(f"Warning: {test_file} not found in TFTP root")
38-
39-
except TftpError as e:
40-
print(f"TFTP operation failed: {e}")
41-
except FileNotFound as e:
42-
print(f"File not found: {e}")
43-
44-
except Exception as e:
45-
print(f"Error: {e}")
46-
finally:
47-
try:
48-
client.tftp.stop()
49-
print("TFTP server stopped")
50-
except Exception as e:
51-
print(f"Error stopping server: {e}")
28+
# Test upload
29+
client.tftp.put_local_file(test_file)
30+
assert test_file in client.tftp.list_files()
31+
32+
# Test delete
33+
client.tftp.delete_file(test_file)
34+
assert test_file not in client.tftp.list_files()
35+
36+
except (TftpError, FileNotFound) as e:
37+
pytest.fail(f"Test failed: {e}")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
CHUNK_SIZE = 1024 * 1024 * 4 # 4MB

packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import hashlib
12
from dataclasses import dataclass
23
from pathlib import Path
34

45
from jumpstarter_driver_opendal.adapter import OpendalAdapter
56
from opendal import Operator
67

8+
from . import CHUNK_SIZE
79
from jumpstarter.client import DriverClient
810

911

@@ -46,37 +48,30 @@ def list_files(self) -> list[str]:
4648
return self.call("list_files")
4749

4850
def put_file(self, operator: Operator, path: str):
49-
"""
50-
Upload a file to the TFTP server using an OpenDAL operator
51+
filename = Path(path).name
52+
client_checksum = self._compute_checksum(operator, path)
5153

52-
Args:
53-
operator (Operator): OpenDAL operator for accessing the source storage
54-
path (str): Path to the file in the source storage system
54+
if self.call("check_file_checksum", filename, client_checksum):
55+
self.logger.info(f"Skipping upload of identical file: {filename}")
56+
return filename
5557

56-
Returns:
57-
str: Name of the uploaded file
58-
"""
59-
filename = Path(path).name
6058
with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle:
61-
return self.call("put_file", filename, handle)
59+
return self.call("put_file", filename, handle, client_checksum)
6260

6361
def put_local_file(self, filepath: str):
64-
"""
65-
Upload a file from the local filesystem to the TFTP server
66-
Note: this doesn't use TFTP to upload.
62+
absolute = Path(filepath).resolve()
63+
filename = absolute.name
6764

68-
Args:
69-
filepath (str): Path to the local file to upload
65+
operator = Operator("fs", root="/")
66+
client_checksum = self._compute_checksum(operator, str(absolute))
7067

71-
Returns:
72-
str: Name of the uploaded file
68+
if self.call("check_file_checksum", filename, client_checksum):
69+
self.logger.info(f"Skipping upload of identical file: {filename}")
70+
return filename
7371

74-
Example:
75-
>>> client.put_local_file("/path/to/local/file.txt")
76-
"""
77-
absolute = Path(filepath).resolve()
78-
with OpendalAdapter(client=self, operator=Operator("fs", root="/"), path=str(absolute), mode="rb") as handle:
79-
return self.call("put_file", absolute.name, handle)
72+
self.logger.info(f"checksum: {client_checksum}")
73+
with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle:
74+
return self.call("put_file", filename, handle, client_checksum)
8075

8176
def delete_file(self, filename: str):
8277
"""
@@ -108,3 +103,10 @@ def get_port(self) -> int:
108103
int: The port number (default is 69)
109104
"""
110105
return self.call("get_port")
106+
107+
def _compute_checksum(self, operator: Operator, path: str) -> str:
108+
hasher = hashlib.sha256()
109+
with operator.open(path, "rb") as f:
110+
while chunk := f.read(CHUNK_SIZE):
111+
hasher.update(chunk)
112+
return hasher.hexdigest()

packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import hashlib
23
import os
34
import socket
45
import threading
@@ -10,33 +11,28 @@
1011

1112
from jumpstarter_driver_tftp.server import TftpServer
1213

14+
from . import CHUNK_SIZE
1315
from jumpstarter.driver import Driver, export
1416

1517

1618
class TftpError(Exception):
1719
"""Base exception for TFTP server errors"""
18-
1920
pass
2021

21-
2222
class ServerNotRunning(TftpError):
2323
"""Server is not running"""
24-
2524
pass
2625

27-
2826
class FileNotFound(TftpError):
2927
"""File not found"""
30-
3128
pass
3229

33-
3430
@dataclass(kw_only=True)
3531
class Tftp(Driver):
3632
"""TFTP Server driver for Jumpstarter"""
3733

3834
root_dir: str = "/var/lib/tftpboot"
39-
host: str = field(default=None)
35+
host: str = field(default='')
4036
port: int = 69
4137
server: Optional["TftpServer"] = field(init=False, default=None)
4238
server_thread: Optional[threading.Thread] = field(init=False, default=None)
@@ -49,7 +45,7 @@ def __post_init__(self):
4945
super().__post_init__()
5046

5147
os.makedirs(self.root_dir, exist_ok=True)
52-
if self.host is None:
48+
if self.host == '':
5349
self.host = self.get_default_ip()
5450

5551
def get_default_ip(self):
@@ -71,10 +67,7 @@ def _start_server(self):
7167
asyncio.set_event_loop(self._loop)
7268
self.server = TftpServer(host=self.host, port=self.port, root_dir=self.root_dir)
7369
try:
74-
# Signal that the loop is ready
7570
self._loop_ready.set()
76-
77-
# Run the server until shutdown is requested
7871
self._loop.run_until_complete(self._run_server())
7972
except Exception as e:
8073
self.logger.error(f"Error running TFTP server: {e}")
@@ -84,7 +77,6 @@ def _start_server(self):
8477
self._loop.close()
8578
except Exception as e:
8679
self.logger.error(f"Error during event loop cleanup: {e}")
87-
8880
self._loop = None
8981
self.logger.info("TFTP server thread completed")
9082

@@ -109,11 +101,9 @@ def start(self):
109101
self.logger.warning("TFTP server is already running")
110102
return
111103

112-
# Clear any previous shutdown state
113104
self._shutdown_event.clear()
114105
self._loop_ready.clear()
115106

116-
# Start the server thread
117107
self.server_thread = threading.Thread(target=self._start_server, daemon=True)
118108
self.server_thread.start()
119109

@@ -131,7 +121,6 @@ def stop(self):
131121
return
132122

133123
self.logger.info("Initiating TFTP server shutdown")
134-
135124
self._shutdown_event.set()
136125
self.server_thread.join(timeout=10)
137126
if self.server_thread.is_alive():
@@ -145,11 +134,10 @@ def list_files(self) -> list[str]:
145134
return os.listdir(self.root_dir)
146135

147136
@export
148-
async def put_file(self, filename: str, src_stream):
149-
"""Handle file upload using streaming"""
150-
try:
151-
file_path = os.path.join(self.root_dir, filename)
137+
async def put_file(self, filename: str, src_stream, client_checksum: str):
138+
file_path = os.path.join(self.root_dir, filename)
152139

140+
try:
153141
if not Path(file_path).resolve().is_relative_to(Path(self.root_dir).resolve()):
154142
raise TftpError("Invalid target path")
155143

@@ -159,19 +147,38 @@ async def put_file(self, filename: str, src_stream):
159147
await dst.send(chunk)
160148

161149
return filename
162-
163150
except Exception as e:
164151
raise TftpError(f"Failed to upload file: {str(e)}") from e
165152

166153
@export
167154
def delete_file(self, filename: str):
155+
file_path = os.path.join(self.root_dir, filename)
156+
157+
if not os.path.exists(file_path):
158+
raise FileNotFound(f"File {filename} not found")
159+
168160
try:
169-
os.remove(os.path.join(self.root_dir, filename))
170-
except FileNotFoundError as err:
171-
raise FileNotFound(f"File {filename} not found") from err
161+
os.remove(file_path)
162+
return filename
172163
except Exception as e:
173164
raise TftpError(f"Failed to delete {filename}") from e
174165

166+
@export
167+
def check_file_checksum(self, filename: str, client_checksum: str) -> bool:
168+
file_path = os.path.join(self.root_dir, filename)
169+
self.logger.debug(f"checking checksum for file: {filename}")
170+
self.logger.debug(f"file path: {file_path}")
171+
172+
if not os.path.exists(file_path):
173+
self.logger.debug(f"File {filename} does not exist")
174+
return False
175+
176+
current_checksum = self._compute_checksum(file_path)
177+
self.logger.debug(f"Computed checksum: {current_checksum}")
178+
self.logger.debug(f"Client checksum: {client_checksum}")
179+
180+
return current_checksum == client_checksum
181+
175182
@export
176183
def get_host(self) -> str:
177184
return self.host
@@ -184,3 +191,10 @@ def close(self):
184191
if self.server_thread is not None:
185192
self.stop()
186193
super().close()
194+
195+
def _compute_checksum(self, path: str) -> str:
196+
hasher = hashlib.sha256()
197+
with open(path, "rb") as f:
198+
while chunk := f.read(CHUNK_SIZE):
199+
hasher.update(chunk)
200+
return hasher.hexdigest()

packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import hashlib
12
import os
23
import tempfile
34
from pathlib import Path
@@ -20,18 +21,17 @@ def temp_dir():
2021
with tempfile.TemporaryDirectory() as tmpdir:
2122
yield tmpdir
2223

23-
2424
@pytest.fixture
2525
def server(temp_dir):
2626
server = Tftp(root_dir=temp_dir, host="127.0.0.1")
2727
yield server
2828
server.close()
2929

30-
3130
@pytest.mark.anyio
3231
async def test_tftp_file_operations(server):
3332
filename = "test.txt"
3433
test_data = b"Hello"
34+
client_checksum = hashlib.sha256(test_data).hexdigest()
3535

3636
send_stream, receive_stream = create_memory_object_stream(max_buffer_size=10)
3737

@@ -46,8 +46,7 @@ async def send_data():
4646

4747
async with anyio.create_task_group() as tg:
4848
tg.start_soon(send_data)
49-
50-
await server.put_file(filename, resource_handle)
49+
await server.put_file(filename, resource_handle, client_checksum)
5150

5251
files = server.list_files()
5352
assert filename in files
@@ -61,20 +60,48 @@ async def send_data():
6160
with pytest.raises(FileNotFound):
6261
server.delete_file("nonexistent.txt")
6362

64-
6563
def test_tftp_host_config(temp_dir):
6664
custom_host = "192.168.1.1"
6765
server = Tftp(root_dir=temp_dir, host=custom_host)
6866
assert server.get_host() == custom_host
6967

70-
7168
def test_tftp_root_directory_creation(temp_dir):
7269
new_dir = os.path.join(temp_dir, "new_tftp_root")
7370
server = Tftp(root_dir=new_dir)
7471
assert os.path.exists(new_dir)
7572
server.close()
7673

74+
@pytest.mark.anyio
75+
async def test_tftp_detect_corrupted_file(server):
76+
filename = "corrupted.txt"
77+
original_data = b"Original Data"
78+
client_checksum = hashlib.sha256(original_data).hexdigest()
79+
80+
await _upload_file(server, filename, original_data)
81+
82+
assert server.check_file_checksum(filename, client_checksum)
83+
84+
file_path = Path(server.root_dir, filename)
85+
file_path.write_bytes(b"corrupted Data")
86+
87+
assert not server.check_file_checksum(filename, client_checksum)
7788

7889
@pytest.fixture
7990
def anyio_backend():
8091
return "asyncio"
92+
93+
async def _upload_file(server, filename: str, data: bytes) -> str:
94+
send_stream, receive_stream = create_memory_object_stream()
95+
resource_uuid = uuid4()
96+
server.resources[resource_uuid] = receive_stream
97+
resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json")
98+
99+
async def send_data():
100+
await send_stream.send(data)
101+
await send_stream.aclose()
102+
103+
async with anyio.create_task_group() as tg:
104+
tg.start_soon(send_data)
105+
await server.put_file(filename, resource_handle, hashlib.sha256(data).hexdigest())
106+
107+
return hashlib.sha256(data).hexdigest()

0 commit comments

Comments
 (0)