Skip to content

Commit e885ae4

Browse files
committed
Replace file objects used as function arguments
By replacing file objects passed as function arguments with the read file content, we simplify temporary file objects life cycle management. Temporary files are handled in a single function. This is done for metadata files, which are fully read into memory right after download, anyway. Same is not true for target files which preferably should be treated in chunks so targets download and verification still deal with file objects. _check_hashes is split in two functions, one dealing correctly with file objects and one using directly file content. Signed-off-by: Teodora Sechkova <[email protected]>
1 parent 482ea1c commit e885ae4

File tree

2 files changed

+39
-40
lines changed

2 files changed

+39
-40
lines changed

tuf/client_rework/metadata_wrapper.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def __init__(self, meta):
2222
self._meta = meta
2323

2424
@classmethod
25-
def from_json_object(cls, tmp_file):
25+
def from_json_object(cls, raw_data):
2626
"""Loads JSON-formatted TUF metadata from a file object."""
27-
raw_data = tmp_file.read()
2827
# Use local scope import to avoid circular import errors
2928
# pylint: disable=import-outside-toplevel
3029
from tuf.api.serialization.json import JSONDeserializer

tuf/client_rework/updater_rework.py

+38-38
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import fnmatch
1111
import logging
1212
import os
13-
from typing import BinaryIO, Dict, Optional, TextIO
13+
from typing import Dict, Optional
1414

1515
from securesystemslib import exceptions as sslib_exceptions
1616
from securesystemslib import hash as sslib_hash
@@ -159,9 +159,9 @@ def download_target(self, target: Dict, destination_directory: str):
159159
temp_obj = download.download_file(
160160
file_mirror, target["fileinfo"]["length"], self._fetcher
161161
)
162-
162+
_check_file_length(temp_obj, target["fileinfo"]["length"])
163163
temp_obj.seek(0)
164-
self._verify_target_file(temp_obj, target)
164+
_check_hashes_obj(temp_obj, target["fileinfo"]["hashes"])
165165
break
166166

167167
except Exception as exception: # pylint: disable=broad-except
@@ -308,7 +308,7 @@ def _root_mirrors_download(self, root_mirrors: Dict) -> "RootWrapper":
308308
)
309309

310310
temp_obj.seek(0)
311-
intermediate_root = self._verify_root(temp_obj)
311+
intermediate_root = self._verify_root(temp_obj.read())
312312
# When we reach this point, a root file has been successfully
313313
# downloaded and verified so we can exit the loop.
314314
break
@@ -356,7 +356,7 @@ def _load_timestamp(self) -> None:
356356
)
357357

358358
temp_obj.seek(0)
359-
verified_timestamp = self._verify_timestamp(temp_obj)
359+
verified_timestamp = self._verify_timestamp(temp_obj.read())
360360
break
361361

362362
except Exception as exception: # pylint: disable=broad-except
@@ -410,7 +410,7 @@ def _load_snapshot(self) -> None:
410410
)
411411

412412
temp_obj.seek(0)
413-
verified_snapshot = self._verify_snapshot(temp_obj)
413+
verified_snapshot = self._verify_snapshot(temp_obj.read())
414414
break
415415

416416
except Exception as exception: # pylint: disable=broad-except
@@ -465,7 +465,7 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None:
465465

466466
temp_obj.seek(0)
467467
verified_targets = self._verify_targets(
468-
temp_obj, targets_role, parent_role
468+
temp_obj.read(), targets_role, parent_role
469469
)
470470
break
471471

@@ -487,12 +487,12 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None:
487487
self._get_full_meta_name(targets_role, extension=".json")
488488
)
489489

490-
def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
490+
def _verify_root(self, file_content: bytes) -> RootWrapper:
491491
"""
492492
TODO
493493
"""
494494

495-
intermediate_root = RootWrapper.from_json_object(temp_obj)
495+
intermediate_root = RootWrapper.from_json_object(file_content)
496496

497497
# Check for an arbitrary software attack
498498
trusted_root = self._metadata["root"]
@@ -505,7 +505,6 @@ def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
505505

506506
# Check for a rollback attack.
507507
if intermediate_root.version < trusted_root.version:
508-
temp_obj.close()
509508
raise exceptions.ReplayedMetadataError(
510509
"root", intermediate_root.version(), trusted_root.version()
511510
)
@@ -514,11 +513,11 @@ def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
514513

515514
return intermediate_root
516515

517-
def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
516+
def _verify_timestamp(self, file_content: bytes) -> TimestampWrapper:
518517
"""
519518
TODO
520519
"""
521-
intermediate_timestamp = TimestampWrapper.from_json_object(temp_obj)
520+
intermediate_timestamp = TimestampWrapper.from_json_object(file_content)
522521

