Skip to content

Commit 660eb25

Browse files
committed
Replace file objects used as function arguments
By replacing file objects passed as function arguments with the read file content, we simplify temporary file objects life cycle management. Temporary files are handled in a single function. This is done for metadata files, which are fully read into memory right after download, anyway. Same is not true for target files which preferably should be treated in chunks so targets download and verification still deal with file objects. _check_hashes is split in two functions, one dealing correctly with file objects and one using directly file content. Signed-off-by: Teodora Sechkova <[email protected]>
1 parent ad760fd commit 660eb25

File tree

2 files changed

+39
-40
lines changed

2 files changed

+39
-40
lines changed

tuf/client_rework/metadata_wrapper.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def __init__(self, meta):
2222
self._meta = meta
2323

2424
@classmethod
25-
def from_json_object(cls, tmp_file):
25+
def from_json_object(cls, raw_data):
2626
"""Loads JSON-formatted TUF metadata from a file object."""
27-
raw_data = tmp_file.read()
2827
# Use local scope import to avoid circular import errors
2928
# pylint: disable=import-outside-toplevel
3029
from tuf.api.serialization.json import JSONDeserializer

tuf/client_rework/updater_rework.py

+38-38
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import fnmatch
1111
import logging
1212
import os
13-
from typing import BinaryIO, Dict, Optional, TextIO
13+
from typing import Dict, Optional
1414

1515
from securesystemslib import exceptions as sslib_exceptions
1616
from securesystemslib import hash as sslib_hash
@@ -158,9 +158,9 @@ def download_target(self, target: Dict, destination_directory: str):
158158
temp_obj = download.download_file(
159159
file_mirror, target["fileinfo"]["length"], self._fetcher
160160
)
161-
161+
_check_file_length(temp_obj, target["fileinfo"]["length"])
162162
temp_obj.seek(0)
163-
self._verify_target_file(temp_obj, target)
163+
_check_hashes_obj(temp_obj, target["fileinfo"]["hashes"])
164164
break
165165

166166
except Exception as exception:
@@ -297,7 +297,7 @@ def _root_mirrors_download(self, root_mirrors: Dict) -> "RootWrapper":
297297
)
298298

299299
temp_obj.seek(0)
300-
intermediate_root = self._verify_root(temp_obj)
300+
intermediate_root = self._verify_root(temp_obj.read())
301301
# When we reach this point, a root file has been successfully
302302
# downloaded and verified so we can exit the loop.
303303
break
@@ -344,7 +344,7 @@ def _load_timestamp(self) -> None:
344344
)
345345

346346
temp_obj.seek(0)
347-
verified_timestamp = self._verify_timestamp(temp_obj)
347+
verified_timestamp = self._verify_timestamp(temp_obj.read())
348348
break
349349

350350
except Exception as exception: # pylint: disable=broad-except
@@ -397,7 +397,7 @@ def _load_snapshot(self) -> None:
397397
)
398398

399399
temp_obj.seek(0)
400-
verified_snapshot = self._verify_snapshot(temp_obj)
400+
verified_snapshot = self._verify_snapshot(temp_obj.read())
401401
break
402402

403403
except Exception as exception: # pylint: disable=broad-except
@@ -451,7 +451,7 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None:
451451

452452
temp_obj.seek(0)
453453
verified_targets = self._verify_targets(
454-
temp_obj, targets_role, parent_role
454+
temp_obj.read(), targets_role, parent_role
455455
)
456456
break
457457

@@ -472,12 +472,12 @@ def _load_targets(self, targets_role: str, parent_role: str) -> None:
472472
self._get_full_meta_name(targets_role, extension=".json")
473473
)
474474

475-
def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
475+
def _verify_root(self, file_content: bytes) -> RootWrapper:
476476
"""
477477
TODO
478478
"""
479479

480-
intermediate_root = RootWrapper.from_json_object(temp_obj)
480+
intermediate_root = RootWrapper.from_json_object(file_content)
481481

482482
# Check for an arbitrary software attack
483483
trusted_root = self._metadata["root"]
@@ -490,7 +490,6 @@ def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
490490

491491
# Check for a rollback attack.
492492
if intermediate_root.version < trusted_root.version:
493-
temp_obj.close()
494493
raise exceptions.ReplayedMetadataError(
495494
"root", intermediate_root.version(), trusted_root.version()
496495
)
@@ -499,11 +498,11 @@ def _verify_root(self, temp_obj: TextIO) -> RootWrapper:
499498

500499
return intermediate_root
501500

502-
def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
501+
def _verify_timestamp(self, file_content: bytes) -> TimestampWrapper:
503502
"""
504503
TODO
505504
"""
506-
intermediate_timestamp = TimestampWrapper.from_json_object(temp_obj)
505+
intermediate_timestamp = TimestampWrapper.from_json_object(file_content)
507506

