Skip to content

Commit 6b73495

Browse files
committed
tftp: negotiate blocksize and timeout
Signed-off-by: Benny Zlotnik <[email protected]>
1 parent c635cca commit 6b73495

File tree

2 files changed

+125
-55
lines changed

2 files changed

+125
-55
lines changed

contrib/drivers/tftp/src/jumpstarter_driver_tftp/server.py

Lines changed: 119 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class Opcode(IntEnum):
1414
DATA = 3
1515
ACK = 4
1616
ERROR = 5
17+
OACK = 6
1718

1819

1920
class TftpErrorCode(IntEnum):
@@ -136,7 +137,6 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]):
136137

137138
async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
138139
try:
139-
# Parse filename and mode from request
140140
parts = data[2:].split(b'\x00')
141141
if len(parts) < 2:
142142
self.logger.error(f"Invalid RRQ format from {addr}")
@@ -145,7 +145,18 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
145145
filename = parts[0].decode('utf-8')
146146
mode = parts[1].decode('utf-8').lower()
147147

148-
self.logger.info(f"RRQ from {addr}: '{filename}' in mode '{mode}'")
148+
options = {}
149+
i = 2
150+
while i < len(parts) - 1:
151+
try:
152+
opt_name = parts[i].decode('utf-8').lower()
153+
opt_value = parts[i + 1].decode('utf-8')
154+
options[opt_name] = opt_value
155+
i += 2
156+
except IndexError:
157+
break
158+
159+
self.logger.info(f"RRQ from {addr}: '{filename}' in mode '{mode}' with options {options}")
149160

150161
if mode not in ('netascii', 'octet'):
151162
self.logger.warning(f"Unsupported transfer mode '{mode}' from {addr}")
@@ -162,17 +173,40 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
162173
return
163174

164175
if not is_subpath(resolved_path, self.server.root_dir):
165-
self.logger.error(f"Access violation: {resolved_path} is outside the root directory")
176+
self.logger.error(f"Access violation: {resolved_path} is outside root directory")
166177
self._send_error(addr, TftpErrorCode.ACCESS_VIOLATION, "Access denied")
167178
return
168179

180+
negotiated_options = {}
181+
182+
if 'blksize' in options:
183+
try:
184+
requested_blksize = int(options['blksize'])
185+
if 512 <= requested_blksize <= 8192:
186+
negotiated_options['blksize'] = requested_blksize
187+
else:
188+
negotiated_options['blksize'] = 512
189+
except ValueError:
190+
negotiated_options['blksize'] = 512
191+
else:
192+
negotiated_options['blksize'] = self.server.block_size
193+
194+
if 'timeout' in options:
195+
try:
196+
requested_timeout = int(options['timeout'])
197+
if 1 <= requested_timeout <= 255:
198+
negotiated_options['timeout'] = requested_timeout
199+
except ValueError:
200+
pass
201+
169202
transfer = TftpReadTransfer(
170203
server=self.server,
171204
filepath=resolved_path,
172205
client_addr=addr,
173-
block_size=self.server.block_size,
174-
timeout=self.server.timeout,
175-
retries=self.server.retries
206+
block_size=negotiated_options['blksize'],
207+
timeout=negotiated_options.get('timeout', self.server.timeout),
208+
retries=self.server.retries,
209+
negotiated_options=negotiated_options if options else None
176210
)
177211
self.server.register_transfer(transfer)
178212
asyncio.create_task(transfer.start())
@@ -181,6 +215,16 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
181215
self.logger.error(f"Error handling RRQ from {addr}: {e}")
182216
self._send_error(addr, TftpErrorCode.NOT_DEFINED, str(e))
183217

