Skip to content

Commit dff54fd

Browse files
committed
[client rework]: Modify fetch as a context manager
Experiment with @contextmanager decorator on RequestsFetcher.fetch() in order to avoid unclosed connections. Signed-off-by: Teodora Sechkova <[email protected]>
1 parent f7f5d6a commit dff54fd

File tree

2 files changed

+98
-96
lines changed

2 files changed

+98
-96
lines changed

tuf/download.py

+35-35
Original file line numberDiff line numberDiff line change
@@ -195,42 +195,42 @@ def _download_file(url, required_length, fetcher, STRICT_REQUIRED_LENGTH=True):
195195
average_download_speed = 0
196196
number_of_bytes_received = 0
197197

198-
try:
199-
chunks = fetcher.fetch(url, required_length)
200-
start_time = timeit.default_timer()
201-
for chunk in chunks:
198+
with fetcher.fetch(url, required_length) as chunks:
199+
try:
200+
start_time = timeit.default_timer()
201+
for chunk in chunks:
202+
203+
stop_time = timeit.default_timer()
204+
temp_file.write(chunk)
205+
206+
# Measure the average download speed.
207+
number_of_bytes_received += len(chunk)
208+
seconds_spent_receiving = stop_time - start_time
209+
average_download_speed = number_of_bytes_received / seconds_spent_receiving
210+
211+
if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
212+
logger.debug('The average download speed dropped below the minimum'
213+
' average download speed set in tuf.settings.py. Stopping the'
214+
' download!')
215+
break
216+
217+
else:
218+
logger.debug('The average download speed has not dipped below the'
219+
' minimum average download speed set in tuf.settings.py.')
220+
221+
# Does the total number of downloaded bytes match the required length?
222+
_check_downloaded_length(number_of_bytes_received, required_length,
223+
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
224+
average_download_speed=average_download_speed)
225+
226+
except Exception:
227+
# Close 'temp_file'. Any written data is lost.
228+
temp_file.close()
229+
logger.debug('Could not download URL: ' + repr(url))
230+
raise
202231

203-
stop_time = timeit.default_timer()
204-
temp_file.write(chunk)
205-
206-
# Measure the average download speed.
207-
number_of_bytes_received += len(chunk)
208-
seconds_spent_receiving = stop_time - start_time
209-
average_download_speed = number_of_bytes_received / seconds_spent_receiving
210-
211-
if average_download_speed < tuf.settings.MIN_AVERAGE_DOWNLOAD_SPEED:
212-
logger.debug('The average download speed dropped below the minimum'
213-
' average download speed set in tuf.settings.py. Stopping the'
214-
' download!')
215-
break
216-
217-
else:
218-
logger.debug('The average download speed has not dipped below the'
219-
' minimum average download speed set in tuf.settings.py.')
220-
221-
# Does the total number of downloaded bytes match the required length?
222-
_check_downloaded_length(number_of_bytes_received, required_length,
223-
STRICT_REQUIRED_LENGTH=STRICT_REQUIRED_LENGTH,
224-
average_download_speed=average_download_speed)
225-
226-
except Exception:
227-
# Close 'temp_file'. Any written data is lost.
228-
temp_file.close()
229-
logger.debug('Could not download URL: ' + repr(url))
230-
raise
231-
232-
else:
233-
return temp_file
232+
else:
233+
return temp_file
234234

235235

236236

tuf/requests_fetcher.py

+63-61
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import six
1111
import logging
1212
import time
13+
from contextlib import contextmanager
1314

1415
import urllib3.exceptions
1516

@@ -47,7 +48,7 @@ def __init__(self):
4748
# minimize subtle security issues. Some cookies may not be HTTP-safe.
4849
self._sessions = {}
4950

