Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds metadata to FlyteFile #3160

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 26 additions & 12 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
@dataclass
class FlyteFile(SerializableType, os.PathLike, typing.Generic[T], DataClassJSONMixin):
path: typing.Union[str, os.PathLike] = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore
metadata: typing.Optional[dict[str, str]] = None
"""
Since there is no native Python implementation of files and directories for the Flyte Blob type, (like how int
exists for Flyte's Integer type) we need to create one so that users can express that their tasks take
Expand Down Expand Up @@ -158,18 +159,18 @@
return "/tmp/local_file.csv"
"""

def _serialize(self) -> typing.Dict[str, str]:
def _serialize(self) -> typing.Dict[str, typing.Any]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}
return {"path": lv.scalar.blob.uri, "metadata": lv.metadata}

Check warning on line 164 in flytekit/types/file/file.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/file/file.py#L164

Added line #L164 was not covered by tests

@classmethod
def _deserialize(cls, value) -> "FlyteFile":
return FlyteFilePathTransformer().dict_to_flyte_file(dict_obj=value, expected_python_type=cls)

@model_serializer
def serialize_flyte_file(self) -> Dict[str, str]:
def serialize_flyte_file(self) -> Dict[str, typing.Any]:
lv = FlyteFilePathTransformer().to_literal(FlyteContextManager.current_context(), self, type(self), None)
return {"path": lv.scalar.blob.uri}
return {"path": lv.scalar.blob.uri, "metadata": lv.metadata}

@model_validator(mode="after")
def deserialize_flyte_file(self, info) -> "FlyteFile":
Expand All @@ -188,7 +189,8 @@
),
uri=self.path,
)
)
),
metadata=self.metadata,
),
type(self),
)
Expand Down Expand Up @@ -281,6 +283,7 @@
path: typing.Union[str, os.PathLike],
downloader: typing.Callable = noop,
remote_path: typing.Optional[typing.Union[os.PathLike, str, bool]] = None,
metadata: typing.Optional[dict[str, str]] = None,
):
"""
FlyteFile's init method.
Expand All @@ -295,6 +298,7 @@
# Make this field public, so that the dataclass transformer can set a value for it
# https://github.com/flyteorg/flytekit/blob/bcc8541bd6227b532f8462563fe8aac902242b21/flytekit/core/type_engine.py#L298
self.path = path
self.metadata = metadata
self._downloader = downloader
self._downloaded = False
self._remote_path = remote_path
Expand Down Expand Up @@ -538,7 +542,9 @@
# If the object has a remote source, then we just convert it back. This means that if someone is just
# going back and forth between a FlyteFile Python value and a Blob Flyte IDL value, we don't do anything.
if python_val._remote_source is not None:
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)))
return Literal(

Check warning on line 545 in flytekit/types/file/file.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/file/file.py#L545

Added line #L545 was not covered by tests
scalar=Scalar(blob=Blob(metadata=meta, uri=python_val._remote_source)), metadata=python_val.metadata
)

# If the user specified the remote_path to be False, that means no matter what, do not upload. Also if the
# path given is already a remote path, say https://www.google.com, the concept of uploading to the Flyte
Expand Down Expand Up @@ -593,10 +599,15 @@
else:
remote_path = await ctx.file_access.async_put_raw_data(source_path, **headers)
# If the source path is a local file, the remote path will be a remote storage path.
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path)))))
return Literal(
scalar=Scalar(blob=Blob(metadata=meta, uri=unquote(str(remote_path)))),
metadata=getattr(python_val, "metadata", None),
)
# If not uploading, then we can only take the original source path as the uri.
else:
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)))
return Literal(
scalar=Scalar(blob=Blob(metadata=meta, uri=source_path)), metadata=getattr(python_val, "metadata", None)
)

@staticmethod
def get_additional_headers(source_path: str | os.PathLike) -> typing.Dict[str, str]:
Expand All @@ -608,6 +619,7 @@
self, dict_obj: typing.Dict[str, str], expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike]
) -> FlyteFile:
path = dict_obj.get("path", None)
metadata = dict_obj.get("metadata", None)

if path is None:
raise ValueError("FlyteFile's path should not be None")
Expand All @@ -624,7 +636,8 @@
),
uri=path,
)
)
),
metadata=metadata,
),
expected_python_type,
)
Expand Down Expand Up @@ -704,6 +717,7 @@

try:
uri = lv.scalar.blob.uri
metadata = lv.metadata
except AttributeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

Expand All @@ -718,7 +732,7 @@
# In this condition, we still return a FlyteFile instance, but it's a simple one that has no downloading tricks
# Using is instead of issubclass because FlyteFile does actually subclass it
if expected_python_type is os.PathLike:
return FlyteFile(uri)
return FlyteFile(path=uri, metadata=metadata)

Check warning on line 735 in flytekit/types/file/file.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/file/file.py#L735

Added line #L735 was not covered by tests

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
expected_python_type = get_underlying_type(expected_python_type)
Expand All @@ -730,15 +744,15 @@
# This is a local file path, like /usr/local/my_file, don't mess with it. Certainly, downloading it doesn't
# make any sense.
if not ctx.file_access.is_remote(uri):
return expected_python_type(uri) # type: ignore
return expected_python_type(path=uri, metadata=metadata) # type: ignore

# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

_downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader, metadata=metadata)

Check warning on line 755 in flytekit/types/file/file.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/file/file.py#L755

Added line #L755 was not covered by tests
ff._remote_source = uri
return ff

Expand Down
18 changes: 12 additions & 6 deletions tests/flytekit/integration/remote/workflows/basic/flytefile.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Optional
from flytekit import task, workflow
from flytekit.types.file import FlyteFile


@task
def create_ff(file_path: str) -> FlyteFile:
def create_ff(file_path: str, info: str) -> FlyteFile:
"""Create a FlyteFile."""
return FlyteFile(path=file_path)
return FlyteFile(path=file_path, metadata={"info": info})


@task
def read_ff(ff: FlyteFile) -> None:
def read_ff(ff: FlyteFile, info: Optional[str] = None) -> None:
"""Read input FlyteFile.