218+
def _send_oack(self, addr: Tuple[str, int], options: dict):
219+
"""Send Option Acknowledgment (OACK) packet."""
220+
oack_data = Opcode.OACK.to_bytes(2, 'big')
221+
for opt_name, opt_value in options.items():
222+
oack_data += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8')
223+
224+
if self.transport:
225+
self.transport.sendto(oack_data, addr)
226+
self.logger.debug(f"Sent OACK to {addr} with options {options}")
227+
184228
def _send_error(self, addr: Tuple[str, int], error_code: TftpErrorCode, message: str):
185229
error_packet = (
186230
Opcode.ERROR.to_bytes(2, 'big') +
@@ -232,16 +276,14 @@ async def cleanup(self):
232276

233277

234278
class TftpReadTransfer(TftpTransfer):
235-
"""
236-
Handles a TFTP Read (RRQ) transfer.
237-
"""
238-
239279
def __init__(self, server: TftpServer, filepath: pathlib.Path, client_addr: Tuple[str, int],
240-
block_size: int, timeout: float, retries: int):
280+
block_size: int, timeout: float, retries: int, negotiated_options: Optional[dict] = None):
241281
super().__init__(server, filepath, client_addr, block_size, timeout, retries)
242-
self.block_num = 1
282+
self.block_num = 0
243283
self.ack_received = asyncio.Event()
244284
self.last_ack = 0
285+
self.negotiated_options = negotiated_options
286+
self.oack_confirmed = False
245287

246288
async def start(self):
247289
self.logger.info(f"Starting read transfer of '{self.filepath.name}' to {self.client_addr}")
@@ -256,89 +298,117 @@ async def start(self):
256298
self.logger.debug(f"Transfer bound to local {local_addr}")
257299

258300
try:
301+
if self.negotiated_options:
302+
oack_packet = self._create_oack_packet()
303+
if not await self._send_with_retries(oack_packet, is_oack=True):
304+
self.logger.error("Failed to get acknowledgment for OACK")
305+
return
306+
self.block_num = 1
307+
259308
async with aiofiles.open(self.filepath, 'rb') as f:
260309
while True:
261310
if self.server.shutdown_event.is_set():
262311
self.logger.info(f"Server shutdown detected, stopping transfer to {self.client_addr}")
263312
break
313+
264314
data = await f.read(self.block_size)
265-
if data:
266-
packet = (
267-
Opcode.DATA.to_bytes(2, 'big') +
268-
self.block_num.to_bytes(2, 'big') +
269-
data
270-
)
315+
if not data and self.block_num == 1:
316+
# Empty file case
317+
packet = self._create_data_packet(b'')
318+
await self._send_with_retries(packet)
319+
break
320+
elif data:
321+
packet = self._create_data_packet(data)
271322
success = await self._send_with_retries(packet)
272323
if not success:
273324
self.logger.error(f"Failed to send block {self.block_num} to {self.client_addr}")
274325
break
326+
275327
self.logger.debug(f"Block {self.block_num} sent successfully")
276328
self.block_num += 1
277329

278-
# If the data read is less than block_size, this is the last packet
279330
if len(data) < self.block_size:
280-
self.logger.info(f"Final block {self.block_num - 1} reached for {self.client_addr}")
331+
self.logger.info(f"Final block {self.block_num - 1} sent")
281332
break
282333
else:
283-
# If no data is returned, it means the file size is an exact multiple of block_size
284-
# Send an extra empty DATA packet to signal end of transfer
285-
packet = (
286-
Opcode.DATA.to_bytes(2, 'big') +
287-
self.block_num.to_bytes(2, 'big') +
288-
b''
289-
)
334+
# End of file reached
335+
packet = self._create_data_packet(b'')
290336
success = await self._send_with_retries(packet)
291337
if not success:
292-
self.logger.error(
293-
f"Failed to send final empty block {self.block_num} "
294-
f"to {self.client_addr}"
295-
)
338+
self.logger.error(f"Failed to send final block {self.block_num}")
296339
break
297-
self.logger.info(f"Transfer complete to {self.client_addr}, final block {self.block_num}")
340+
self.logger.info(f"Transfer complete, final block {self.block_num}")
298341
break
299342

300343
except Exception as e:
301344
self.logger.error(f"Error during read transfer: {e}")
302345
finally:
303346
await self.cleanup()
304347

