Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds metadata to FlyteFile #3160

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 32 additions & 12 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def noop(): ...
@dataclass
class FlyteFile(SerializableType, os.PathLike, typing.Generic[T], DataClassJSONMixin):
path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
metadata: typing.Optional[dict[str, str]] = None
"""
Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int
exists for Flyte's Integer type) we need to create one so that users can express that their tasks take
Expand Down Expand Up @@ -158,18 +159,24 @@ def t2() -> flytekit_typing.FlyteFile["csv"]:
return "/tmp/local_file.csv"
"""

def _serialize(self) -> typing.Dict[str, str]:
def _serialize(self) -> typing.Dict[str, typing.Any]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}
out = {"path": lv.scalar.blob.uri}
if lv.metadata:
out["metadata"] = lv.metadata
return out

@classmethod
def _deserialize(cls, value) -> "FlyteFile":
return FlyteFilePathTransformer().dict_to_flyte_file(dict_obj=value, expected_python_type=cls)

@model_serializer
def serialize_flyte_file(self) -> Dict[str, str]:
def serialize_flyte_file(self) -> Dict[str, typing.Any]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}
out = {"path": lv.scalar.blob.uri}
if lv.metadata:
out["metadata"] = lv.metadata
return out

@model_validator(mode="after")
def deserialize_flyte_file(self, info) -> "FlyteFile":
Expand All @@ -188,7 +195,8 @@ def deserialize_flyte_file(self, info) -> "FlyteFile":
),
uri=self.path,
)
)
),
metadata=self.metadata,
),
type(self),
)
Expand Down Expand Up @@ -281,6 +289,7 @@ def __init__(
path: typing.Union[str, os.PathLike],
downloader: typing.Callable = noop,
remote_path: typing.Optional[typing.Union[os.PathLike, str, bool]] = None,
metadata: typing.Optional[dict[str, str]] = None,
):
"""
FlyteFile's init method.
Expand All @@ -295,6 +304,7 @@ def __init__(
# Make this field public, so that the dataclass transformer can set a value for it
# https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
self.path = path
self.metadata = metadata
self._downloader = downloader
self._downloaded = False
self._remote_path = remote_path
Expand Down Expand Up @@ -538,7 +548,9 @@ async def async_to_literal(
# If the object has a remote source, then we just convert it back. This means that if someone is just
# going back and forth between a FlyteFile Python value and a Blob Flyte IDL value, we don't do anything.
if python_val._remote_source is not None:
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)))
return Literal(
scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)), metadata=python_val.metadata
)

# If the user specified the remote_path to be False, that means no matter what, do not upload. Also if the
# path given is already a remote path, say https://www.google.com, the concept of uploading to the Flyte
Expand Down Expand Up @@ -593,10 +605,15 @@ async def async_to_literal(
else:
remote_path = await ctx.file_access.async_put_raw_data(source_path, **headers)
# If the source path is a local file, the remote path will be a remote storage path.
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path)))))
return Literal(
scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path)))),
metadata=getattr(python_val, "metadata", None),
)
# If not uploading, then we can only take the original source path as the uri.
else:
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)))
return Literal(
scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)), metadata=getattr(python_val, "metadata", None)
)

@staticmethod
def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, str]:
Expand All @@ -608,6 +625,7 @@ def dict_to_flyte_file(
self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
path = dict_obj.get("path", None)
metadata = dict_obj.get("metadata", None)

if path is None:
raise ValueError("FlyteFile's path should not be None")
Expand All @@ -624,7 +642,8 @@ def dict_to_flyte_file(
),
uri=path,
)
)
),
metadata=metadata,
),
expected_python_type,
)
Expand Down Expand Up @@ -704,6 +723,7 @@ async def async_to_python_value(

try:
uri = lv.scalar.blob.uri
metadata = lv.metadata
except AttributeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

Expand All @@ -718,7 +738,7 @@ async def async_to_python_value(
# In this condition, we still return a FlyteFile instance, but it's a simple one that has no downloading tricks
# Using is instead of issubclass because FlyteFile does actually subclass it
if expected_python_type is os.PathLike:
return FlyteFile(uri)
return FlyteFile(path=uri, metadata=metadata)

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
expected_python_type = get_underlying_type(expected_python_type)
Expand All @@ -730,15 +750,15 @@ async def async_to_python_value(
# This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't
# make any sense.
if not ctx.file_access.is_remote(uri):
return expected_python_type(uri) # type: ignore
return expected_python_type(path=uri, metadata=metadata) # type: ignore

# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

_downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata)
ff._remote_source = uri
return ff

Expand Down
18 changes: 12 additions & 6 deletions tests/flytekit/integration/remote/workflows/basic/flytefile.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Optional
from flytekit import task, workflow
from flytekit.types.file import FlyteFile


@task
def create_ff(file_path: str) -> FlyteFile:
def create_ff(file_path: str, info: str) -> FlyteFile:
"""Create a FlyteFile."""
return FlyteFile(path=file_path)
return FlyteFile(path=file_path, metadata={"info": info})


@task
def read_ff(ff: FlyteFile) -> None:
def read_ff(ff: FlyteFile, info: Optional[str] = None) -> None:
"""Read input FlyteFile.

