1111
1212from jumpstarter_driver_tftp .server import TftpServer
1313
14+ from . import CHUNK_SIZE
1415from jumpstarter .driver import Driver , export
1516
1617
@@ -37,7 +38,7 @@ class Tftp(Driver):
3738 """TFTP Server driver for Jumpstarter"""
3839
3940 root_dir : str = "/var/lib/tftpboot"
40- host : str = field (default = None )
41+ host : str = field (default = '' )
4142 port : int = 69
4243 checksum_suffix : str = ".sha256"
4344 server : Optional ["TftpServer" ] = field (init = False , default = None )
@@ -50,7 +51,7 @@ class Tftp(Driver):
5051 def __post_init__ (self ):
5152 super ().__post_init__ ()
5253 os .makedirs (self .root_dir , exist_ok = True )
53- if self .host is None :
54+ if self .host == '' :
5455 self .host = self .get_default_ip ()
5556
5657 def get_default_ip (self ):
@@ -147,7 +148,7 @@ def list_files(self) -> list[str]:
147148
148149 @export
149150 async def put_file (self , filename : str , src_stream , client_checksum : str ):
150- """Only called when we know we need to upload """
151+ """Compute and store checksum at write time """
151152 file_path = os .path .join (self .root_dir , filename )
152153
153154 try :
@@ -159,8 +160,9 @@ async def put_file(self, filename: str, src_stream, client_checksum: str):
159160 async for chunk in src :
160161 await dst .send (chunk )
161162
162- self ._checksums [filename ] = client_checksum
163- self ._write_checksum_file (filename , client_checksum )
163+ current_checksum = self ._compute_checksum (file_path )
164+ self ._checksums [filename ] = current_checksum
165+ self ._write_checksum_file (filename , current_checksum )
164166 return filename
165167 except Exception as e :
166168 raise TftpError (f"Failed to upload file: { str (e )} " ) from e
@@ -185,19 +187,20 @@ def delete_file(self, filename: str):
185187
186188 @export
187189 def check_file_checksum (self , filename : str , client_checksum : str ) -> bool :
188- """Check if file exists with matching checksum"""
190+ """
191+ check if the checksum of the file matches the client checksum
192+ """
193+
189194 file_path = os .path .join (self .root_dir , filename )
190195 if not os .path .exists (file_path ):
191196 return False
192197
193198 current_checksum = self ._compute_checksum (file_path )
194- stored_checksum = self ._read_checksum_file (filename )
195199
196- if stored_checksum != current_checksum :
197- self ._write_checksum_file (filename , current_checksum )
198- self ._checksums [filename ] = current_checksum
200+ self ._checksums [filename ] = current_checksum
201+ self ._write_checksum_file (filename , current_checksum )
199202
200- logger .debug (f"Client checksum: { client_checksum } , server checksum: { current_checksum } " )
203+ self . logger .debug (f"Client checksum: { client_checksum } , server checksum: { current_checksum } " )
201204 return current_checksum == client_checksum
202205
203206 @export
@@ -223,7 +226,7 @@ def _read_checksum_file(self, filename: str) -> Optional[str]:
223226 with open (checksum_path , 'r' ) as f :
224227 return f .read ().strip ()
225228 except Exception as e :
226- logger .warning (f"Failed to read checksum file for { filename } : { e } " )
229+ self . logger .warning (f"Failed to read checksum file for { filename } : { e } " )
227230 return None
228231
229232 def _write_checksum_file (self , filename : str , checksum : str ):
@@ -233,24 +236,11 @@ def _write_checksum_file(self, filename: str, checksum: str):
233236 with open (checksum_path , 'w' ) as f :
234237 f .write (f"{ checksum } \n " )
235238 except Exception as e :
236- logger .error (f"Failed to write checksum file for { filename } : { e } " )
239+ self . logger .error (f"Failed to write checksum file for { filename } : { e } " )
237240
238241 def _compute_checksum (self , path : str ) -> str :
239242 hasher = hashlib .sha256 ()
240243 with open (path , "rb" ) as f :
241- while chunk := f .read (8192 ):
244+ while chunk := f .read (CHUNK_SIZE ):
242245 hasher .update (chunk )
243246 return hasher .hexdigest ()
244-
245- def _initialize_checksums (self ):
246- self ._checksums .clear ()
247- for filename in os .listdir (self .root_dir ):
248- if filename .endswith (self .checksum_suffix ):
249- continue
250- file_path = os .path .join (self .root_dir , filename )
251- if os .path .isfile (file_path ):
252- stored_checksum = self ._read_checksum_file (filename )
253- current_checksum = self ._compute_checksum (file_path )
254- if stored_checksum != current_checksum :
255- self ._write_checksum_file (filename , current_checksum )
256- self ._checksums [filename ] = current_checksum
0 commit comments