Skip to content

Commit 6204ef9

Browse files
committed
tftp: remove checksum caching
Signed-off-by: Benny Zlotnik <[email protected]>
1 parent 464025e commit 6204ef9

File tree

2 files changed

+12
-119
lines changed

2 files changed

+12
-119
lines changed

packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py

Lines changed: 7 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -17,36 +17,28 @@
1717

1818
class TftpError(Exception):
1919
"""Base exception for TFTP server errors"""
20-
2120
pass
2221

23-
2422
class ServerNotRunning(TftpError):
2523
"""Server is not running"""
26-
2724
pass
2825

29-
3026
class FileNotFound(TftpError):
3127
"""File not found"""
32-
3328
pass
3429

35-
3630
@dataclass(kw_only=True)
3731
class Tftp(Driver):
3832
"""TFTP Server driver for Jumpstarter"""
3933

4034
root_dir: str = "/var/lib/tftpboot"
4135
host: str = field(default='')
4236
port: int = 69
43-
checksum_suffix: str = ".sha256"
4437
server: Optional["TftpServer"] = field(init=False, default=None)
4538
server_thread: Optional[threading.Thread] = field(init=False, default=None)
4639
_shutdown_event: threading.Event = field(init=False, default_factory=threading.Event)
4740
_loop_ready: threading.Event = field(init=False, default_factory=threading.Event)
4841
_loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None)
49-
_checksums: dict[str, str] = field(default_factory=dict)
5042

5143
def __post_init__(self):
5244
super().__post_init__()
@@ -73,10 +65,7 @@ def _start_server(self):
7365
asyncio.set_event_loop(self._loop)
7466
self.server = TftpServer(host=self.host, port=self.port, root_dir=self.root_dir)
7567
try:
76-
# Signal that the loop is ready
7768
self._loop_ready.set()
78-
79-
# Run the server until shutdown is requested
8069
self._loop.run_until_complete(self._run_server())
8170
except Exception as e:
8271
self.logger.error(f"Error running TFTP server: {e}")
@@ -86,7 +75,6 @@ def _start_server(self):
8675
self._loop.close()
8776
except Exception as e:
8877
self.logger.error(f"Error during event loop cleanup: {e}")
89-
9078
self._loop = None
9179
self.logger.info("TFTP server thread completed")
9280

@@ -111,11 +99,9 @@ def start(self):
11199
self.logger.warning("TFTP server is already running")
112100
return
113101

114-
# Clear any previous shutdown state
115102
self._shutdown_event.clear()
116103
self._loop_ready.clear()
117104

118-
# Start the server thread
119105
self.server_thread = threading.Thread(target=self._start_server, daemon=True)
120106
self.server_thread.start()
121107

@@ -133,7 +119,6 @@ def stop(self):
133119
return
134120

135121
self.logger.info("Initiating TFTP server shutdown")
136-
137122
self._shutdown_event.set()
138123
self.server_thread.join(timeout=10)
139124
if self.server_thread.is_alive():
@@ -148,7 +133,6 @@ def list_files(self) -> list[str]:
148133

149134
@export
150135
async def put_file(self, filename: str, src_stream, client_checksum: str):
151-
"""Compute and store checksum at write time"""
152136
file_path = os.path.join(self.root_dir, filename)
153137

154138
try:
@@ -160,47 +144,37 @@ async def put_file(self, filename: str, src_stream, client_checksum: str):
160144
async for chunk in src:
161145
await dst.send(chunk)
162146

163-
current_checksum = self._compute_checksum(file_path)
164-
self._checksums[filename] = current_checksum
165-
self._write_checksum_file(filename, current_checksum)
166147
return filename
167148
except Exception as e:
168149
raise TftpError(f"Failed to upload file: {str(e)}") from e
169150

170-
171151
@export
172152
def delete_file(self, filename: str):
173-
"""Delete file and its checksum file"""
174153
file_path = os.path.join(self.root_dir, filename)
175-
checksum_path = self._get_checksum_path(filename)
176154

177155
if not os.path.exists(file_path):
178156
raise FileNotFound(f"File {filename} not found")
179157

