Skip to content

Commit f55b40a

Browse files
greentimrenardeinsidematthayes
authored
✨ support for files bigger than 1MB in sync (#509)
* files that are larger than the 1MB limit allowed in the dbfs put API v2.0 will use the streaming api calls to upload the content and avoid the http 400 MAX_BLOCK_SIZE_EXCEEDED error * merge with upstream changes * fix test * add changelog --------- Co-authored-by: Ivan Trusov <[email protected]> Co-authored-by: Matt Hayes <[email protected]>
1 parent a37e44e commit f55b40a

File tree

5 files changed

+134
-45
lines changed

5 files changed

+134
-45
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222

2323
- 📌 switch from using `retry` to `tenacity`
2424

25+
### Added
26+
- ✨ support for files bigger than 1MB in sync
27+
2528
## [0.8.8] - 2022-02-22
2629

2730
# Fixed

dbx/sync/clients.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ async def _api(
9898
more_opts = {"ssl": ssl} if ssl is not None else {}
9999
async with session.post(url=url, json=json_data, headers=headers, **more_opts) as resp:
100100
if resp.status in ok_status:
101-
break
101+
return await resp.json()
102102
if resp.status == 429:
103103
dbx_echo("Rate limited")
104104
await _rate_limit_sleep(resp)
@@ -197,14 +197,49 @@ async def put(
197197
path = f"{self.base_path}/{sub_path}"
198198
with open(full_source_path, "rb") as f:
199199
contents = base64.b64encode(f.read()).decode("ascii")
200-
await self._api_put(
201-
api_base_path=self.api_base_path,
202-
path=path,
203-
session=session,
204-
api_token=self.api_token,
205-
contents=contents,
206-
ssl=self.ssl,
207-
)
200+
201+
if len(contents) <= 1024 * 1024:
202+
await self._api_put(
203+
api_base_path=self.api_base_path,
204+
path=path,
205+
session=session,
206+
api_token=self.api_token,
207+
contents=contents,
208+
ssl=self.ssl,
209+
)
210+
else:
211+
dbx_echo(f"Streaming {path}")
212+
213+
resp = await self._api(
214+
url=f"{self.api_base_path}/create",
215+
path=path,
216+
session=session,
217+
api_token=self.api_token,
218+
ssl=self.ssl,
219+
overwrite=True,
220+
)
221+
handle = resp.get("handle")
222+
import textwrap
223+
224+
chunks = textwrap.wrap(contents, 1024 * 1024)
225+
for chunk in chunks:
226+
await self._api(
227+
url=f"{self.api_base_path}/add-block",
228+
path=path,
229+
session=session,
230+
api_token=self.api_token,
231+
ssl=self.ssl,
232+
handle=handle,
233+
data=chunk,
234+
)
235+
await self._api(
236+
url=f"{self.api_base_path}/close",
237+
path=path,
238+
session=session,
239+
api_token=self.api_token,
240+
ssl=self.ssl,
241+
handle=handle,
242+
)
208243

209244

210245
class ReposClient(BaseClient):

tests/unit/sync/clients/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,12 @@ def dummy_file_path() -> str:
1818
with open(file_path, "w") as f:
1919
f.write("yo")
2020
yield file_path
21+
22+
23+
@pytest.fixture
24+
def dummy_file_path_2mb() -> str:
25+
with temporary_directory() as tempdir:
26+
file_path = os.path.join(tempdir, "file")
27+
with open(file_path, "w") as f:
28+
f.write("y" * 1024 * 2048)
29+
yield file_path

tests/unit/sync/clients/test_dbfs_client.py

Lines changed: 69 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import asyncio
22
import base64
3-
from unittest.mock import AsyncMock, MagicMock, PropertyMock
3+
import textwrap
4+
from tests.unit.sync.utils import create_async_with_result
5+
from unittest.mock import AsyncMock, MagicMock, PropertyMock, call
46

57
import pytest
68

@@ -22,7 +24,7 @@ def test_init(client):
2224

2325
def test_delete(client: DBFSClient):
2426
session = MagicMock()
25-
resp = MagicMock()
27+
resp = AsyncMock()
2628
setattr(type(resp), "status", PropertyMock(return_value=200))
2729
session.post.return_value = create_async_with_result(resp)
2830
asyncio.run(client.delete(sub_path="foo/bar", session=session))
@@ -39,7 +41,7 @@ def test_delete_secure(client: DBFSClient):
3941
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False)
4042
client = DBFSClient(base_path="/tmp/foo", config=mock_config)
4143
session = MagicMock()
42-
resp = MagicMock()
44+
resp = AsyncMock()
4345
setattr(type(resp), "status", PropertyMock(return_value=200))
4446
session.post.return_value = create_async_with_result(resp)
4547
asyncio.run(client.delete(sub_path="foo/bar", session=session))
@@ -50,21 +52,6 @@ def test_delete_secure(client: DBFSClient):
5052
assert session.post.call_args[1]["ssl"] is True
5153

5254

53-
def test_delete_secure(client: DBFSClient):
54-
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True)
55-
client = DBFSClient(base_path="/tmp/foo", config=mock_config)
56-
session = MagicMock()
57-
resp = MagicMock()
58-
setattr(type(resp), "status", PropertyMock(return_value=200))
59-
session.post.return_value = create_async_with_result(resp)
60-
asyncio.run(client.delete(sub_path="foo/bar", session=session))
61-
62-
assert session.post.call_count == 1
63-
assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete"
64-
assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"}
65-
assert session.post.call_args[1]["ssl"] is False
66-
67-
6855
def test_delete_backslash(client: DBFSClient):
6956
session = MagicMock()
7057
resp = MagicMock()
@@ -82,7 +69,7 @@ def test_delete_no_path(client: DBFSClient):
8269

8370
def test_delete_recursive(client: DBFSClient):
8471
session = MagicMock()
85-
resp = MagicMock()
72+
resp = AsyncMock()
8673
setattr(type(resp), "status", PropertyMock(return_value=200))
8774
session.post.return_value = create_async_with_result(resp)
8875
asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True))
@@ -98,7 +85,7 @@ def test_delete_rate_limited(client: DBFSClient):
9885
rate_limit_resp = MagicMock()
9986
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
10087

