11
11
import requests
12
12
13
13
from fsspec .asyn import AsyncFileSystem , sync , sync_wrapper
14
+ from fsspec .callbacks import _DEFAULT_CALLBACK
14
15
from fsspec .exceptions import FSTimeoutError
15
16
from fsspec .spec import AbstractBufferedFile
16
17
from fsspec .utils import DEFAULT_BLOCK_SIZE , tokenize
@@ -58,6 +59,7 @@ def __init__(
58
59
asynchronous = False ,
59
60
loop = None ,
60
61
client_kwargs = None ,
62
+ get_client = get_client ,
61
63
** storage_options ,
62
64
):
63
65
"""
@@ -79,6 +81,10 @@ def __init__(
79
81
Passed to aiohttp.ClientSession, see
80
82
https://docs.aiohttp.org/en/stable/client_reference.html
81
83
For example, ``{'auth': aiohttp.BasicAuth('user', 'pass')}``
84
+ get_client: Callable[..., aiohttp.ClientSession]
85
+ A callable which takes keyword arguments and constructs
86
+ an aiohttp.ClientSession. It's state will be managed by
87
+ the HTTPFileSystem class.
82
88
storage_options: key-value
83
89
Any other parameters passed on to requests
84
90
cache_type, cache_options: defaults used in open
@@ -90,6 +96,7 @@ def __init__(
90
96
self .cache_type = cache_type
91
97
self .cache_options = cache_options
92
98
self .client_kwargs = client_kwargs or {}
99
+ self .get_client = get_client
93
100
self .kwargs = storage_options
94
101
self ._session = None
95
102
@@ -121,7 +128,7 @@ def close_session(loop, session):
121
128
122
129
async def set_session (self ):
123
130
if self ._session is None :
124
- self ._session = await get_client (loop = self .loop , ** self .client_kwargs )
131
+ self ._session = await self . get_client (loop = self .loop , ** self .client_kwargs )
125
132
if not self .asynchronous :
126
133
weakref .finalize (self , self .close_session , self .loop , self ._session )
127
134
return self ._session
@@ -223,18 +230,61 @@ async def _cat_file(self, url, start=None, end=None, **kwargs):
223
230
self ._raise_not_found_for_status (r , url )
224
231
return out
225
232
226
- async def _get_file (self , rpath , lpath , chunk_size = 5 * 2 ** 20 , ** kwargs ):
233
+ async def _get_file (
234
+ self , rpath , lpath , chunk_size = 5 * 2 ** 20 , callback = _DEFAULT_CALLBACK , ** kwargs
235
+ ):
227
236
kw = self .kwargs .copy ()
228
237
kw .update (kwargs )
229
238
logger .debug (rpath )
230
239
session = await self .set_session ()
231
240
async with session .get (rpath , ** self .kwargs ) as r :
241
+ try :
242
+ size = int (r .headers ["content-length" ])
243
+ except (ValueError , KeyError ):
244
+ size = None
245
+
246
+ callback .set_size (size )
232
247
self ._raise_not_found_for_status (r , rpath )
233
248
with open (lpath , "wb" ) as fd :
234
249
chunk = True
235
250
while chunk :
236
251
chunk = await r .content .read (chunk_size )
237
252
fd .write (chunk )
253
+ callback .relative_update (len (chunk ))
254
+
255
+ async def _put_file (
256
+ self ,
257
+ rpath ,
258
+ lpath ,
259
+ chunk_size = 5 * 2 ** 20 ,
260
+ callback = _DEFAULT_CALLBACK ,
261
+ method = "post" ,
262
+ ** kwargs ,
263
+ ):
264
+ async def gen_chunks ():
265
+ with open (rpath , "rb" ) as f :
266
+ callback .set_size (f .seek (0 , 2 ))
267
+ f .seek (0 )
268
+
269
+ chunk = f .read (64 * 1024 )
270
+ while chunk :
271
+ yield chunk
272
+ callback .relative_update (len (chunk ))
273
+ chunk = f .read (64 * 1024 )
274
+
275
+ kw = self .kwargs .copy ()
276
+ kw .update (kwargs )
277
+ session = await self .set_session ()
278
+
279
+ method = method .lower ()
280
+ if method not in ("post" , "put" ):
281
+ raise ValueError (
282
+ f"method has to be either 'post' or 'put', not: { method !r} "
283
+ )
284
+
285
+ meth = getattr (session , method )
286
+ async with meth (lpath , data = gen_chunks (), ** kw ) as resp :
287
+ self ._raise_not_found_for_status (resp , lpath )
238
288
239
289
async def _exists (self , path , ** kwargs ):
240
290
kw = self .kwargs .copy ()
@@ -316,22 +366,29 @@ async def _info(self, url, **kwargs):
316
366
which case size will be given as None (and certain operations on the
317
367
corresponding file will not work).
318
368
"""
319
- size = False
320
- kw = self .kwargs . copy ()
321
- kw . update ( kwargs )
369
+ info = {}
370
+ session = await self .set_session ()
371
+
322
372
for policy in ["head" , "get" ]:
323
373
try :
324
- session = await self .set_session ()
325
- size = await _file_size (url , size_policy = policy , session = session , ** kw )
326
- if size :
374
+ info .update (
375
+ await _file_info (
376
+ url ,
377
+ size_policy = policy ,
378
+ session = session ,
379
+ ** self .kwargs ,
380
+ ** kwargs ,
381
+ )
382
+ )
383
+ if info .get ("size" ) is not None :
327
384
break
328
- except Exception as e :
329
- logger . debug (( str ( e )))
330
- else :
331
- # get failed, so conclude URL does not exist
332
- if size is False :
333
- raise FileNotFoundError ( url )
334
- return {"name" : url , "size" : size or None , "type" : "file" }
385
+ except Exception as exc :
386
+ if policy == "get" :
387
+ # If get failed, then raise a FileNotFoundError
388
+ raise FileNotFoundError ( url ) from exc
389
+ logger . debug ( str ( exc ))
390
+
391
+ return {"name" : url , "size" : None , ** info , "type" : "file" }
335
392
336
393
async def _glob (self , path , ** kwargs ):
337
394
"""
@@ -613,6 +670,7 @@ def __init__(self, fs, url, mode="rb", loop=None, session=None, **kwargs):
613
670
614
671
async def cor ():
615
672
r = await self .session .get (url , ** kwargs ).__aenter__ ()
673
+ self .fs ._raise_not_found_for_status (r , url )
616
674
return r
617
675
618
676
self .r = sync (self .loop , cor )
@@ -654,8 +712,8 @@ async def get_range(session, url, start, end, file=None, **kwargs):
654
712
return out
655
713
656
714
657
- async def _file_size (url , session = None , size_policy = "head" , ** kwargs ):
658
- """Call HEAD on the server to get file size
715
+ async def _file_info (url , session , size_policy = "head" , ** kwargs ):
716
+ """Call HEAD on the server to get details about the file ( size/checksum etc.)
659
717
660
718
Default operation is to explicitly allow redirects and use encoding
661
719
'identity' (no compression) to get the true size of the target.
@@ -666,29 +724,38 @@ async def _file_size(url, session=None, size_policy="head", **kwargs):
666
724
head = kwargs .get ("headers" , {}).copy ()
667
725
head ["Accept-Encoding" ] = "identity"
668
726
kwargs ["headers" ] = head
669
- session = session or await get_client ()
727
+
728
+ info = {}
670
729
if size_policy == "head" :
671
730
r = await session .head (url , allow_redirects = ar , ** kwargs )
672
731
elif size_policy == "get" :
673
732
r = await session .get (url , allow_redirects = ar , ** kwargs )
674
733
else :
675
734
raise TypeError ('size_policy must be "head" or "get", got %s' "" % size_policy )
676
735
async with r :
677
- try :
678
- r .raise_for_status ()
736
+ r .raise_for_status ()
737
+
738
+ # TODO:
739
+ # recognise lack of 'Accept-Ranges',
740
+ # or 'Accept-Ranges': 'none' (not 'bytes')
741
+ # to mean streaming only, no random access => return None
742
+ if "Content-Length" in r .headers :
743
+ info ["size" ] = int (r .headers ["Content-Length" ])
744
+ elif "Content-Range" in r .headers :
745
+ info ["size" ] = int (r .headers ["Content-Range" ].split ("/" )[1 ])
746
+
747
+ for checksum_field in ["ETag" , "Content-MD5" , "Digest" ]:
748
+ if r .headers .get (checksum_field ):
749
+ info [checksum_field ] = r .headers [checksum_field ]
750
+
751
+ return info
752
+
679
753
680
- # TODO:
681
- # recognise lack of 'Accept-Ranges',
682
- # or 'Accept-Ranges': 'none' (not 'bytes')
683
- # to mean streaming only, no random access => return None
684
- if "Content-Length" in r .headers :
685
- return int (r .headers ["Content-Length" ])
686
- elif "Content-Range" in r .headers :
687
- return int (r .headers ["Content-Range" ].split ("/" )[1 ])
688
- except aiohttp .ClientResponseError :
689
- logger .debug ("Error retrieving file size" )
690
- return None
691
- r .close ()
754
+ async def _file_size (url , session = None , * args , ** kwargs ):
755
+ if session is None :
756
+ session = await get_client ()
757
+ info = await _file_info (url , session = session , * args , ** kwargs )
758
+ return info .get ("size" )
692
759
693
760
694
761
file_size = sync_wrapper (_file_size )
0 commit comments