180158
try:
181159
os.remove(file_path)
182-
if os.path.exists(checksum_path):
183-
os.remove(checksum_path)
184-
self._checksums.pop(filename, None)
160+
return filename
185161
except Exception as e:
186162
raise TftpError(f"Failed to delete {filename}") from e
187163

188164
@export
189165
def check_file_checksum(self, filename: str, client_checksum: str) -> bool:
190-
"""
191-
check if the checksum of the file matches the client checksum
192-
"""
193-
194166
file_path = os.path.join(self.root_dir, filename)
167+
self.logger.debug(f"checking checksum for file: {filename}")
168+
self.logger.debug(f"file path: {file_path}")
169+
195170
if not os.path.exists(file_path):
171+
self.logger.debug(f"File {filename} does not exist")
196172
return False
197173

198174
current_checksum = self._compute_checksum(file_path)
175+
self.logger.debug(f"Computed checksum: {current_checksum}")
176+
self.logger.debug(f"Client checksum: {client_checksum}")
199177

200-
self._checksums[filename] = current_checksum
201-
self._write_checksum_file(filename, current_checksum)
202-
203-
self.logger.debug(f"Client checksum: {client_checksum}, server checksum: {current_checksum}")
204178
return current_checksum == client_checksum
205179

206180
@export
@@ -216,28 +190,6 @@ def close(self):
216190
self.stop()
217191
super().close()
218192

219-
def _get_checksum_path(self, filename: str) -> str:
220-
return os.path.join(self.root_dir, f"{filename}{self.checksum_suffix}")
221-
222-
def _read_checksum_file(self, filename: str) -> Optional[str]:
223-
try:
224-
checksum_path = self._get_checksum_path(filename)
225-
if os.path.exists(checksum_path):
226-
with open(checksum_path, 'r') as f:
227-
return f.read().strip()
228-
except Exception as e:
229-
self.logger.warning(f"Failed to read checksum file for {filename}: {e}")
230-
return None
231-
232-
def _write_checksum_file(self, filename: str, checksum: str):
233-
"""Write checksum to the checksum file"""
234-
try:
235-
checksum_path = self._get_checksum_path(filename)
236-
with open(checksum_path, 'w') as f:
237-
f.write(f"{checksum}\n")
238-
except Exception as e:
239-
self.logger.error(f"Failed to write checksum file for {filename}: {e}")
240-
241193
def _compute_checksum(self, path: str) -> str:
242194
hasher = hashlib.sha256()
243195
with open(path, "rb") as f:

packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py

Lines changed: 5 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import tempfile
44
from pathlib import Path
5-
from typing import Optional
65
from uuid import uuid4
76

87
import anyio
@@ -12,7 +11,6 @@
1211
from jumpstarter_driver_tftp.driver import (
1312
FileNotFound,
1413
Tftp,
15-
TftpError,
1614
)
1715

1816
from jumpstarter.common.resources import ClientStreamResource
@@ -23,14 +21,12 @@ def temp_dir():
2321
with tempfile.TemporaryDirectory() as tmpdir:
2422
yield tmpdir
2523

26-
2724
@pytest.fixture
2825
def server(temp_dir):
2926
server = Tftp(root_dir=temp_dir, host="127.0.0.1")
3027
yield server
3128
server.close()
3229

33-
3430
@pytest.mark.anyio
3531
async def test_tftp_file_operations(server):
3632
filename = "test.txt"
@@ -64,103 +60,48 @@ async def send_data():
6460
with pytest.raises(FileNotFound):
6561
server.delete_file("nonexistent.txt")
6662

67-
6863
def test_tftp_host_config(temp_dir):
6964
custom_host = "192.168.1.1"
7065
server = Tftp(root_dir=temp_dir, host=custom_host)
7166
assert server.get_host() == custom_host
7267

73-
7468
def test_tftp_root_directory_creation(temp_dir):
7569
new_dir = os.path.join(temp_dir, "new_tftp_root")
7670
server = Tftp(root_dir=new_dir)
7771
assert os.path.exists(new_dir)
7872
server.close()
7973

