Skip to content

Commit 2f882a5

Browse files
authored
Merge pull request #36 from jumpstarter-dev/opendal
Implement file access with opendal
2 parents b61ab5c + b970e2c commit 2f882a5

File tree

13 files changed

+347
-43
lines changed

13 files changed

+347
-43
lines changed

jumpstarter/common/aiohttp.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass
2+
3+
from aiohttp import StreamReader
4+
from anyio import EndOfStream
5+
from anyio.abc import ByteReceiveStream
6+
7+
8+
@dataclass(kw_only=True)
9+
class AiohttpStream(ByteReceiveStream):
10+
stream: StreamReader
11+
12+
async def receive(self, max_bytes=65536):
13+
data = await self.stream.read(n=max_bytes)
14+
if len(data) == 0:
15+
raise EndOfStream
16+
return data
17+
18+
async def aclose(self):
19+
pass

jumpstarter/common/opendal.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from dataclasses import dataclass
2+
3+
from anyio import EndOfStream
4+
from anyio.abc import ByteReceiveStream
5+
from opendal import AsyncFile
6+
7+
8+
@dataclass(kw_only=True)
9+
class AsyncFileStream(ByteReceiveStream):
10+
file: AsyncFile
11+
12+
async def receive(self, max_bytes=65536):
13+
data = await self.file.read(size=max_bytes)
14+
if len(data) == 0:
15+
raise EndOfStream
16+
return data
17+
18+
async def aclose(self):
19+
await self.file.close()

jumpstarter/drivers/base.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
"""
44

55
from abc import ABCMeta, abstractmethod
6+
from contextlib import asynccontextmanager
67
from dataclasses import dataclass, field
78
from typing import Any
89
from uuid import UUID, uuid4
910

11+
import aiohttp
1012
from anyio.from_thread import BlockingPortal
1113
from grpc import StatusCode
1214

1315
from jumpstarter.common import Metadata
16+
from jumpstarter.common.aiohttp import AiohttpStream
1417
from jumpstarter.common.streams import (
1518
create_memory_stream,
1619
forward_server_stream,
@@ -22,6 +25,8 @@
2225
MARKER_STREAMCALL,
2326
MARKER_STREAMING_DRIVERCALL,
2427
)
28+
from jumpstarter.drivers.resources import ClientStreamResource, PresignedRequestResource, Resource
29+
from jumpstarter.drivers.streams import DriverStreamRequest, ResourceStreamRequest, StreamRequest
2530
from jumpstarter.v1 import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
2631

2732

@@ -83,14 +88,16 @@ async def Stream(self, request_iterator, context):
8388
"""
8489
metadata = dict(context.invocation_metadata())
8590

86-
match metadata["kind"]:
87-
case "connect":
88-
method = await self.__lookup_drivercall(metadata["method"], context, MARKER_STREAMCALL)
91+
request = StreamRequest.validate_json(metadata["request"], strict=True)
92+
93+
match request:
94+
case DriverStreamRequest(method=driver_method):
95+
method = await self.__lookup_drivercall(driver_method, context, MARKER_STREAMCALL)
8996

9097
async for v in method(request_iterator, context):
9198
yield v
9299

93-
case "resource":
100+
case ResourceStreamRequest():
94101
remote, resource = create_memory_stream()
95102

96103
resource_uuid = uuid4()
@@ -133,6 +140,16 @@ def items(self, parent=None):
133140

134141
return [(self.uuid, parent.uuid if parent else None, self)]
135142

