From d91539540f86eb155b99aa98a00d9bf064945469 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Wed, 29 Jan 2025 10:55:10 +0200 Subject: [PATCH 1/5] tftp: add checksum validation Signed-off-by: Benny Zlotnik --- .../examples/tftp_test.py | 62 +++++-------- .../jumpstarter_driver_tftp/client.py | 49 +++++----- .../jumpstarter_driver_tftp/driver.py | 89 +++++++++++++++++-- .../jumpstarter_driver_tftp/driver_test.py | 78 +++++++++++++++- 4 files changed, 207 insertions(+), 71 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/examples/tftp_test.py b/packages/jumpstarter-driver-tftp/examples/tftp_test.py index ba7b7067..735fcc14 100644 --- a/packages/jumpstarter-driver-tftp/examples/tftp_test.py +++ b/packages/jumpstarter-driver-tftp/examples/tftp_test.py @@ -1,5 +1,4 @@ import logging -import time import pytest from jumpstarter_driver_tftp.driver import FileNotFound, TftpError @@ -7,45 +6,32 @@ log = logging.getLogger(__name__) - class TestResource(JumpstarterTest): filter_labels = {"board": "rpi4"} @pytest.fixture() - def test_tftp_upload(self, client): + def setup_tftp(self, client): + # Move the setup code to a fixture + client.tftp.start() + yield client + client.tftp.stop() + + def test_tftp_operations(self, setup_tftp): + client = setup_tftp + test_file = "test.bin" + + # Create test file + with open(test_file, "wb") as f: + f.write(b"Hello from TFTP streaming test!") + try: - client.tftp.start() - print("TFTP server started") - - time.sleep(1) - - test_file = "test.bin" - with open(test_file, "wb") as f: - f.write(b"Hello from TFTP streaming test!") - - try: - client.tftp.put_local_file(test_file) - print(f"Successfully uploaded {test_file}") - - files = client.tftp.list_files() - print(f"Files in TFTP root: {files}") - - if test_file in files: - client.tftp.delete_file(test_file) - print(f"Successfully deleted {test_file}") - else: - print(f"Warning: {test_file} not found in TFTP root") - - except TftpError as e: - print(f"TFTP operation failed: {e}") - except FileNotFound as e: - print(f"File not found: {e}") - - except Exception as e: - print(f"Error: {e}") - finally: - try: - client.tftp.stop() - print("TFTP server stopped") - except Exception as e: - print(f"Error stopping server: {e}") + # Test upload + client.tftp.put_local_file(test_file) + assert test_file in client.tftp.list_files() + + # Test delete + client.tftp.delete_file(test_file) + assert test_file not in client.tftp.list_files() + + except (TftpError, FileNotFound) as e: + pytest.fail(f"Test failed: {e}") diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py index cc22910c..5989d87f 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py @@ -1,3 +1,5 @@ +import hashlib +import logging from dataclasses import dataclass from pathlib import Path @@ -6,6 +8,7 @@ from jumpstarter.client import DriverClient +logger = logging.getLogger(__name__) @dataclass(kw_only=True) class TftpServerClient(DriverClient): @@ -46,37 +49,30 @@ def list_files(self) -> list[str]: return self.call("list_files") def put_file(self, operator: Operator, path: str): - """ - Upload a file to the TFTP server using an OpenDAL operator + filename = Path(path).name + client_checksum = self._compute_checksum(operator, path) - Args: - operator (Operator): OpenDAL operator for accessing the source storage - path (str): Path to the file in the source storage system + if self.call("check_file_checksum", filename, client_checksum): + logger.info(f"Skipping upload of identical file: {filename}") + return filename - Returns: - str: Name of the uploaded file - """ - filename = Path(path).name with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: - return self.call("put_file", filename, handle) + return self.call("put_file", filename, handle, client_checksum) def put_local_file(self, filepath: str): - """ - Upload a file from the local filesystem to the TFTP server - Note: this doesn't use TFTP to upload. + absolute = Path(filepath).resolve() + filename = absolute.name - Args: - filepath (str): Path to the local file to upload + operator = Operator("fs", root="/") + client_checksum = self._compute_checksum(operator, str(absolute)) - Returns: - str: Name of the uploaded file + if self.call("check_file_checksum", filename, client_checksum): + logger.info(f"Skipping upload of identical file: {filename}") + return filename - Example: - >>> client.put_local_file("/path/to/local/file.txt") - """ - absolute = Path(filepath).resolve() - with OpendalAdapter(client=self, operator=Operator("fs", root="/"), path=str(absolute), mode="rb") as handle: - return self.call("put_file", absolute.name, handle) + logger.info(f"checksum: {client_checksum}") + with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle: + return self.call("put_file", filename, handle, client_checksum) def delete_file(self, filename: str): """ @@ -108,3 +104,10 @@ def get_port(self) -> int: int: The port number (default is 69) """ return self.call("get_port") + + def _compute_checksum(self, operator: Operator, path: str) -> str: + hasher = hashlib.sha256() + with operator.open(path, "rb") as f: + while chunk := f.read(8192): + hasher.update(chunk) + return hasher.hexdigest() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 0d6e3030..61bda095 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import logging import os import socket @@ -52,15 +53,18 @@ class Tftp(Driver): root_dir: str = "/var/lib/tftpboot" host: str = field(default_factory=get_default_ip) port: int = 69 + checksum_suffix: str = ".sha256" server: Optional["TftpServer"] = field(init=False, default=None) server_thread: Optional[threading.Thread] = field(init=False, default=None) _shutdown_event: threading.Event = field(init=False, default_factory=threading.Event) _loop_ready: threading.Event = field(init=False, default_factory=threading.Event) _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) + _checksums: dict[str, str] = field(default_factory=dict) def __post_init__(self): super().__post_init__() os.makedirs(self.root_dir, exist_ok=True) + self._initialize_checksums() @classmethod def client(cls) -> str: @@ -145,11 +149,11 @@ def list_files(self) -> list[str]: return os.listdir(self.root_dir) @export - async def put_file(self, filename: str, src_stream): - """Handle file upload using streaming""" - try: - file_path = os.path.join(self.root_dir, filename) + async def put_file(self, filename: str, src_stream, client_checksum: str): + """Only called when we know we need to upload""" + file_path = os.path.join(self.root_dir, filename) + try: if not Path(file_path).resolve().is_relative_to(Path(self.root_dir).resolve()): raise TftpError("Invalid target path") @@ -158,20 +162,47 @@ async def put_file(self, filename: str, src_stream): async for chunk in src: await dst.send(chunk) + self._checksums[filename] = client_checksum + self._write_checksum_file(filename, client_checksum) return filename - except Exception as e: raise TftpError(f"Failed to upload file: {str(e)}") from e + @export def delete_file(self, filename: str): + """Delete file and its checksum file""" + file_path = os.path.join(self.root_dir, filename) + checksum_path = self._get_checksum_path(filename) + + if not os.path.exists(file_path): + raise FileNotFound(f"File {filename} not found") + try: - os.remove(os.path.join(self.root_dir, filename)) - except FileNotFoundError as err: - raise FileNotFound(f"File {filename} not found") from err + os.remove(file_path) + if os.path.exists(checksum_path): + os.remove(checksum_path) + self._checksums.pop(filename, None) except Exception as e: raise TftpError(f"Failed to delete {filename}") from e + @export + def check_file_checksum(self, filename: str, client_checksum: str) -> bool: + """Check if file exists with matching checksum""" + file_path = os.path.join(self.root_dir, filename) + if not os.path.exists(file_path): + return False + + current_checksum = self._compute_checksum(file_path) + stored_checksum = self._read_checksum_file(filename) + + if stored_checksum != current_checksum: + self._write_checksum_file(filename, current_checksum) + self._checksums[filename] = current_checksum + + logger.debug(f"Client checksum: {client_checksum}, server checksum: {current_checksum}") + return current_checksum == client_checksum + @export def get_host(self) -> str: return self.host @@ -184,3 +215,45 @@ def close(self): if self.server_thread is not None: self.stop() super().close() + + def _get_checksum_path(self, filename: str) -> str: + return os.path.join(self.root_dir, f"{filename}{self.checksum_suffix}") + + def _read_checksum_file(self, filename: str) -> Optional[str]: + try: + checksum_path = self._get_checksum_path(filename) + if os.path.exists(checksum_path): + with open(checksum_path, 'r') as f: + return f.read().strip() + except Exception as e: + logger.warning(f"Failed to read checksum file for {filename}: {e}") + return None + + def _write_checksum_file(self, filename: str, checksum: str): + """Write checksum to the checksum file""" + try: + checksum_path = self._get_checksum_path(filename) + with open(checksum_path, 'w') as f: + f.write(f"{checksum}\n") + except Exception as e: + logger.error(f"Failed to write checksum file for {filename}: {e}") + + def _compute_checksum(self, path: str) -> str: + hasher = hashlib.sha256() + with open(path, "rb") as f: + while chunk := f.read(8192): + hasher.update(chunk) + return hasher.hexdigest() + + def _initialize_checksums(self): + self._checksums.clear() + for filename in os.listdir(self.root_dir): + if filename.endswith(self.checksum_suffix): + continue + file_path = os.path.join(self.root_dir, filename) + if os.path.isfile(file_path): + stored_checksum = self._read_checksum_file(filename) + current_checksum = self._compute_checksum(file_path) + if stored_checksum != current_checksum: + self._write_checksum_file(filename, current_checksum) + self._checksums[filename] = current_checksum diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index 4bd74163..dd7cc9c2 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -1,6 +1,8 @@ +import hashlib import os import tempfile from pathlib import Path +from typing import Optional from uuid import uuid4 import anyio @@ -32,6 +34,7 @@ def server(temp_dir): async def test_tftp_file_operations(server): filename = "test.txt" test_data = b"Hello" + client_checksum = hashlib.sha256(test_data).hexdigest() send_stream, receive_stream = create_memory_object_stream(max_buffer_size=10) @@ -46,8 +49,7 @@ async def send_data(): async with anyio.create_task_group() as tg: tg.start_soon(send_data) - - await server.put_file(filename, resource_handle) + await server.put_file(filename, resource_handle, client_checksum) files = server.list_files() assert filename in files @@ -75,6 +77,78 @@ def test_tftp_root_directory_creation(temp_dir): server.close() +@pytest.mark.anyio +async def test_tftp_checksum_validation(server): + filename = "test_checksum.txt" + test_data = b"Hello world" + modified_data = b"Modified content" + + def compute_checksum(data: bytes) -> str: + return hashlib.sha256(data).hexdigest() + + initial_checksum = await _upload_file(server, filename, test_data) + assert filename in server.list_files() + assert compute_checksum(test_data) == initial_checksum + + # Second upload with same data should be skipped + same_data_checksum = await _upload_file(server, filename, test_data) + assert same_data_checksum == initial_checksum + + modified_checksum = await _upload_file(server, filename, modified_data) + assert modified_checksum != initial_checksum + assert Path(server.root_dir).joinpath(filename).read_bytes() == modified_data + + empty_checksum = await _upload_file(server, "empty.txt", b"") + assert empty_checksum == hashlib.sha256(b"").hexdigest() + +@pytest.mark.anyio +async def test_tftp_detect_corrupted_file(server): + filename = "corrupted.txt" + original_data = b"Original Data" + client_checksum = hashlib.sha256(original_data).hexdigest() + + await _upload_file(server, filename, original_data) + assert server.check_file_checksum(filename, client_checksum) + + file_path = Path(server.root_dir, filename) + with open(file_path, "wb") as f: + f.write(b"Corrupted Data") + + assert not server.check_file_checksum(filename, client_checksum) + +@pytest.mark.anyio +async def test_tftp_reupload_different_checksum(server): + filename = "reupload.txt" + initial_data = b"Initial Data" + updated_data = b"Updated Data" + initial_checksum = hashlib.sha256(initial_data).hexdigest() + updated_checksum = hashlib.sha256(updated_data).hexdigest() + + await _upload_file(server, filename, initial_data) + assert server.check_file_checksum(filename, initial_checksum) + assert Path(server.root_dir, filename).read_bytes() == initial_data + + await _upload_file(server, filename, updated_data, client_checksum=updated_checksum) + assert server.check_file_checksum(filename, updated_checksum) + assert Path(server.root_dir, filename).read_bytes() == updated_data + @pytest.fixture def anyio_backend(): return "asyncio" + +async def _upload_file(server, filename: str, data: bytes, client_checksum: Optional[str] = None) -> str: + send_stream, receive_stream = create_memory_object_stream() + resource_uuid = uuid4() + server.resources[resource_uuid] = receive_stream + resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") + client_checksum = client_checksum or hashlib.sha256(data).hexdigest() + + async def send_data(): + await send_stream.send(data) + await send_stream.aclose() + + async with anyio.create_task_group() as tg: + tg.start_soon(send_data) + await server.put_file(filename, resource_handle, client_checksum) + + return server._compute_checksum(os.path.join(server.root_dir, filename)) From e8c76e97596ebb9b1b0828ebe3cb8e362113da75 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Wed, 29 Jan 2025 11:09:38 +0200 Subject: [PATCH 2/5] tftp: fix test paths Signed-off-by: Benny Zlotnik --- packages/jumpstarter-driver-tftp/pyproject.toml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/pyproject.toml b/packages/jumpstarter-driver-tftp/pyproject.toml index 830c47bb..50dfe17e 100644 --- a/packages/jumpstarter-driver-tftp/pyproject.toml +++ b/packages/jumpstarter-driver-tftp/pyproject.toml @@ -32,13 +32,12 @@ raw-options = { 'root' = '../../'} Homepage = "https://jumpstarter.dev" source_archive = "https://github.com/jumpstarter-dev/repo/archive/{commit_hash}.zip" -# [tool.pytest.ini_options] -# #addopts = "--cov --cov-report=html --cov-report=xml" -# log_cli = true -# log_cli_level = "INFO" -# # testpaths = ["src"] -# asyncio_mode = "auto" +[tool.pytest.ini_options] +log_cli = true +log_cli_level = "INFO" +testpaths = ["jumpstarter_driver_tftp"] +asyncio_mode = "auto" [build-system] requires = ["hatchling", "hatch-vcs"] -build-backend = "hatchling.build" \ No newline at end of file +build-backend = "hatchling.build" From 83d2e81c54d6b1a86206cc9e0113d912940592e5 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Wed, 29 Jan 2025 13:47:46 +0200 Subject: [PATCH 3/5] tftp: add path traversal test Signed-off-by: Benny Zlotnik --- .../jumpstarter_driver_tftp/driver_test.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index dd7cc9c2..11407e38 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -12,6 +12,7 @@ from jumpstarter_driver_tftp.driver import ( FileNotFound, Tftp, + TftpError, ) from jumpstarter.common.resources import ClientStreamResource @@ -152,3 +153,14 @@ async def send_data(): await server.put_file(filename, resource_handle, client_checksum) return server._compute_checksum(os.path.join(server.root_dir, filename)) + +@pytest.mark.anyio +async def test_tftp_path_traversal_attempt(server): + malicious_filename = "../../evil.txt" + + resource_uuid = uuid4() + server.resources[resource_uuid] = None + resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") + + with pytest.raises(TftpError, match="Invalid target path"): + await server.put_file(malicious_filename, resource_handle, "checksum") From 464025e8e5de187f1baac06b8b46769af7e55292 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Fri, 31 Jan 2025 10:55:58 +0200 Subject: [PATCH 4/5] tftp: increase chunk size and compute on write Signed-off-by: Benny Zlotnik --- .../jumpstarter_driver_tftp/__init__.py | 1 + .../jumpstarter_driver_tftp/client.py | 11 +++-- .../jumpstarter_driver_tftp/driver.py | 44 +++++++------------ 3 files changed, 23 insertions(+), 33 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py index e69de29b..fc318846 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/__init__.py @@ -0,0 +1 @@ +CHUNK_SIZE = 1024 * 1024 * 4 # 4MB diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py index 5989d87f..24081eea 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/client.py @@ -1,14 +1,13 @@ import hashlib -import logging from dataclasses import dataclass from pathlib import Path from jumpstarter_driver_opendal.adapter import OpendalAdapter from opendal import Operator +from . import CHUNK_SIZE from jumpstarter.client import DriverClient -logger = logging.getLogger(__name__) @dataclass(kw_only=True) class TftpServerClient(DriverClient): @@ -53,7 +52,7 @@ def put_file(self, operator: Operator, path: str): client_checksum = self._compute_checksum(operator, path) if self.call("check_file_checksum", filename, client_checksum): - logger.info(f"Skipping upload of identical file: {filename}") + self.logger.info(f"Skipping upload of identical file: {filename}") return filename with OpendalAdapter(client=self, operator=operator, path=path, mode="rb") as handle: @@ -67,10 +66,10 @@ def put_local_file(self, filepath: str): client_checksum = self._compute_checksum(operator, str(absolute)) if self.call("check_file_checksum", filename, client_checksum): - logger.info(f"Skipping upload of identical file: {filename}") + self.logger.info(f"Skipping upload of identical file: {filename}") return filename - logger.info(f"checksum: {client_checksum}") + self.logger.info(f"checksum: {client_checksum}") with OpendalAdapter(client=self, operator=operator, path=str(absolute), mode="rb") as handle: return self.call("put_file", filename, handle, client_checksum) @@ -108,6 +107,6 @@ def get_port(self) -> int: def _compute_checksum(self, operator: Operator, path: str) -> str: hasher = hashlib.sha256() with operator.open(path, "rb") as f: - while chunk := f.read(8192): + while chunk := f.read(CHUNK_SIZE): hasher.update(chunk) return hasher.hexdigest() diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 88aafff0..06969d85 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -11,6 +11,7 @@ from jumpstarter_driver_tftp.server import TftpServer +from . import CHUNK_SIZE from jumpstarter.driver import Driver, export @@ -37,7 +38,7 @@ class Tftp(Driver): """TFTP Server driver for Jumpstarter""" root_dir: str = "/var/lib/tftpboot" - host: str = field(default=None) + host: str = field(default='') port: int = 69 checksum_suffix: str = ".sha256" server: Optional["TftpServer"] = field(init=False, default=None) @@ -50,7 +51,7 @@ class Tftp(Driver): def __post_init__(self): super().__post_init__() os.makedirs(self.root_dir, exist_ok=True) - if self.host is None: + if self.host == '': self.host = self.get_default_ip() def get_default_ip(self): @@ -147,7 +148,7 @@ def list_files(self) -> list[str]: @export async def put_file(self, filename: str, src_stream, client_checksum: str): - """Only called when we know we need to upload""" + """Compute and store checksum at write time""" file_path = os.path.join(self.root_dir, filename) try: @@ -159,8 +160,9 @@ async def put_file(self, filename: str, src_stream, client_checksum: str): async for chunk in src: await dst.send(chunk) - self._checksums[filename] = client_checksum - self._write_checksum_file(filename, client_checksum) + current_checksum = self._compute_checksum(file_path) + self._checksums[filename] = current_checksum + self._write_checksum_file(filename, current_checksum) return filename except Exception as e: raise TftpError(f"Failed to upload file: {str(e)}") from e @@ -185,19 +187,20 @@ def delete_file(self, filename: str): @export def check_file_checksum(self, filename: str, client_checksum: str) -> bool: - """Check if file exists with matching checksum""" + """ + check if the checksum of the file matches the client checksum + """ + file_path = os.path.join(self.root_dir, filename) if not os.path.exists(file_path): return False current_checksum = self._compute_checksum(file_path) - stored_checksum = self._read_checksum_file(filename) - if stored_checksum != current_checksum: - self._write_checksum_file(filename, current_checksum) - self._checksums[filename] = current_checksum + self._checksums[filename] = current_checksum + self._write_checksum_file(filename, current_checksum) - logger.debug(f"Client checksum: {client_checksum}, server checksum: {current_checksum}") + self.logger.debug(f"Client checksum: {client_checksum}, server checksum: {current_checksum}") return current_checksum == client_checksum @export @@ -223,7 +226,7 @@ def _read_checksum_file(self, filename: str) -> Optional[str]: with open(checksum_path, 'r') as f: return f.read().strip() except Exception as e: - logger.warning(f"Failed to read checksum file for {filename}: {e}") + self.logger.warning(f"Failed to read checksum file for {filename}: {e}") return None def _write_checksum_file(self, filename: str, checksum: str): @@ -233,24 +236,11 @@ def _write_checksum_file(self, filename: str, checksum: str): with open(checksum_path, 'w') as f: f.write(f"{checksum}\n") except Exception as e: - logger.error(f"Failed to write checksum file for {filename}: {e}") + self.logger.error(f"Failed to write checksum file for {filename}: {e}") def _compute_checksum(self, path: str) -> str: hasher = hashlib.sha256() with open(path, "rb") as f: - while chunk := f.read(8192): + while chunk := f.read(CHUNK_SIZE): hasher.update(chunk) return hasher.hexdigest() - - def _initialize_checksums(self): - self._checksums.clear() - for filename in os.listdir(self.root_dir): - if filename.endswith(self.checksum_suffix): - continue - file_path = os.path.join(self.root_dir, filename) - if os.path.isfile(file_path): - stored_checksum = self._read_checksum_file(filename) - current_checksum = self._compute_checksum(file_path) - if stored_checksum != current_checksum: - self._write_checksum_file(filename, current_checksum) - self._checksums[filename] = current_checksum From 6204ef9ae85c30ad86987b81da2c720773052f80 Mon Sep 17 00:00:00 2001 From: Benny Zlotnik Date: Sat, 1 Feb 2025 11:10:58 +0200 Subject: [PATCH 5/5] tftp: remove checksum caching Signed-off-by: Benny Zlotnik --- .../jumpstarter_driver_tftp/driver.py | 62 ++--------------- .../jumpstarter_driver_tftp/driver_test.py | 69 ++----------------- 2 files changed, 12 insertions(+), 119 deletions(-) diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py index 06969d85..c89bca0a 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver.py @@ -17,22 +17,16 @@ class TftpError(Exception): """Base exception for TFTP server errors""" - pass - class ServerNotRunning(TftpError): """Server is not running""" - pass - class FileNotFound(TftpError): """File not found""" - pass - @dataclass(kw_only=True) class Tftp(Driver): """TFTP Server driver for Jumpstarter""" @@ -40,13 +34,11 @@ class Tftp(Driver): root_dir: str = "/var/lib/tftpboot" host: str = field(default='') port: int = 69 - checksum_suffix: str = ".sha256" server: Optional["TftpServer"] = field(init=False, default=None) server_thread: Optional[threading.Thread] = field(init=False, default=None) _shutdown_event: threading.Event = field(init=False, default_factory=threading.Event) _loop_ready: threading.Event = field(init=False, default_factory=threading.Event) _loop: Optional[asyncio.AbstractEventLoop] = field(init=False, default=None) - _checksums: dict[str, str] = field(default_factory=dict) def __post_init__(self): super().__post_init__() @@ -73,10 +65,7 @@ def _start_server(self): asyncio.set_event_loop(self._loop) self.server = TftpServer(host=self.host, port=self.port, root_dir=self.root_dir) try: - # Signal that the loop is ready self._loop_ready.set() - - # Run the server until shutdown is requested self._loop.run_until_complete(self._run_server()) except Exception as e: self.logger.error(f"Error running TFTP server: {e}") @@ -86,7 +75,6 @@ def _start_server(self): self._loop.close() except Exception as e: self.logger.error(f"Error during event loop cleanup: {e}") - self._loop = None self.logger.info("TFTP server thread completed") @@ -111,11 +99,9 @@ def start(self): self.logger.warning("TFTP server is already running") return - # Clear any previous shutdown state self._shutdown_event.clear() self._loop_ready.clear() - # Start the server thread self.server_thread = threading.Thread(target=self._start_server, daemon=True) self.server_thread.start() @@ -133,7 +119,6 @@ def stop(self): return self.logger.info("Initiating TFTP server shutdown") - self._shutdown_event.set() self.server_thread.join(timeout=10) if self.server_thread.is_alive(): @@ -148,7 +133,6 @@ def list_files(self) -> list[str]: @export async def put_file(self, filename: str, src_stream, client_checksum: str): - """Compute and store checksum at write time""" file_path = os.path.join(self.root_dir, filename) try: @@ -160,47 +144,37 @@ async def put_file(self, filename: str, src_stream, client_checksum: str): async for chunk in src: await dst.send(chunk) - current_checksum = self._compute_checksum(file_path) - self._checksums[filename] = current_checksum - self._write_checksum_file(filename, current_checksum) return filename except Exception as e: raise TftpError(f"Failed to upload file: {str(e)}") from e - @export def delete_file(self, filename: str): - """Delete file and its checksum file""" file_path = os.path.join(self.root_dir, filename) - checksum_path = self._get_checksum_path(filename) if not os.path.exists(file_path): raise FileNotFound(f"File {filename} not found") try: os.remove(file_path) - if os.path.exists(checksum_path): - os.remove(checksum_path) - self._checksums.pop(filename, None) + return filename except Exception as e: raise TftpError(f"Failed to delete {filename}") from e @export def check_file_checksum(self, filename: str, client_checksum: str) -> bool: - """ - check if the checksum of the file matches the client checksum - """ - file_path = os.path.join(self.root_dir, filename) + self.logger.debug(f"checking checksum for file: {filename}") + self.logger.debug(f"file path: {file_path}") + if not os.path.exists(file_path): + self.logger.debug(f"File {filename} does not exist") return False current_checksum = self._compute_checksum(file_path) + self.logger.debug(f"Computed checksum: {current_checksum}") + self.logger.debug(f"Client checksum: {client_checksum}") - self._checksums[filename] = current_checksum - self._write_checksum_file(filename, current_checksum) - - self.logger.debug(f"Client checksum: {client_checksum}, server checksum: {current_checksum}") return current_checksum == client_checksum @export @@ -216,28 +190,6 @@ def close(self): self.stop() super().close() - def _get_checksum_path(self, filename: str) -> str: - return os.path.join(self.root_dir, f"{filename}{self.checksum_suffix}") - - def _read_checksum_file(self, filename: str) -> Optional[str]: - try: - checksum_path = self._get_checksum_path(filename) - if os.path.exists(checksum_path): - with open(checksum_path, 'r') as f: - return f.read().strip() - except Exception as e: - self.logger.warning(f"Failed to read checksum file for {filename}: {e}") - return None - - def _write_checksum_file(self, filename: str, checksum: str): - """Write checksum to the checksum file""" - try: - checksum_path = self._get_checksum_path(filename) - with open(checksum_path, 'w') as f: - f.write(f"{checksum}\n") - except Exception as e: - self.logger.error(f"Failed to write checksum file for {filename}: {e}") - def _compute_checksum(self, path: str) -> str: hasher = hashlib.sha256() with open(path, "rb") as f: diff --git a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py index 11407e38..3f0f6911 100644 --- a/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py +++ b/packages/jumpstarter-driver-tftp/jumpstarter_driver_tftp/driver_test.py @@ -2,7 +2,6 @@ import os import tempfile from pathlib import Path -from typing import Optional from uuid import uuid4 import anyio @@ -12,7 +11,6 @@ from jumpstarter_driver_tftp.driver import ( FileNotFound, Tftp, - TftpError, ) from jumpstarter.common.resources import ClientStreamResource @@ -23,14 +21,12 @@ def temp_dir(): with tempfile.TemporaryDirectory() as tmpdir: yield tmpdir - @pytest.fixture def server(temp_dir): server = Tftp(root_dir=temp_dir, host="127.0.0.1") yield server server.close() - @pytest.mark.anyio async def test_tftp_file_operations(server): filename = "test.txt" @@ -64,44 +60,17 @@ async def send_data(): with pytest.raises(FileNotFound): server.delete_file("nonexistent.txt") - def test_tftp_host_config(temp_dir): custom_host = "192.168.1.1" server = Tftp(root_dir=temp_dir, host=custom_host) assert server.get_host() == custom_host - def test_tftp_root_directory_creation(temp_dir): new_dir = os.path.join(temp_dir, "new_tftp_root") server = Tftp(root_dir=new_dir) assert os.path.exists(new_dir) server.close() - -@pytest.mark.anyio -async def test_tftp_checksum_validation(server): - filename = "test_checksum.txt" - test_data = b"Hello world" - modified_data = b"Modified content" - - def compute_checksum(data: bytes) -> str: - return hashlib.sha256(data).hexdigest() - - initial_checksum = await _upload_file(server, filename, test_data) - assert filename in server.list_files() - assert compute_checksum(test_data) == initial_checksum - - # Second upload with same data should be skipped - same_data_checksum = await _upload_file(server, filename, test_data) - assert same_data_checksum == initial_checksum - - modified_checksum = await _upload_file(server, filename, modified_data) - assert modified_checksum != initial_checksum - assert Path(server.root_dir).joinpath(filename).read_bytes() == modified_data - - empty_checksum = await _upload_file(server, "empty.txt", b"") - assert empty_checksum == hashlib.sha256(b"").hexdigest() - @pytest.mark.anyio async def test_tftp_detect_corrupted_file(server): filename = "corrupted.txt" @@ -109,40 +78,23 @@ async def test_tftp_detect_corrupted_file(server): client_checksum = hashlib.sha256(original_data).hexdigest() await _upload_file(server, filename, original_data) + assert server.check_file_checksum(filename, client_checksum) file_path = Path(server.root_dir, filename) - with open(file_path, "wb") as f: - f.write(b"Corrupted Data") + file_path.write_bytes(b"corrupted Data") assert not server.check_file_checksum(filename, client_checksum) -@pytest.mark.anyio -async def test_tftp_reupload_different_checksum(server): - filename = "reupload.txt" - initial_data = b"Initial Data" - updated_data = b"Updated Data" - initial_checksum = hashlib.sha256(initial_data).hexdigest() - updated_checksum = hashlib.sha256(updated_data).hexdigest() - - await _upload_file(server, filename, initial_data) - assert server.check_file_checksum(filename, initial_checksum) - assert Path(server.root_dir, filename).read_bytes() == initial_data - - await _upload_file(server, filename, updated_data, client_checksum=updated_checksum) - assert server.check_file_checksum(filename, updated_checksum) - assert Path(server.root_dir, filename).read_bytes() == updated_data - @pytest.fixture def anyio_backend(): return "asyncio" -async def _upload_file(server, filename: str, data: bytes, client_checksum: Optional[str] = None) -> str: +async def _upload_file(server, filename: str, data: bytes) -> str: send_stream, receive_stream = create_memory_object_stream() resource_uuid = uuid4() server.resources[resource_uuid] = receive_stream resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") - client_checksum = client_checksum or hashlib.sha256(data).hexdigest() async def send_data(): await send_stream.send(data) @@ -150,17 +102,6 @@ async def send_data(): async with anyio.create_task_group() as tg: tg.start_soon(send_data) - await server.put_file(filename, resource_handle, client_checksum) - - return server._compute_checksum(os.path.join(server.root_dir, filename)) - -@pytest.mark.anyio -async def test_tftp_path_traversal_attempt(server): - malicious_filename = "../../evil.txt" - - resource_uuid = uuid4() - server.resources[resource_uuid] = None - resource_handle = ClientStreamResource(uuid=resource_uuid).model_dump(mode="json") + await server.put_file(filename, resource_handle, hashlib.sha256(data).hexdigest()) - with pytest.raises(TftpError, match="Invalid target path"): - await server.put_file(malicious_filename, resource_handle, "checksum") + return hashlib.sha256(data).hexdigest()