101-
success_resp = MagicMock()
88+
success_resp = AsyncMock()
10289
setattr(type(success_resp), "status", PropertyMock(return_value=200))
10390
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))
10491

@@ -118,7 +105,7 @@ def test_delete_rate_limited_retry_after(client: DBFSClient):
118105
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
119106
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))
120107

121-
success_resp = MagicMock()
108+
success_resp = AsyncMock()
122109
setattr(type(success_resp), "status", PropertyMock(return_value=200))
123110

124111
session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
@@ -146,7 +133,7 @@ def test_delete_unauthorized(client: DBFSClient):
146133

147134
def test_mkdirs(client: DBFSClient):
148135
session = MagicMock()
149-
resp = MagicMock()
136+
resp = AsyncMock()
150137
setattr(type(resp), "status", PropertyMock(return_value=200))
151138
session.post.return_value = create_async_with_result(resp)
152139
asyncio.run(client.mkdirs(sub_path="foo/bar", session=session))
@@ -179,7 +166,7 @@ def test_mkdirs_rate_limited(client: DBFSClient):
179166
rate_limit_resp = MagicMock()
180167
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
181168

182-
success_resp = MagicMock()
169+
success_resp = AsyncMock()
183170
setattr(type(success_resp), "status", PropertyMock(return_value=200))
184171
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))
185172

@@ -199,7 +186,7 @@ def test_mkdirs_rate_limited_retry_after(client: DBFSClient):
199186
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
200187
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))
201188

202-
success_resp = MagicMock()
189+
success_resp = AsyncMock()
203190
setattr(type(success_resp), "status", PropertyMock(return_value=200))
204191

205192
session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
@@ -227,7 +214,7 @@ def test_mkdirs_unauthorized(client: DBFSClient):
227214

228215
def test_put(client: DBFSClient, dummy_file_path: str):
229216
session = MagicMock()
230-
resp = MagicMock()
217+
resp = AsyncMock()
231218
setattr(type(resp), "status", PropertyMock(return_value=200))
232219
session.post.return_value = create_async_with_result(resp)
233220

@@ -244,6 +231,61 @@ def test_put(client: DBFSClient, dummy_file_path: str):
244231
assert is_dbfs_user_agent(session.post.call_args[1]["headers"]["user-agent"])
245232

246233

234+
def test_put_max_block_size_exceeded(client: DBFSClient, dummy_file_path_2mb: str):
235+
expected_handle = 1234
236+
237+
async def mock_json(*args, **kwargs):
238+
return {"handle": expected_handle}
239+
240+
def mock_post(url, *args, **kwargs):
241+
resp = AsyncMock()
242+
setattr(type(resp), "status", PropertyMock(return_value=200))
243+
if "/api/2.0/dbfs/put" in url:
244+
contents = kwargs.get("json").get("contents")
245+
if len(contents) > 1024 * 1024: # replicate the api error thrown when contents exceeds max allowed
246+
setattr(type(resp), "status", PropertyMock(return_value=400))
247+
elif "/api/2.0/dbfs/create" in url:
248+
# return a mock response json
249+
resp.json = MagicMock(side_effect=mock_json)
250+
251+
return create_async_with_result(resp)
252+
253+
session = AsyncMock()
254+
post = MagicMock(side_effect=mock_post)
255+
session.post = post
256+
257+
asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path_2mb, session=session))
258+
259+
with open(dummy_file_path_2mb, "r") as f:
260+
expected_contents = f.read()
261+
262+
chunks = textwrap.wrap(base64.b64encode(bytes(expected_contents, encoding="utf8")).decode("ascii"), 1024 * 1024)
263+
264+
assert session.post.call_count == len(chunks) + 2
265+
assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/create"
266+
assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block"
267+
assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block"
268+
assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block"
269+
assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/close"
270+
271+
assert session.post.call_args_list[0][1]["json"] == {
272+
"path": "dbfs:/tmp/foo/foo/bar",
273+
"overwrite": True,
274+
}
275+
276+
for i, chunk in enumerate(chunks):
277+
assert session.post.call_args_list[i + 1][1]["json"] == {
278+
"data": chunk,
279+
"path": "dbfs:/tmp/foo/foo/bar",
280+
"handle": expected_handle,
281+
}, f"invalid json for chunk {i}"
282+
283+
assert session.post.call_args_list[4][1]["json"] == {
284+
"path": "dbfs:/tmp/foo/foo/bar",
285+
"handle": expected_handle,
286+
}
287+
288+
247289
def test_put_backslash(client: DBFSClient, dummy_file_path: str):
248290
session = MagicMock()
249291
resp = MagicMock()
@@ -267,7 +309,7 @@ def test_put_rate_limited(client: DBFSClient, dummy_file_path: str):
267309
rate_limit_resp = MagicMock()
268310
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
269311