143+
@asynccontextmanager
144+
async def resource(self, handle: str):
145+
handle = Resource.validate_python(handle)
146+
match handle:
147+
case ClientStreamResource(uuid=uuid):
148+
yield self.resources[uuid]
149+
case PresignedRequestResource(headers=headers, url=url, method=method):
150+
async with aiohttp.request(method, url, headers=headers, raise_for_status=True) as resp:
151+
yield AiohttpStream(stream=resp.content)
152+
136153
async def __lookup_drivercall(self, name, context, marker):
137154
"""Lookup drivercall by method name
138155

jumpstarter/drivers/core.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,23 @@
44

55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass
7+
from uuid import UUID
78

89
from anyio import create_task_group, sleep_forever
910
from anyio.streams.stapled import StapledObjectStream
1011
from google.protobuf import json_format, struct_pb2
1112
from grpc.aio import Channel
13+
from opendal import AsyncOperator
1214

1315
from jumpstarter.common import Metadata
16+
from jumpstarter.common.opendal import AsyncFileStream
1417
from jumpstarter.common.progress import ProgressStream
1518
from jumpstarter.common.streams import (
1619
create_memory_stream,
1720
forward_client_stream,
1821
)
22+
from jumpstarter.drivers.resources import ClientStreamResource, PresignedRequestResource
23+
from jumpstarter.drivers.streams import DriverStreamRequest, ResourceStreamRequest
1924
from jumpstarter.v1 import jumpstarter_pb2, jumpstarter_pb2_grpc, router_pb2_grpc
2025

2126

@@ -68,7 +73,7 @@ async def stream_async(self, method):
6873
async with forward_client_stream(
6974
self,
7075
device_stream,
71-
{"kind": "connect", "uuid": str(self.uuid), "method": method}.items(),
76+
{"request": DriverStreamRequest(uuid=self.uuid, method=method).model_dump_json()}.items(),
7277
):
7378
async with client_stream:
7479
yield client_stream
@@ -80,7 +85,7 @@ async def handle(client):
8085
async with forward_client_stream(
8186
self,
8287
client,
83-
{"kind": "connect", "uuid": str(self.uuid), "method": method}.items(),
88+
{"request": DriverStreamRequest(uuid=self.uuid, method=method).model_dump_json()}.items(),
8489
):
8590
await sleep_forever()
8691

@@ -104,6 +109,22 @@ async def resource_async(
104109
async with forward_client_stream(
105110
self,
106111
combined,
107-
{"kind": "resource", "uuid": str(self.uuid)}.items(),
112+
{"request": ResourceStreamRequest(uuid=self.uuid).model_dump_json()}.items(),
108113
):
109-
yield (await rx.receive()).decode()
114+
yield ClientStreamResource(uuid=UUID((await rx.receive()).decode())).model_dump(mode="json")
115+
116+
@asynccontextmanager
117+
async def file_async(
118+
self,
119+
operator: AsyncOperator,
120+
path: str,
121+
):
122+
if operator.capability().presign:
123+
presigned = await operator.presign_read(path, expire_second=60)
124+
yield PresignedRequestResource(
125+
headers=presigned.headers, url=presigned.url, method=presigned.method
126+
).model_dump(mode="json")
127+
else:
128+
file = await operator.open(path, "rb")
129+
async with self.resource_async(AsyncFileStream(file=file)) as handle:
130+
yield handle

jumpstarter/drivers/dutlink/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from collections.abc import AsyncGenerator
33
from dataclasses import dataclass, field
44
from pathlib import Path
5-
from uuid import UUID
65

76
import pyudev
87
import usb.core
@@ -103,8 +102,9 @@ async def write(self, src: str):
103102
await sleep(1)
104103

105104
async with await FileWriteStream.from_path(self.storage_device) as stream:
106-
async for chunk in self.resources[UUID(src)]:
107-
await stream.send(chunk)
105+
async with self.resource(src) as res:
106+
async for chunk in res:
107+
await stream.send(chunk)
108108

109109

110110
@dataclass(kw_only=True)

jumpstarter/drivers/mixins.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from anyio import create_unix_listener
1313
from anyio.from_thread import BlockingPortal
14-
from anyio.streams.file import FileReadStream
14+
from opendal import Operator
1515
from pexpect.fdpexpect import fdspawn
1616

1717

@@ -70,15 +70,6 @@ class ResourceMixin:
7070
"""Resource"""
7171

7272
@contextmanager
73-
def local_file(
74-
self,
75-
filepath,
76-
):
77-
"""
78-
Share local file with driver
79-
80-
:param str filepath: path to file
81-
"""
82-
with self.portal.wrap_async_context_manager(self.portal.call(FileReadStream.from_path, filepath)) as file:
83-
with self.portal.wrap_async_context_manager(self.resource_async(file)) as uuid:
84-
yield uuid
73+
def file(self, operator: Operator, path: str):
74+
with self.portal.wrap_async_context_manager(self.file_async(operator.to_async_operator(), path)) as uuid:
75+
yield uuid

jumpstarter/drivers/resources.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from typing import Annotated, Literal, Union
2+
from uuid import UUID
3+
4+
from pydantic import BaseModel, Field, TypeAdapter
5+
6+
7+
class ClientStreamResource(BaseModel):
8+
kind: Literal["client_stream"] = "client_stream"
9+
uuid: UUID
10+
11+
12+
class PresignedRequestResource(BaseModel):
13+
kind: Literal["presigned_request"] = "presigned_request"
14+
headers: dict[str, str]
15+
url: str
16+
method: str
17+
18+
19+
Resource = TypeAdapter(
20+
Annotated[
21+
Union[ClientStreamResource, PresignedRequestResource],
22+
Field(discriminator="kind"),
23+
]
24+
)

jumpstarter/drivers/storage/base.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from abc import ABCMeta, abstractmethod
22
from tempfile import NamedTemporaryFile
3-
from uuid import UUID
43

54
import click
65
from anyio.streams.file import FileWriteStream
6+
from opendal import Operator
77

88
from jumpstarter.drivers import Driver, DriverClient, export
99
from jumpstarter.drivers.mixins import ResourceMixin
@@ -40,8 +40,12 @@ def off(self):
4040
def write(self, handle):
4141
return self.call("write", handle)
4242

43+
def write_file(self, operator: Operator, path: str):
44+
with self.file(operator, path) as handle:
45+
return self.call("write", handle)
46+
4347
def write_local_file(self, filepath):
44-
with self.local_file(filepath) as handle:
48+
with self.file(Operator("fs", root="/"), filepath) as handle:
4549
return self.call("write", handle)
4650

4751
def cli(self):
@@ -90,5 +94,6 @@ async def off(self):
9094
async def write(self, src: str):
9195
with NamedTemporaryFile() as file:
9296
async with FileWriteStream(file) as stream:
93-
async for chunk in self.resources[UUID(src)]:
94-
await stream.send(chunk)
97+
async with self.resource(src) as res:
98+
async for chunk in res:
99+
await stream.send(chunk)
Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,36 @@
1-
import os
2-
from tempfile import NamedTemporaryFile
1+
from pathlib import Path
2+
from tempfile import TemporaryDirectory
3+
4+
import pytest
5+
from opendal import Operator
36

47
from jumpstarter.common.utils import serve
58
from jumpstarter.drivers.storage import MockStorageMux
69

710

8-
def test_drivers_mock_storage_mux():
11+
def test_drivers_mock_storage_mux_fs():
12+
with serve(MockStorageMux(name="storage")) as client:
13+
with TemporaryDirectory() as tempdir:
14+
fs = Operator("fs", root=tempdir)
15+
16+
fs.write("test", b"testcontent" * 1000)
17+
18+
client.write_file(fs, "test")
19+
client.write_local_file(str(Path(tempdir) / "test"))
20+
21+
22+
@pytest.mark.skip(reason="require minio")
23+
def test_drivers_mock_storage_mux_s3():
924
with serve(MockStorageMux(name="storage")) as client:
10-
with NamedTemporaryFile(delete=False) as file:
11-
file.write(b"testcontent" * 1000)
12-
file.close()
25+
s3 = Operator(
26+
"s3",
27+
bucket="test",
28+
endpoint="http://127.0.0.1:9000",
29+
region="us-east-1",
30+
access_key_id="minioadmin",
31+
secret_access_key="minioadmin",
32+
)
1333

14-
client.off()
15-
client.dut()
16-
client.host()
17-
client.write_local_file(file.name)
34+
s3.write("test", b"testcontent" * 1000)
1835

19-
os.unlink(file.name)
36+
client.write_file(s3, "test")

jumpstarter/drivers/streams.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Annotated, Literal, Union
2+
from uuid import UUID
3+
4+
from pydantic import BaseModel, Field, TypeAdapter
5+
6+
7+
class ResourceStreamRequest(BaseModel):
8+
kind: Literal["resource"] = "resource"
9+
uuid: UUID
10+
11+
12+
class DriverStreamRequest(BaseModel):
13+
kind: Literal["driver"] = "driver"
14+
uuid: UUID
15+
method: str
16+
17+
18+
StreamRequest = TypeAdapter(
19+
Annotated[
20+
Union[ResourceStreamRequest, DriverStreamRequest],
21+
Field(discriminator="kind"),
22+
]
23+
)

jumpstarter/exporter/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from jumpstarter.common import Metadata
55
from jumpstarter.drivers.base import Driver
6+
from jumpstarter.drivers.streams import StreamRequest
67
from jumpstarter.v1 import (
78
jumpstarter_pb2,
89
jumpstarter_pb2_grpc,
@@ -51,5 +52,7 @@ async def StreamingDriverCall(self, request, context):
5152
async def Stream(self, request_iterator, context):
5253
metadata = dict(context.invocation_metadata())
5354

54-
async for v in self[UUID(metadata["uuid"])].Stream(request_iterator, context):
55+
request = StreamRequest.validate_json(metadata["request"], strict=True)
56+
57+
async for v in self[request.uuid].Stream(request_iterator, context):
5558
yield v

0 commit comments

Comments
 (0)