This can be used in the case in which a FlyteFile is created
Expand All @@ -19,6 +20,11 @@ def read_ff(ff: FlyteFile) -> None:
content = f.read()
print(f"FILE CONTENT | {content}")

if info:
assert ff.metadata["info"] == info
else:
assert ff.metadata is None


@task
def create_and_read_ff(file_path: str) -> FlyteFile:
Expand All @@ -41,9 +47,9 @@ def create_and_read_ff(file_path: str) -> FlyteFile:


@workflow
def wf(remote_file_path: str) -> None:
ff_1 = create_ff(file_path=remote_file_path)
read_ff(ff=ff_1)
def wf(remote_file_path: str, info: str = "abc") -> None:
ff_1 = create_ff(file_path=remote_file_path, info=info)
read_ff(ff=ff_1, info=info)
ff_2 = create_and_read_ff(file_path=remote_file_path)
read_ff(ff=ff_2)

Expand Down
59 changes: 58 additions & 1 deletion tests/flytekit/unit/types/file/test_file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import tempfile
from pathlib import Path
from typing import Optional
from dataclasses import dataclass

import pytest
from flytekit import task, workflow
from flytekit import task, workflow, current_context
from flytekit.types.file import FlyteFile


Expand Down Expand Up @@ -70,3 +71,59 @@ def _verify_msg(ff: FlyteFile) -> None:

ff_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=remote_path)
_verify_msg(ff_4)


def test_metadata():

@task
def create_file() -> FlyteFile:
ctx = current_context()
wd = Path(ctx.working_directory)
new_file = wd / "my_file.txt"

content = "hello there"
new_file.write_text(content)
return FlyteFile(path=new_file, metadata={"length": str(len(content))})

@task
def read_metadata(file: FlyteFile) -> Optional[dict]:
return file.metadata

@workflow
def wf() -> Optional[dict]:
file = create_file()
return read_metadata(file=file)

output = wf()
assert output["length"] == "11"


@dataclass
class SimpleDC:
file: FlyteFile


def test_metadata_with_dataclass():
@task
def create_dc() -> SimpleDC:
ctx = current_context()
wd = Path(ctx.working_directory)
my_file = wd / "file.txt"
my_file.write_text("hello there!")
return SimpleDC(file=FlyteFile(path=my_file, metadata={"HELLO": "WORLD"}))


@task
def get_metadata(dc: SimpleDC) -> dict:
if dc.file.metadata:
return dc.file.metadata
else:
return {}

@workflow
def wf() -> dict:
dc = create_dc()
return get_metadata(dc=dc)

output = wf()
assert output["HELLO"] == "WORLD"
Loading