305-
async def _send_with_retries(self, packet: bytes) -> bool:
348+
def _create_oack_packet(self) -> bytes:
349+
"""Create OACK packet with negotiated options."""
350+
packet = Opcode.OACK.to_bytes(2, 'big')
351+
for opt_name, opt_value in self.negotiated_options.items():
352+
packet += f"{opt_name}\0{str(opt_value)}\0".encode('utf-8')
353+
return packet
354+
355+
def _create_data_packet(self, data: bytes) -> bytes:
356+
"""Create DATA packet with block number and data."""
357+
return (
358+
Opcode.DATA.to_bytes(2, 'big') +
359+
self.block_num.to_bytes(2, 'big') +
360+
data
361+
)
362+
363+
def _send_packet(self, packet: bytes):
364+
"""
365+
Sends a packet to the client.
366+
"""
367+
self.transport.sendto(packet)
368+
if packet[0:2] == Opcode.DATA.to_bytes(2, 'big'):
369+
block = int.from_bytes(packet[2:4], 'big')
370+
data_length = len(packet) - 4
371+
self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}")
372+
elif packet[0:2] == Opcode.OACK.to_bytes(2, 'big'):
373+
self.logger.debug(f"Sent OACK to {self.client_addr}")
374+
375+
async def _send_with_retries(self, packet: bytes, is_oack: bool = False) -> bool:
306376
self.current_packet = packet
377+
expected_block = 0 if is_oack else self.block_num
378+
307379
for attempt in range(1, self.retries + 1):
308380
try:
309381
self._send_packet(packet)
310-
self.logger.debug(f"Sent DATA block {self.block_num}, waiting for ACK (Attempt {attempt})")
382+
self.logger.debug(f"Sent {'OACK' if is_oack else 'DATA'} block {expected_block}, waiting for ACK (Attempt {attempt})")
311383
self.ack_received.clear()
312384
await asyncio.wait_for(self.ack_received.wait(), timeout=self.timeout)
313385

314-
if self.last_ack == self.block_num:
315-
self.logger.debug(f"ACK received for block {self.block_num}")
386+
if self.last_ack == expected_block:
387+
self.logger.debug(f"ACK received for block {expected_block}")
316388
return True
317389
else:
318-
self.logger.warning(f"Received wrong ACK: expected {self.block_num}, got {self.last_ack}")
390+
self.logger.warning(f"Received wrong ACK: expected {expected_block}, got {self.last_ack}")
319391

320392
except asyncio.TimeoutError:
321-
self.logger.warning(f"Timeout waiting for ACK of block {self.block_num} (Attempt {attempt})")
393+
self.logger.warning(f"Timeout waiting for ACK of block {expected_block} (Attempt {attempt})")
322394

323395
return False
324396

325-
def _send_packet(self, packet: bytes):
326-
"""
327-
Sends a DATA packet to the client.
328-
"""
329-
self.transport.sendto(packet)
330-
block = int.from_bytes(packet[2:4], 'big')
331-
data_length = len(packet) - 4
332-
self.logger.debug(f"Sent DATA block {block} ({data_length} bytes) to {self.client_addr}")
333-
334397
def handle_ack(self, block_num: int):
335398
self.logger.debug(f"Received ACK for block {block_num} from {self.client_addr}")
399+
400+
# special handling for OACK acknowledgment
401+
if not self.oack_confirmed and self.negotiated_options and block_num == 0:
402+
self.oack_confirmed = True
403+
self.last_ack = block_num
404+
self.ack_received.set()
405+
return
406+
336407
if block_num == self.block_num:
337408
self.last_ack = block_num
338409
self.ack_received.set()
339410
elif block_num == self.block_num - 1:
340-
# Duplicate ACK for previous block, resend current packet
341-
self.logger.warning(f"Duplicate ACK for block {block_num} received, resending DATA block {self.block_num}")
411+
self.logger.warning(f"Duplicate ACK for block {block_num} received, resending block {self.block_num}")
342412
self.transport.sendto(self.current_packet)
343413
else:
344414
self.logger.warning(f"Out of sequence ACK: expected {self.block_num}, got {block_num}")

uv.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)