This can be used in the case in which a FlyteFile is created
Expand All @@ -19,6 +20,11 @@ def read_ff(ff: FlyteFile) -> None:
content = f.read()
print(f"FILE CONTENT | {content}")

if info:
assert ff.metadata["info"] == info
else:
assert ff.metadata is None


@task
def create_and_read_ff(file_path: str) -> FlyteFile:
Expand All @@ -41,9 +47,9 @@ def create_and_read_ff(file_path: str) -> FlyteFile:


@workflow
def wf(remote_file_path: str) -> None:
ff_1 = create_ff(file_path=remote_file_path)
read_ff(ff=ff_1)
def wf(remote_file_path: str, info: str = "abc") -> None:
ff_1 = create_ff(file_path=remote_file_path, info=info)
read_ff(ff=ff_1, info=info)
ff_2 = create_and_read_ff(file_path=remote_file_path)
read_ff(ff=ff_2)

Expand Down
59 changes: 58 additions & 1 deletion tests/flytekit/unit/types/file/test_file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import tempfile
from pathlib import Path
from typing import Optional
from dataclasses import dataclass

import pytest
from flytekit import task, workflow
from flytekit import task, workflow, current_context
from flytekit.types.file import FlyteFile


Expand Down Expand Up @@ -70,3 +71,59 @@ def _verify_msg(ff: FlyteFile) -> None:

ff_4 = wf(source_path=source_path, use_pathlike_src_path=True, remote_path=remote_path)
_verify_msg(ff_4)


def test_metadata():

@task
def create_file() -> FlyteFile:
ctx = current_context()
wd = Path(ctx.working_directory)
new_file = wd / "my_file.txt"

content = "hello there"
new_file.write_text(content)
return FlyteFile(path=new_file, metadata={"length": str(len(content))})

@task
def read_metadata(file: FlyteFile) -> Optional[dict]:
return file.metadata

@workflow
def wf() -> Optional[dict]:
file = create_file()
return read_metadata(file=file)

output = wf()
assert output["length"] == "11"


@dataclass
class SimpleDC:
file: FlyteFile


def test_metadata_with_dataclass():
@task
def create_dc() -> SimpleDC:
ctx = current_context()
wd = Path(ctx.working_directory)
my_file = wd / "file.txt"
my_file.write_text("hello there!")
return SimpleDC(file=FlyteFile(path=my_file, metadata={"HELLO": "WORLD"}))


@task
def get_metadata(dc: SimpleDC) -> dict:
if dc.file.metadata:
return dc.file.metadata
else:
return {}

@workflow
def wf() -> dict:
dc = create_dc()
return get_metadata(dc=dc)

output = wf()
assert output["HELLO"] == "WORLD"
2 changes: 1 addition & 1 deletion tests/flytekit/unit/utils/test_pbhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,4 @@ class TestFileStruct(DataClassJsonMixin):
lt = tf.get_literal_type(TestFileStruct)
lv = tf.to_literal(ctx, o, TestFileStruct, lt)

assert compute_hash_string(lv.to_flyte_idl()) == "Hp/cWul3sBI5r8XKdVzAlvNBJ4OSX9L2d/SADI8+YOY="
assert compute_hash_string(lv.to_flyte_idl()) == "gqjGZ84q3Tz80PX3RYYZQ+bHz8zVoYYo+uvdOyMwHB0="
Loading