50-
51+
@contextmanager
5152
def fetch(self, url, required_length):
5253
"""Fetches the contents of HTTP/HTTPS url from a remote server.
5354
@@ -64,70 +65,71 @@ def fetch(self, url, required_length):
6465
Returns:
6566
A bytes iterator
6667
"""
67-
# Get a customized session for each new schema+hostname combination.
68-
session = self._get_session(url)
69-
70-
# Get the requests.Response object for this URL.
71-
#
72-
# Defer downloading the response body with stream=True.
73-
# Always set the timeout. This timeout value is interpreted by requests as:
74-
# - connect timeout (max delay before first byte is received)
75-
# - read (gap) timeout (max delay between bytes received)
76-
response = session.get(url, stream=True,
77-
timeout=tuf.settings.SOCKET_TIMEOUT)
78-
# Check response status.
7968
try:
80-
response.raise_for_status()
81-
except requests.HTTPError as e:
82-
status = e.response.status_code
83-
raise tuf.exceptions.FetcherHTTPError(str(e), status)
84-
85-
86-
# Define a generator function to be returned by fetch. This way the caller
87-
# of fetch can differentiate between connection and actual data download
88-
# and measure download times accordingly.
89-
def chunks():
69+
# Get a customized session for each new schema+hostname combination.
70+
session = self._get_session(url)
71+
72+
# Get the requests.Response object for this URL.
73+
#
74+
# Defer downloading the response body with stream=True.
75+
# Always set the timeout. This timeout value is interpreted by requests as:
76+
# - connect timeout (max delay before first byte is received)
77+
# - read (gap) timeout (max delay between bytes received)
78+
response = session.get(url, stream=True,
79+
timeout=tuf.settings.SOCKET_TIMEOUT)
80+
# Check response status.
9081
try:
91-
bytes_received = 0
92-
while True:
93-
# We download a fixed chunk of data in every round. This is so that we
94-
# can defend against slow retrieval attacks. Furthermore, we do not
95-
# wish to download an extremely large file in one shot.
96-
# Before beginning the round, sleep (if set) for a short amount of
97-
# time so that the CPU is not hogged in the while loop.
98-
if tuf.settings.SLEEP_BEFORE_ROUND:
99-
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)
100-
101-
read_amount = min(
102-
tuf.settings.CHUNK_SIZE, required_length - bytes_received)
103-
104-
# NOTE: This may not handle some servers adding a Content-Encoding
105-
# header, which may cause urllib3 to misbehave:
106-
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
107-
data = response.raw.read(read_amount)
108-
bytes_received += len(data)
109-
110-
# We might have no more data to read. Check number of bytes downloaded.
111-
if not data:
112-
logger.debug('Downloaded ' + repr(bytes_received) + '/' +
113-
repr(required_length) + ' bytes.')
114-
115-
# Finally, we signal that the download is complete.
116-
break
117-
118-
yield data
119-
120-
if bytes_received >= required_length:
121-
break
122-
123-
except urllib3.exceptions.ReadTimeoutError as e:
124-
raise tuf.exceptions.SlowRetrievalError(str(e))
125-
126-
finally:
82+
response.raise_for_status()
83+
except requests.HTTPError as e:
84+
status = e.response.status_code
85+
raise tuf.exceptions.FetcherHTTPError(str(e), status)
86+
87+
88+
# Define a generator function to be returned by fetch. This way the caller
89+
# of fetch can differentiate between connection and actual data download
90+
# and measure download times accordingly.
91+
def chunks():
92+
try:
93+
bytes_received = 0
94+
while True:
95+
# We download a fixed chunk of data in every round. This is so that we
96+
# can defend against slow retrieval attacks. Furthermore, we do not wish
97+
# to download an extremely large file in one shot.
98+
# Before beginning the round, sleep (if set) for a short amount of time
99+
# so that the CPU is not hogged in the while loop.
100+
if tuf.settings.SLEEP_BEFORE_ROUND:
101+
time.sleep(tuf.settings.SLEEP_BEFORE_ROUND)
102+
103+
read_amount = min(
104+
tuf.settings.CHUNK_SIZE, required_length - bytes_received)
105+
106+
# NOTE: This may not handle some servers adding a Content-Encoding
107+
# header, which may cause urllib3 to misbehave:
108+
# https://github.com/pypa/pip/blob/404838abcca467648180b358598c597b74d568c9/src/pip/_internal/download.py#L547-L582
109+
data = response.raw.read(read_amount)
110+
bytes_received += len(data)
111+
112+
# We might have no more data to read. Check number of bytes downloaded.
113+
if not data:
114+
logger.debug('Downloaded ' + repr(bytes_received) + '/' +
115+
repr(required_length) + ' bytes.')
116+
117+
# Finally, we signal that the download is complete.
118+
break
119+
120+
yield data
121+
122+
if bytes_received >= required_length:
123+
break
124+
125+
except urllib3.exceptions.ReadTimeoutError as e:
126+
raise tuf.exceptions.SlowRetrievalError(str(e))
127+
128+
yield chunks()
129+
130+
finally:
127131
response.close()
128132

129-
return chunks()
130-
131133

132134

133135
def _get_session(self, url):

0 commit comments

Comments
 (0)