508507
# Check for an arbitrary software attack
509508
trusted_root = self._metadata["root"]
@@ -517,7 +516,6 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
517516
intermediate_timestamp.signed.version
518517
<= self._metadata["timestamp"].version
519518
):
520-
temp_obj.close()
521519
raise exceptions.ReplayedMetadataError(
522520
"root",
523521
intermediate_timestamp.version(),
@@ -529,7 +527,6 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
529527
intermediate_timestamp.snapshot.version
530528
<= self._metadata["timestamp"].snapshot["version"]
531529
):
532-
temp_obj.close()
533530
raise exceptions.ReplayedMetadataError(
534531
"root",
535532
intermediate_timestamp.snapshot.version(),
@@ -540,24 +537,23 @@ def _verify_timestamp(self, temp_obj: TextIO) -> TimestampWrapper:
540537

541538
return intermediate_timestamp
542539

543-
def _verify_snapshot(self, temp_obj: TextIO) -> SnapshotWrapper:
540+
def _verify_snapshot(self, file_content: bytes) -> SnapshotWrapper:
544541
"""
545542
TODO
546543
"""
547544

548545
# Check against timestamp metadata
549546
if self._metadata["timestamp"].snapshot.get("hash"):
550547
_check_hashes(
551-
temp_obj, self._metadata["timestamp"].snapshot.get("hash")
548+
file_content, self._metadata["timestamp"].snapshot.get("hash")
552549
)
553550

554-
intermediate_snapshot = SnapshotWrapper.from_json_object(temp_obj)
551+
intermediate_snapshot = SnapshotWrapper.from_json_object(file_content)
555552

556553
if (
557554
intermediate_snapshot.version
558555
!= self._metadata["timestamp"].snapshot["version"]
559556
):
560-
temp_obj.close()
561557
raise exceptions.BadVersionNumberError
562558

563559
# Check for an arbitrary software attack
@@ -573,15 +569,14 @@ def _verify_snapshot(self, temp_obj: TextIO) -> SnapshotWrapper:
573569
target_role["version"]
574570
!= self._metadata["snapshot"].meta[target_role]["version"]
575571
):
576-
temp_obj.close()
577572
raise exceptions.BadVersionNumberError
578573

579574
intermediate_snapshot.expires()
580575

581576
return intermediate_snapshot
582577

583578
def _verify_targets(
584-
self, temp_obj: TextIO, filename: str, parent_role: str
579+
self, file_content: bytes, filename: str, parent_role: str
585580
) -> TargetsWrapper:
586581
"""
587582
TODO
@@ -590,15 +585,14 @@ def _verify_targets(
590585
# Check against timestamp metadata
591586
if self._metadata["snapshot"].role(filename).get("hash"):
592587
_check_hashes(
593-
temp_obj, self._metadata["snapshot"].targets.get("hash")
588+
file_content, self._metadata["snapshot"].targets.get("hash")
594589
)
595590

596-
intermediate_targets = TargetsWrapper.from_json_object(temp_obj)
591+
intermediate_targets = TargetsWrapper.from_json_object(file_content)
597592
if (
598593
intermediate_targets.version
599594
!= self._metadata["snapshot"].role(filename)["version"]
600595
):
601-
temp_obj.close()
602596
raise exceptions.BadVersionNumberError
603597

604598
# Check for an arbitrary software attack
@@ -612,15 +606,6 @@ def _verify_targets(
612606

613607
return intermediate_targets
614608

615-
@staticmethod
616-
def _verify_target_file(temp_obj: BinaryIO, targetinfo: Dict) -> None:
617-
"""
618-
TODO
619-
"""
620-
621-
_check_file_length(temp_obj, targetinfo["fileinfo"]["length"])
622-
_check_hashes(temp_obj, targetinfo["fileinfo"]["hashes"])
623-
624609
def _preorder_depth_first_walk(self, target_filepath) -> Dict:
625610
"""
626611
TODO
@@ -849,19 +834,34 @@ def _check_file_length(file_object, trusted_file_length):
849834
)
850835

851836

852-
def _check_hashes(file_object, trusted_hashes):
837+
def _check_hashes_obj(file_object, trusted_hashes):
838+
"""
839+
TODO
840+
"""
841+
for algorithm, trusted_hash in trusted_hashes.items():
842+
digest_object = sslib_hash.digest_fileobject(file_object, algorithm)
843+
844+
computed_hash = digest_object.hexdigest()
845+
846+
# Raise an exception if any of the hashes are incorrect.
847+
if trusted_hash != computed_hash:
848+
raise sslib_exceptions.BadHashError(trusted_hash, computed_hash)
849+
850+
logger.info(
851+
"The file's " + algorithm + " hash is" " correct: " + trusted_hash
852+
)
853+
854+
855+
def _check_hashes(file_content, trusted_hashes):
853856
"""
854857
TODO
855858
"""
856859
# Verify each trusted hash of 'trusted_hashes'. If all are valid, simply
857860
# return.
858861
for algorithm, trusted_hash in trusted_hashes.items():
859862
digest_object = sslib_hash.digest(algorithm)
860-
# Ensure we read from the beginning of the file object
861-
# TODO: should we store file position (before the loop) and reset
862-
# after we seek about?
863-
file_object.seek(0)
864-
digest_object.update(file_object.read())
863+
864+
digest_object.update(file_content)
865865
computed_hash = digest_object.hexdigest()
866866

867867
# Raise an exception if any of the hashes are incorrect.

0 commit comments

Comments
 (0)