diff --git a/CHANGELOG.md b/CHANGELOG.md index 52e7fadf..e360b92f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - 📌 switch from using `retry` to `tenacity` +### Added +- ✨ support for files bigger than 1MB in sync + ## [0.8.8] - 2022-02-22 # Fixed diff --git a/dbx/sync/clients.py b/dbx/sync/clients.py index 4358fbcf..b7a1b955 100644 --- a/dbx/sync/clients.py +++ b/dbx/sync/clients.py @@ -98,7 +98,7 @@ async def _api( more_opts = {"ssl": ssl} if ssl is not None else {} async with session.post(url=url, json=json_data, headers=headers, **more_opts) as resp: if resp.status in ok_status: - break + return await resp.json() if resp.status == 429: dbx_echo("Rate limited") await _rate_limit_sleep(resp) @@ -197,14 +197,49 @@ async def put( path = f"{self.base_path}/{sub_path}" with open(full_source_path, "rb") as f: contents = base64.b64encode(f.read()).decode("ascii") - await self._api_put( - api_base_path=self.api_base_path, - path=path, - session=session, - api_token=self.api_token, - contents=contents, - ssl=self.ssl, - ) + + if len(contents) <= 1024 * 1024: + await self._api_put( + api_base_path=self.api_base_path, + path=path, + session=session, + api_token=self.api_token, + contents=contents, + ssl=self.ssl, + ) + else: + dbx_echo(f"Streaming {path}") + + resp = await self._api( + url=f"{self.api_base_path}/create", + path=path, + session=session, + api_token=self.api_token, + ssl=self.ssl, + overwrite=True, + ) + handle = resp.get("handle") + import textwrap + + chunks = textwrap.wrap(contents, 1024 * 1024) + for chunk in chunks: + await self._api( + url=f"{self.api_base_path}/add-block", + path=path, + session=session, + api_token=self.api_token, + ssl=self.ssl, + handle=handle, + data=chunk, + ) + await self._api( + url=f"{self.api_base_path}/close", + path=path, + session=session, + api_token=self.api_token, + ssl=self.ssl, + handle=handle, + ) class ReposClient(BaseClient): diff --git a/tests/unit/sync/clients/conftest.py b/tests/unit/sync/clients/conftest.py index b929906f..1baedd76 100644 --- a/tests/unit/sync/clients/conftest.py +++ b/tests/unit/sync/clients/conftest.py @@ -18,3 +18,12 @@ def dummy_file_path() -> str: with open(file_path, "w") as f: f.write("yo") yield file_path + + +@pytest.fixture +def dummy_file_path_2mb() -> str: + with temporary_directory() as tempdir: + file_path = os.path.join(tempdir, "file") + with open(file_path, "w") as f: + f.write("y" * 1024 * 2048) + yield file_path diff --git a/tests/unit/sync/clients/test_dbfs_client.py b/tests/unit/sync/clients/test_dbfs_client.py index ac1656a6..e983390f 100644 --- a/tests/unit/sync/clients/test_dbfs_client.py +++ b/tests/unit/sync/clients/test_dbfs_client.py @@ -1,6 +1,8 @@ import asyncio import base64 -from unittest.mock import AsyncMock, MagicMock, PropertyMock +import textwrap +from tests.unit.sync.utils import create_async_with_result +from unittest.mock import AsyncMock, MagicMock, PropertyMock, call import pytest @@ -22,7 +24,7 @@ def test_init(client): def test_delete(client: DBFSClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -39,7 +41,7 @@ def test_delete_secure(client: DBFSClient): mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False) client = DBFSClient(base_path="/tmp/foo", config=mock_config) session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -50,21 +52,6 @@ def test_delete_secure(client: DBFSClient): assert session.post.call_args[1]["ssl"] is True -def test_delete_secure(client: DBFSClient): - mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True) - client = DBFSClient(base_path="/tmp/foo", config=mock_config) - session = MagicMock() - resp = MagicMock() - setattr(type(resp), "status", PropertyMock(return_value=200)) - session.post.return_value = create_async_with_result(resp) - asyncio.run(client.delete(sub_path="foo/bar", session=session)) - - assert session.post.call_count == 1 - assert session.post.call_args[1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/delete" - assert session.post.call_args[1]["json"] == {"path": "dbfs:/tmp/foo/foo/bar"} - assert session.post.call_args[1]["ssl"] is False - - def test_delete_backslash(client: DBFSClient): session = MagicMock() resp = MagicMock() @@ -82,7 +69,7 @@ def test_delete_no_path(client: DBFSClient): def test_delete_recursive(client: DBFSClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True)) @@ -98,7 +85,7 @@ def test_delete_rate_limited(client: DBFSClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -118,7 +105,7 @@ def test_delete_rate_limited_retry_after(client: DBFSClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] @@ -146,7 +133,7 @@ def test_delete_unauthorized(client: DBFSClient): def test_mkdirs(client: DBFSClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) @@ -179,7 +166,7 @@ def test_mkdirs_rate_limited(client: DBFSClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -199,7 +186,7 @@ def test_mkdirs_rate_limited_retry_after(client: DBFSClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] @@ -227,7 +214,7 @@ def test_mkdirs_unauthorized(client: DBFSClient): def test_put(client: DBFSClient, dummy_file_path: str): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) @@ -244,6 +231,61 @@ def test_put(client: DBFSClient, dummy_file_path: str): assert is_dbfs_user_agent(session.post.call_args[1]["headers"]["user-agent"]) +def test_put_max_block_size_exceeded(client: DBFSClient, dummy_file_path_2mb: str): + expected_handle = 1234 + + async def mock_json(*args, **kwargs): + return {"handle": expected_handle} + + def mock_post(url, *args, **kwargs): + resp = AsyncMock() + setattr(type(resp), "status", PropertyMock(return_value=200)) + if "/api/2.0/dbfs/put" in url: + contents = kwargs.get("json").get("contents") + if len(contents) > 1024 * 1024: # replicate the api error thrown when contents exceeds max allowed + setattr(type(resp), "status", PropertyMock(return_value=400)) + elif "/api/2.0/dbfs/create" in url: + # return a mock response json + resp.json = MagicMock(side_effect=mock_json) + + return create_async_with_result(resp) + + session = AsyncMock() + post = MagicMock(side_effect=mock_post) + session.post = post + + asyncio.run(client.put(sub_path="foo/bar", full_source_path=dummy_file_path_2mb, session=session)) + + with open(dummy_file_path_2mb, "r") as f: + expected_contents = f.read() + + chunks = textwrap.wrap(base64.b64encode(bytes(expected_contents, encoding="utf8")).decode("ascii"), 1024 * 1024) + + assert session.post.call_count == len(chunks) + 2 + assert session.post.call_args_list[0][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/create" + assert session.post.call_args_list[1][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block" + assert session.post.call_args_list[2][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block" + assert session.post.call_args_list[3][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/add-block" + assert session.post.call_args_list[4][1]["url"] == "http://fakehost.asdf/api/2.0/dbfs/close" + + assert session.post.call_args_list[0][1]["json"] == { + "path": "dbfs:/tmp/foo/foo/bar", + "overwrite": True, + } + + for i, chunk in enumerate(chunks): + assert session.post.call_args_list[i + 1][1]["json"] == { + "data": chunk, + "path": "dbfs:/tmp/foo/foo/bar", + "handle": expected_handle, + }, f"invalid json for chunk {i}" + + assert session.post.call_args_list[4][1]["json"] == { + "path": "dbfs:/tmp/foo/foo/bar", + "handle": expected_handle, + } + + def test_put_backslash(client: DBFSClient, dummy_file_path: str): session = MagicMock() resp = MagicMock() @@ -267,7 +309,7 @@ def test_put_rate_limited(client: DBFSClient, dummy_file_path: str): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -291,7 +333,7 @@ def test_put_rate_limited_retry_after(client: DBFSClient, dummy_file_path: str): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] diff --git a/tests/unit/sync/clients/test_repos_client.py b/tests/unit/sync/clients/test_repos_client.py index 947bc6ff..ed2c033e 100644 --- a/tests/unit/sync/clients/test_repos_client.py +++ b/tests/unit/sync/clients/test_repos_client.py @@ -33,7 +33,7 @@ def test_init(mock_config): def test_delete(client: ReposClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -50,7 +50,7 @@ def test_delete_secure(client: ReposClient): mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=False) client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -65,7 +65,7 @@ def test_delete_insecure(client: ReposClient): mock_config = mocked_props(token="fake-token", host="http://fakehost.asdf/", insecure=True) client = ReposClient(user="foo@somewhere.com", repo_name="my-repo", config=mock_config) session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session)) @@ -93,7 +93,7 @@ def test_delete_no_path(client: ReposClient): def test_delete_recursive(client: ReposClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.delete(sub_path="foo/bar", session=session, recursive=True)) @@ -109,7 +109,7 @@ def test_delete_rate_limited(client: ReposClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -129,7 +129,7 @@ def test_delete_rate_limited_retry_after(client: ReposClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)] @@ -157,7 +157,7 @@ def test_delete_unauthorized(client: ReposClient): def test_mkdirs(client: ReposClient): session = MagicMock() - resp = MagicMock() + resp = AsyncMock() setattr(type(resp), "status", PropertyMock(return_value=200)) session.post.return_value = create_async_with_result(resp) asyncio.run(client.mkdirs(sub_path="foo/bar", session=session)) @@ -190,7 +190,7 @@ def test_mkdirs_rate_limited(client: ReposClient): rate_limit_resp = MagicMock() setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": None})) @@ -210,7 +210,7 @@ def test_mkdirs_rate_limited_retry_after(client: ReposClient): setattr(type(rate_limit_resp), "status", PropertyMock(return_value=429)) setattr(type(rate_limit_resp), "headers", PropertyMock(return_value={"Retry-After": 1})) - success_resp = MagicMock() + success_resp = AsyncMock() setattr(type(success_resp), "status", PropertyMock(return_value=200)) session.post.side_effect = [create_async_with_result(rate_limit_resp), create_async_with_result(success_resp)]