523522
# Check for an arbitrary software attack
524523
trusted_root = self._metadata["root"]
@@ -532,7 +531,6 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
532531
intermediate_timestamp.signed.version
533532
<= self._metadata["timestamp"].version
534533
):
535-
temp_obj.close()
536534
raise exceptions.ReplayedMetadataError(
537535
"root",
538536
intermediate_timestamp.version(),
@@ -544,7 +542,6 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
544542
intermediate_timestamp.snapshot.version
545543
<= self._metadata["timestamp"].snapshot["version"]
546544
):
547-
temp_obj.close()
548545
raise exceptions.ReplayedMetadataError(
549546
"root",
550547
intermediate_timestamp.snapshot.version(),
@@ -555,24 +552,23 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
555552

556553
return intermediate_timestamp
557554

558-
def _verify_snapshot(self, temp_obj: TextIO) -> SnapshotWrapper:
555+
def _verify_snapshot(self, file_content: bytes) -> SnapshotWrapper:
559556
"""
560557
TODO
561558
"""
562559

563560
# Check against timestamp metadata
564561
if self._metadata["timestamp"].snapshot.get("hash"):
565562
_check_hashes(
566-
temp_obj, self._metadata["timestamp"].snapshot.get("hash")
563+
file_content, self._metadata["timestamp"].snapshot.get("hash")
567564
)
568565

569-
intermediate_snapshot = SnapshotWrapper.from_json_object(temp_obj)
566+
intermediate_snapshot = SnapshotWrapper.from_json_object(file_content)
570567

571568
if (
572569
intermediate_snapshot.version
573570
!= self._metadata["timestamp"].snapshot["version"]
574571
):
575-
temp_obj.close()
576572
raise exceptions.BadVersionNumberError
577573

578574
# Check for an arbitrary software attack
@@ -588,15 +584,14 @@ def _verify_snapshot(self, temp_obj: TextIO) -> SnapshotWrapper:
588584
target_role["version"]
589585
!= self._metadata["snapshot"].meta[target_role]["version"]
590586
):
591-
temp_obj.close()
592587
raise exceptions.BadVersionNumberError
593588

594589
intermediate_snapshot.expires()
595590

596591
return intermediate_snapshot
597592

598593
def _verify_targets(
599-
self, temp_obj: TextIO, filename: str, parent_role: str
594+
self, file_content: bytes, filename: str, parent_role: str
600595
) -> TargetsWrapper:
601596
"""
602597
TODO
@@ -605,15 +600,14 @@ def _verify_targets(
605600
# Check against timestamp metadata
606601
if self._metadata["snapshot"].role(filename).get("hash"):
607602
_check_hashes(
608-
temp_obj, self._metadata["snapshot"].targets.get("hash")
603+
file_content, self._metadata["snapshot"].targets.get("hash")
609604
)
610605

611-
intermediate_targets = TargetsWrapper.from_json_object(temp_obj)
606+
intermediate_targets = TargetsWrapper.from_json_object(file_content)
612607
if (
613608
intermediate_targets.version
614609
!= self._metadata["snapshot"].role(filename)["version"]
615610
):
616-
temp_obj.close()
617611
raise exceptions.BadVersionNumberError
618612

619613
# Check for an arbitrary software attack
@@ -627,15 +621,6 @@ def _verify_targets(
627621

628622
return intermediate_targets
629623

630-
@staticmethod
631-
def _verify_target_file(temp_obj: BinaryIO, targetinfo: Dict) -> None:
632-
"""
633-
TODO
634-
"""
635-
636-
_check_file_length(temp_obj, targetinfo["fileinfo"]["length"])
637-
_check_hashes(temp_obj, targetinfo["fileinfo"]["hashes"])
638-
639624
def _preorder_depth_first_walk(self, target_filepath) -> Dict:
640625
"""
641626
TODO
@@ -864,19 +849,34 @@ def _check_file_length(file_object, trusted_file_length):
864849
)
865850

866851

867-
def _check_hashes(file_object, trusted_hashes):
852+
def _check_hashes_obj(file_object, trusted_hashes):
853+
"""
854+
TODO
855+
"""
856+
for algorithm, trusted_hash in trusted_hashes.items():
857+
digest_object = sslib_hash.digest_fileobject(file_object, algorithm)
858+
859+
computed_hash = digest_object.hexdigest()
860+
861+
# Raise an exception if any of the hashes are incorrect.
862+
if trusted_hash != computed_hash:
863+
raise sslib_exceptions.BadHashError(trusted_hash, computed_hash)
864+
865+
logger.info(
866+
"The file's " + algorithm + " hash is" " correct: " + trusted_hash
867+
)
868+
869+
870+
def _check_hashes(file_content, trusted_hashes):
868871
"""
869872
TODO
870873
"""
871874
# Verify each trusted hash of 'trusted_hashes'. If all are valid, simply
872875
# return.
873876
for algorithm, trusted_hash in trusted_hashes.items():
874877
digest_object = sslib_hash.digest(algorithm)
875-
# Ensure we read from the beginning of the file object
876-
# TODO: should we store file position (before the loop) and reset
877-
# after we seek about?
878-
file_object.seek(0)
879-
digest_object.update(file_object.read())
878+
879+
digest_object.update(file_content)
880880
computed_hash = digest_object.hexdigest()
881881

882882
# Raise an exception if any of the hashes are incorrect.

0 commit comments

Comments
 (0)