@@ -14,6 +14,7 @@ class Opcode(IntEnum):
14
14
DATA = 3
15
15
ACK = 4
16
16
ERROR = 5
17
+ OACK = 6
17
18
18
19
19
20
class TftpErrorCode (IntEnum ):
@@ -136,7 +137,6 @@ def datagram_received(self, data: bytes, addr: Tuple[str, int]):
136
137
137
138
async def _handle_read_request (self , data : bytes , addr : Tuple [str , int ]):
138
139
try :
139
- # Parse filename and mode from request
140
140
parts = data [2 :].split (b'\x00 ' )
141
141
if len (parts ) < 2 :
142
142
self .logger .error (f"Invalid RRQ format from { addr } " )
@@ -145,7 +145,18 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
145
145
filename = parts [0 ].decode ('utf-8' )
146
146
mode = parts [1 ].decode ('utf-8' ).lower ()
147
147
148
- self .logger .info (f"RRQ from { addr } : '{ filename } ' in mode '{ mode } '" )
148
+ options = {}
149
+ i = 2
150
+ while i < len (parts ) - 1 :
151
+ try :
152
+ opt_name = parts [i ].decode ('utf-8' ).lower ()
153
+ opt_value = parts [i + 1 ].decode ('utf-8' )
154
+ options [opt_name ] = opt_value
155
+ i += 2
156
+ except IndexError :
157
+ break
158
+
159
+ self .logger .info (f"RRQ from { addr } : '{ filename } ' in mode '{ mode } ' with options { options } " )
149
160
150
161
if mode not in ('netascii' , 'octet' ):
151
162
self .logger .warning (f"Unsupported transfer mode '{ mode } ' from { addr } " )
@@ -162,17 +173,40 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
162
173
return
163
174
164
175
if not is_subpath (resolved_path , self .server .root_dir ):
165
- self .logger .error (f"Access violation: { resolved_path } is outside the root directory" )
176
+ self .logger .error (f"Access violation: { resolved_path } is outside root directory" )
166
177
self ._send_error (addr , TftpErrorCode .ACCESS_VIOLATION , "Access denied" )
167
178
return
168
179
180
+ negotiated_options = {}
181
+
182
+ if 'blksize' in options :
183
+ try :
184
+ requested_blksize = int (options ['blksize' ])
185
+ if 512 <= requested_blksize <= 8192 :
186
+ negotiated_options ['blksize' ] = requested_blksize
187
+ else :
188
+ negotiated_options ['blksize' ] = 512
189
+ except ValueError :
190
+ negotiated_options ['blksize' ] = 512
191
+ else :
192
+ negotiated_options ['blksize' ] = self .server .block_size
193
+
194
+ if 'timeout' in options :
195
+ try :
196
+ requested_timeout = int (options ['timeout' ])
197
+ if 1 <= requested_timeout <= 255 :
198
+ negotiated_options ['timeout' ] = requested_timeout
199
+ except ValueError :
200
+ pass
201
+
169
202
transfer = TftpReadTransfer (
170
203
server = self .server ,
171
204
filepath = resolved_path ,
172
205
client_addr = addr ,
173
- block_size = self .server .block_size ,
174
- timeout = self .server .timeout ,
175
- retries = self .server .retries
206
+ block_size = negotiated_options ['blksize' ],
207
+ timeout = negotiated_options .get ('timeout' , self .server .timeout ),
208
+ retries = self .server .retries ,
209
+ negotiated_options = negotiated_options if options else None
176
210
)
177
211
self .server .register_transfer (transfer )
178
212
asyncio .create_task (transfer .start ())
@@ -181,6 +215,16 @@ async def _handle_read_request(self, data: bytes, addr: Tuple[str, int]):
181
215
self .logger .error (f"Error handling RRQ from { addr } : { e } " )
182
216
self ._send_error (addr , TftpErrorCode .NOT_DEFINED , str (e ))
183
217
218
+ def _send_oack (self , addr : Tuple [str , int ], options : dict ):
219
+ """Send Option Acknowledgment (OACK) packet."""
220
+ oack_data = Opcode .OACK .to_bytes (2 , 'big' )
221
+ for opt_name , opt_value in options .items ():
222
+ oack_data += f"{ opt_name } \0 { str (opt_value )} \0 " .encode ('utf-8' )
223
+
224
+ if self .transport :
225
+ self .transport .sendto (oack_data , addr )
226
+ self .logger .debug (f"Sent OACK to { addr } with options { options } " )
227
+
184
228
def _send_error (self , addr : Tuple [str , int ], error_code : TftpErrorCode , message : str ):
185
229
error_packet = (
186
230
Opcode .ERROR .to_bytes (2 , 'big' ) +
@@ -232,16 +276,14 @@ async def cleanup(self):
232
276
233
277
234
278
class TftpReadTransfer (TftpTransfer ):
235
- """
236
- Handles a TFTP Read (RRQ) transfer.
237
- """
238
-
239
279
def __init__ (self , server : TftpServer , filepath : pathlib .Path , client_addr : Tuple [str , int ],
240
- block_size : int , timeout : float , retries : int ):
280
+ block_size : int , timeout : float , retries : int , negotiated_options : Optional [ dict ] = None ):
241
281
super ().__init__ (server , filepath , client_addr , block_size , timeout , retries )
242
- self .block_num = 1
282
+ self .block_num = 0
243
283
self .ack_received = asyncio .Event ()
244
284
self .last_ack = 0
285
+ self .negotiated_options = negotiated_options
286
+ self .oack_confirmed = False
245
287
246
288
async def start (self ):
247
289
self .logger .info (f"Starting read transfer of '{ self .filepath .name } ' to { self .client_addr } " )
@@ -256,89 +298,117 @@ async def start(self):
256
298
self .logger .debug (f"Transfer bound to local { local_addr } " )
257
299
258
300
try :
301
+ if self .negotiated_options :
302
+ oack_packet = self ._create_oack_packet ()
303
+ if not await self ._send_with_retries (oack_packet , is_oack = True ):
304
+ self .logger .error ("Failed to get acknowledgment for OACK" )
305
+ return
306
+ self .block_num = 1
307
+
259
308
async with aiofiles .open (self .filepath , 'rb' ) as f :
260
309
while True :
261
310
if self .server .shutdown_event .is_set ():
262
311
self .logger .info (f"Server shutdown detected, stopping transfer to { self .client_addr } " )
263
312
break
313
+
264
314
data = await f .read (self .block_size )
265
- if data :
266
- packet = (
267
- Opcode .DATA .to_bytes (2 , 'big' ) +
268
- self .block_num .to_bytes (2 , 'big' ) +
269
- data
270
- )
315
+ if not data and self .block_num == 1 :
316
+ # Empty file case
317
+ packet = self ._create_data_packet (b'' )
318
+ await self ._send_with_retries (packet )
319
+ break
320
+ elif data :
321
+ packet = self ._create_data_packet (data )
271
322
success = await self ._send_with_retries (packet )
272
323
if not success :
273
324
self .logger .error (f"Failed to send block { self .block_num } to { self .client_addr } " )
274
325
break
326
+
275
327
self .logger .debug (f"Block { self .block_num } sent successfully" )
276
328
self .block_num += 1
277
329
278
- # If the data read is less than block_size, this is the last packet
279
330
if len (data ) < self .block_size :
280
- self .logger .info (f"Final block { self .block_num - 1 } reached for { self . client_addr } " )
331
+ self .logger .info (f"Final block { self .block_num - 1 } sent " )
281
332
break
282
333
else :
283
- # If no data is returned, it means the file size is an exact multiple of block_size
284
- # Send an extra empty DATA packet to signal end of transfer
285
- packet = (
286
- Opcode .DATA .to_bytes (2 , 'big' ) +
287
- self .block_num .to_bytes (2 , 'big' ) +
288
- b''
289
- )
334
+ # End of file reached
335
+ packet = self ._create_data_packet (b'' )
290
336
success = await self ._send_with_retries (packet )
291
337
if not success :
292
- self .logger .error (
293
- f"Failed to send final empty block { self .block_num } "
294
- f"to { self .client_addr } "
295
- )
338
+ self .logger .error (f"Failed to send final block { self .block_num } " )
296
339
break
297
- self .logger .info (f"Transfer complete to { self . client_addr } , final block { self .block_num } " )
340
+ self .logger .info (f"Transfer complete, final block { self .block_num } " )
298
341
break
299
342
300
343
except Exception as e :
301
344
self .logger .error (f"Error during read transfer: { e } " )
302
345
finally :
303
346
await self .cleanup ()
304
347
305
- async def _send_with_retries (self , packet : bytes ) -> bool :
348
+ def _create_oack_packet (self ) -> bytes :
349
+ """Create OACK packet with negotiated options."""
350
+ packet = Opcode .OACK .to_bytes (2 , 'big' )
351
+ for opt_name , opt_value in self .negotiated_options .items ():
352
+ packet += f"{ opt_name } \0 { str (opt_value )} \0 " .encode ('utf-8' )
353
+ return packet
354
+
355
+ def _create_data_packet (self , data : bytes ) -> bytes :
356
+ """Create DATA packet with block number and data."""
357
+ return (
358
+ Opcode .DATA .to_bytes (2 , 'big' ) +
359
+ self .block_num .to_bytes (2 , 'big' ) +
360
+ data
361
+ )
362
+
363
+ def _send_packet (self , packet : bytes ):
364
+ """
365
+ Sends a packet to the client.
366
+ """
367
+ self .transport .sendto (packet )
368
+ if packet [0 :2 ] == Opcode .DATA .to_bytes (2 , 'big' ):
369
+ block = int .from_bytes (packet [2 :4 ], 'big' )
370
+ data_length = len (packet ) - 4
371
+ self .logger .debug (f"Sent DATA block { block } ({ data_length } bytes) to { self .client_addr } " )
372
+ elif packet [0 :2 ] == Opcode .OACK .to_bytes (2 , 'big' ):
373
+ self .logger .debug (f"Sent OACK to { self .client_addr } " )
374
+
375
+ async def _send_with_retries (self , packet : bytes , is_oack : bool = False ) -> bool :
306
376
self .current_packet = packet
377
+ expected_block = 0 if is_oack else self .block_num
378
+
307
379
for attempt in range (1 , self .retries + 1 ):
308
380
try :
309
381
self ._send_packet (packet )
310
- self .logger .debug (f"Sent DATA block { self . block_num } , waiting for ACK (Attempt { attempt } )" )
382
+ self .logger .debug (f"Sent { 'OACK' if is_oack else ' DATA' } block { expected_block } , waiting for ACK (Attempt { attempt } )" )
311
383
self .ack_received .clear ()
312
384
await asyncio .wait_for (self .ack_received .wait (), timeout = self .timeout )
313
385
314
- if self .last_ack == self . block_num :
315
- self .logger .debug (f"ACK received for block { self . block_num } " )
386
+ if self .last_ack == expected_block :
387
+ self .logger .debug (f"ACK received for block { expected_block } " )
316
388
return True
317
389
else :
318
- self .logger .warning (f"Received wrong ACK: expected { self . block_num } , got { self .last_ack } " )
390
+ self .logger .warning (f"Received wrong ACK: expected { expected_block } , got { self .last_ack } " )
319
391
320
392
except asyncio .TimeoutError :
321
- self .logger .warning (f"Timeout waiting for ACK of block { self . block_num } (Attempt { attempt } )" )
393
+ self .logger .warning (f"Timeout waiting for ACK of block { expected_block } (Attempt { attempt } )" )
322
394
323
395
return False
324
396
325
- def _send_packet (self , packet : bytes ):
326
- """
327
- Sends a DATA packet to the client.
328
- """
329
- self .transport .sendto (packet )
330
- block = int .from_bytes (packet [2 :4 ], 'big' )
331
- data_length = len (packet ) - 4
332
- self .logger .debug (f"Sent DATA block { block } ({ data_length } bytes) to { self .client_addr } " )
333
-
334
397
def handle_ack (self , block_num : int ):
335
398
self .logger .debug (f"Received ACK for block { block_num } from { self .client_addr } " )
399
+
400
+ # special handling for OACK acknowledgment
401
+ if not self .oack_confirmed and self .negotiated_options and block_num == 0 :
402
+ self .oack_confirmed = True
403
+ self .last_ack = block_num
404
+ self .ack_received .set ()
405
+ return
406
+
336
407
if block_num == self .block_num :
337
408
self .last_ack = block_num
338
409
self .ack_received .set ()
339
410
elif block_num == self .block_num - 1 :
340
- # Duplicate ACK for previous block, resend current packet
341
- self .logger .warning (f"Duplicate ACK for block { block_num } received, resending DATA block { self .block_num } " )
411
+ self .logger .warning (f"Duplicate ACK for block { block_num } received, resending block { self .block_num } " )
342
412
self .transport .sendto (self .current_packet )
343
413
else :
344
414
self .logger .warning (f"Out of sequence ACK: expected { self .block_num } , got { block_num } " )
0 commit comments