80-
81-
@pytest.mark.anyio
82-
async def test_tftp_checksum_validation(server):
83-
filename = "test_checksum.txt"
84-
test_data = b"Hello world"
85-
modified_data = b"Modified content"
86-
87-
def compute_checksum(data: bytes) -> str:
88-
return hashlib.sha256(data).hexdigest()
89-
90-
initial_checksum = await _upload_file(server, filename, test_data)
91-
assert filename in server.list_files()
92-
assert compute_checksum(test_data) == initial_checksum
93-
94-
# Second upload with same data should be skipped
95-
same_data_checksum = await _upload_file(server, filename, test_data)
96-
assert same_data_checksum == initial_checksum
97-
98-
modified_checksum = await _upload_file(server, filename, modified_data)
99-
assert modified_checksum != initial_checksum
100-
assert Path(server.root_dir).joinpath(filename).read_bytes() == modified_data
101-
102-
empty_checksum = await _upload_file(server, "empty.txt", b"")
103-
assert empty_checksum == hashlib.sha256(b"").hexdigest()
104-
10574
@pytest.mark.anyio
10675
async def test_tftp_detect_corrupted_file(server):
10776
filename = "corrupted.txt"
10877
original_data = b"Original Data"
10978
client_checksum = hashlib.sha256(original_data).hexdigest()
11079

11180
await _upload_file(server, filename, original_data)
81+
11282
assert server.check_file_checksum(filename, client_checksum)
11383

11484
file_path = Path(server.root_dir, filename)
115-
with open(file_path, "wb") as f:
116-
f.write(b"Corrupted Data")
85+
file_path.write_bytes(b"corrupted Data")
11786

11887
assert not server.check_file_checksum(filename, client_checksum)
11988

120-
@pytest.mark.anyio
121-
async def test_tftp_reupload_different_checksum(server):
122-
filename = "reupload.txt"
123-
initial_data = b"Initial Data"
124-
updated_data = b"Updated Data"
125-
initial_checksum = hashlib.sha256(initial_data).hexdigest()
126-
updated_checksum = hashlib.sha256(updated_data).hexdigest()
127-
128-
await _upload_file(server, filename, initial_data)
129-
assert server.check_file_checksum(filename, initial_checksum)
130-
assert Path(server.root_dir, filename).read_bytes() == initial_data
131-
132-
await _upload_file(server, filename, updated_data, client_checksum=updated_checksum)
133-
assert server.check_file_checksum(filename, updated_checksum)
134-
assert Path(server.root_dir, filename).read_bytes() == updated_data
135-
13689
@pytest.fixture
13790
def anyio_backend():
13891
return "asyncio"
13992

140-
async def _upload_file(server, filename: str, data: bytes, client_checksum: Optional[str] = None) -> str:
93+
async def _upload_file(server, filename: str, data: bytes) -> str:
14194
send_stream, receive_stream = create_memory_object_stream()
14295
resource_uuid = uuid4()
14396
server.resources[resource_uuid] = receive_stream
14497
resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json")
145-
client_checksum = client_checksum or hashlib.sha256(data).hexdigest()
14698

14799
async def send_data():
148100
await send_stream.send(data)
149101
await send_stream.aclose()
150102

151103
async with anyio.create_task_group() as tg:
152104
tg.start_soon(send_data)
153-
await server.put_file(filename, resource_handle, client_checksum)
154-
155-
return server._compute_checksum(os.path.join(server.root_dir, filename))
156-
157-
@pytest.mark.anyio
158-
async def test_tftp_path_traversal_attempt(server):
159-
malicious_filename = "../../evil.txt"
160-
161-
resource_uuid = uuid4()
162-
server.resources[resource_uuid] = None
163-
resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json")
105+
await server.put_file(filename, resource_handle, hashlib.sha256(data).hexdigest())
164106

165-
with pytest.raises(TftpError, match="Invalid target path"):
166-
await server.put_file(malicious_filename, resource_handle, "checksum")
107+
return hashlib.sha256(data).hexdigest()

0 commit comments

Comments
 (0)