270-
success_resp = MagicMock()
312+
success_resp = AsyncMock()
271313
setattr(type(success_resp), "status", PropertyMock(return_value=200))
272314
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))
273315

@@ -291,7 +333,7 @@ def test_put_rate_limited_retry_after(client: DBFSClient, dummy_file_path: str):
291333
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
292334
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))
293335

294-
success_resp = MagicMock()
336+
success_resp = AsyncMock()
295337
setattr(type(success_resp), "status", PropertyMock(return_value=200))
296338

297339
session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]

tests/unit/sync/clients/test_repos_client.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_init(mock_config):
3333

3434
def test_delete(client: ReposClient):
3535
session = MagicMock()
36-
resp = MagicMock()
36+
resp = AsyncMock()
3737
setattr(type(resp), "status", PropertyMock(return_value=200))
3838
session.post.return_value = create_async_with_result(resp)
3939
asyncio.run(client.delete(sub_path="foo/bar", session=session))
@@ -50,7 +50,7 @@ def test_delete_secure(client: ReposClient):
5050
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False)
5151
client = ReposClient(user="[email protected]", repo_name="my-repo", config=mock_config)
5252
session = MagicMock()
53-
resp = MagicMock()
53+
resp = AsyncMock()
5454
setattr(type(resp), "status", PropertyMock(return_value=200))
5555
session.post.return_value = create_async_with_result(resp)
5656
asyncio.run(client.delete(sub_path="foo/bar", session=session))
@@ -65,7 +65,7 @@ def test_delete_insecure(client: ReposClient):
6565
mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True)
6666
client = ReposClient(user="[email protected]", repo_name="my-repo", config=mock_config)
6767
session = MagicMock()
68-
resp = MagicMock()
68+
resp = AsyncMock()
6969
setattr(type(resp), "status", PropertyMock(return_value=200))
7070
session.post.return_value = create_async_with_result(resp)
7171
asyncio.run(client.delete(sub_path="foo/bar", session=session))
@@ -93,7 +93,7 @@ def test_delete_no_path(client: ReposClient):
9393

9494
def test_delete_recursive(client: ReposClient):
9595
session = MagicMock()
96-
resp = MagicMock()
96+
resp = AsyncMock()
9797
setattr(type(resp), "status", PropertyMock(return_value=200))
9898
session.post.return_value = create_async_with_result(resp)
9999
asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True))
@@ -109,7 +109,7 @@ def test_delete_rate_limited(client: ReposClient):
109109
rate_limit_resp = MagicMock()
110110
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
111111

112-
success_resp = MagicMock()
112+
success_resp = AsyncMock()
113113
setattr(type(success_resp), "status", PropertyMock(return_value=200))
114114
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))
115115

@@ -129,7 +129,7 @@ def test_delete_rate_limited_retry_after(client: ReposClient):
129129
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
130130
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))
131131

132-
success_resp = MagicMock()
132+
success_resp = AsyncMock()
133133
setattr(type(success_resp), "status", PropertyMock(return_value=200))
134134

135135
session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]
@@ -157,7 +157,7 @@ def test_delete_unauthorized(client: ReposClient):
157157

158158
def test_mkdirs(client: ReposClient):
159159
session = MagicMock()
160-
resp = MagicMock()
160+
resp = AsyncMock()
161161
setattr(type(resp), "status", PropertyMock(return_value=200))
162162
session.post.return_value = create_async_with_result(resp)
163163
asyncio.run(client.mkdirs(sub_path="foo/bar", session=session))
@@ -190,7 +190,7 @@ def test_mkdirs_rate_limited(client: ReposClient):
190190
rate_limit_resp = MagicMock()
191191
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
192192

193-
success_resp = MagicMock()
193+
success_resp = AsyncMock()
194194
setattr(type(success_resp), "status", PropertyMock(return_value=200))
195195
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None}))
196196

@@ -210,7 +210,7 @@ def test_mkdirs_rate_limited_retry_after(client: ReposClient):
210210
setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429))
211211
setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1}))
212212

213-
success_resp = MagicMock()
213+
success_resp = AsyncMock()
214214
setattr(type(success_resp), "status", PropertyMock(return_value=200))
215215

216216
session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]

0 commit comments

Comments
 (0)