-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] Pandas Dataframe in Dataclass #3116
base: master
Are you sure you want to change the base?
Changes from all commits
28db576
fc506fe
de150b1
2a3bffe
1625381
d9d438e
c2579e4
b5f2a6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||||||||||||||||||||||
from __future__ import annotations | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import dataclasses | ||||||||||||||||||||||||
from dataclasses import dataclass, fields, make_dataclass, is_dataclass, MISSING | ||||||||||||||||||||||||
import asyncio | ||||||||||||||||||||||||
import collections | ||||||||||||||||||||||||
import copy | ||||||||||||||||||||||||
|
@@ -129,7 +130,6 @@ | |||||||||||||||||||||||
lit.scalar.structured_dataset.uri | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ... | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -683,6 +683,40 @@ | |||||||||||||||||||||||
Set `FLYTE_USE_OLD_DC_FORMAT=true` to use the old JSON-based format. | ||||||||||||||||||||||||
Note: This is deprecated and will be removed in the future. | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||
from flytekit.types.file import FlyteFile | ||||||||||||||||||||||||
from flytekit.types.directory import FlyteDirectory | ||||||||||||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||||||||||||
from flytekit.types.schema import FlyteSchema | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def transform_dataclass(cls, memo=None): | ||||||||||||||||||||||||
FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema] | ||||||||||||||||||||||||
if cls in FLYTE_TYPES: | ||||||||||||||||||||||||
return cls | ||||||||||||||||||||||||
if memo is None: | ||||||||||||||||||||||||
memo = {} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if cls in memo: | ||||||||||||||||||||||||
return memo[cls] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
cls_hints = get_type_hints(cls) | ||||||||||||||||||||||||
new_field_defs = [] | ||||||||||||||||||||||||
for field in fields(cls): | ||||||||||||||||||||||||
orig_type = cls_hints[field.name] | ||||||||||||||||||||||||
if orig_type == pd.DataFrame: | ||||||||||||||||||||||||
new_type = StructuredDataset | ||||||||||||||||||||||||
elif is_dataclass(orig_type): | ||||||||||||||||||||||||
new_type = transform_dataclass(orig_type, memo) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
new_type = orig_type | ||||||||||||||||||||||||
new_field_defs.append((field.name, new_type)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
new_cls = make_dataclass("FlyteModified" + cls.__name__, new_field_defs) | ||||||||||||||||||||||||
memo[cls] = new_cls | ||||||||||||||||||||||||
return new_cls | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if isinstance(python_val, dict): | ||||||||||||||||||||||||
json_str = json.dumps(python_val) | ||||||||||||||||||||||||
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) | ||||||||||||||||||||||||
|
@@ -694,16 +728,17 @@ | |||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self._make_dataclass_serializable(python_val, python_type) | ||||||||||||||||||||||||
new_python_type = transform_dataclass(python_type) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# JSON serialization using mashumaro's DataClassJSONMixin | ||||||||||||||||||||||||
if isinstance(python_val, DataClassJSONMixin): | ||||||||||||||||||||||||
json_str = python_val.to_json() | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
encoder = self._json_encoder[python_type] | ||||||||||||||||||||||||
encoder = self._json_encoder[new_python_type] | ||||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||||
encoder = JSONEncoder(python_type) | ||||||||||||||||||||||||
self._json_encoder[python_type] = encoder | ||||||||||||||||||||||||
encoder = JSONEncoder(new_python_type) | ||||||||||||||||||||||||
self._json_encoder[new_python_type] = encoder | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
json_str = encoder.encode(python_val) | ||||||||||||||||||||||||
|
@@ -729,7 +764,43 @@ | |||||||||||||||||||||||
f"user defined datatypes in Flytekit" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||
from flytekit.types.file import FlyteFile | ||||||||||||||||||||||||
from flytekit.types.directory import FlyteDirectory | ||||||||||||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||||||||||||
from flytekit.types.schema import FlyteSchema | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def transform_dataclass(cls, memo=None): | ||||||||||||||||||||||||
FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema] | ||||||||||||||||||||||||
if cls in FLYTE_TYPES: | ||||||||||||||||||||||||
return cls | ||||||||||||||||||||||||
Comment on lines
+777
to
+779
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider moving constant to module level
Consider moving the Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #e17dbe Is this a valid issue, or was it incorrectly flagged by the Agent?
Comment on lines
+777
to
+779
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider using immutable type for constants
Consider using a tuple or frozenset instead of list for Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #e17dbe Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||||||||
if memo is None: | ||||||||||||||||||||||||
memo = {} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if cls in memo: | ||||||||||||||||||||||||
return memo[cls] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
cls_hints = get_type_hints(cls) | ||||||||||||||||||||||||
new_field_defs = [] | ||||||||||||||||||||||||
for field in fields(cls): | ||||||||||||||||||||||||
orig_type = cls_hints[field.name] | ||||||||||||||||||||||||
if orig_type == pd.DataFrame: | ||||||||||||||||||||||||
new_type = StructuredDataset | ||||||||||||||||||||||||
elif is_dataclass(orig_type): | ||||||||||||||||||||||||
new_type = transform_dataclass(orig_type, memo) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
new_type = orig_type | ||||||||||||||||||||||||
new_field_defs.append((field.name, new_type)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
new_cls = make_dataclass("FlyteModified" + cls.__name__, new_field_defs) | ||||||||||||||||||||||||
memo[cls] = new_cls | ||||||||||||||||||||||||
return new_cls | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self._make_dataclass_serializable(python_val, python_type) | ||||||||||||||||||||||||
new_python_type = transform_dataclass(python_type) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# The `to_json` integrated through mashumaro's `DataClassJSONMixin` allows for more | ||||||||||||||||||||||||
# functionality than JSONEncoder | ||||||||||||||||||||||||
|
@@ -742,10 +813,10 @@ | |||||||||||||||||||||||
# The function looks up or creates a MessagePackEncoder specifically designed for the object's type. | ||||||||||||||||||||||||
# This encoder is then used to convert a data class into MessagePack Bytes. | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
encoder = self._msgpack_encoder[python_type] | ||||||||||||||||||||||||
encoder = self._msgpack_encoder[new_python_type] | ||||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||||
encoder = MessagePackEncoder(python_type) | ||||||||||||||||||||||||
self._msgpack_encoder[python_type] = encoder | ||||||||||||||||||||||||
encoder = MessagePackEncoder(new_python_type) | ||||||||||||||||||||||||
self._msgpack_encoder[new_python_type] = encoder | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
msgpack_bytes = encoder.encode(python_val) | ||||||||||||||||||||||||
|
@@ -836,6 +907,9 @@ | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if not dataclasses.is_dataclass(python_type): | ||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||
if isinstance(python_val, pd.DataFrame): | ||||||||||||||||||||||||
python_val = StructuredDataset(dataframe=python_val, file_format="parquet") | ||||||||||||||||||||||||
Comment on lines
+910
to
+912
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider dedicated transformer for DataFrame conversion
Consider moving the pandas DataFrame conversion logic to a dedicated transformer class instead of handling it in the Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #fedbf7 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||||||||
return python_val | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Transform str to FlyteFile or FlyteDirectory so that mashumaro can serialize the path. | ||||||||||||||||||||||||
|
@@ -874,6 +948,10 @@ | |||||||||||||||||||||||
if t == int: | ||||||||||||||||||||||||
return int(val) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||
if t == pd.DataFrame: | ||||||||||||||||||||||||
return val().open(dataframe_type=pd.DataFrame).all() | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if isinstance(val, list): | ||||||||||||||||||||||||
# Handle nested List. e.g. [[1, 2], [3, 4]] | ||||||||||||||||||||||||
return list(map(lambda x: self._fix_val_int(ListTransformer.get_sub_type(t), x), val)) | ||||||||||||||||||||||||
|
@@ -918,7 +996,8 @@ | |||||||||||||||||||||||
self._msgpack_decoder[expected_python_type] = decoder | ||||||||||||||||||||||||
dc = decoder.decode(binary_idl_object.value) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return dc | ||||||||||||||||||||||||
# return dc | ||||||||||||||||||||||||
return self._fix_dataclass_int(expected_python_type, dc) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -929,28 +1008,78 @@ | |||||||||||||||||||||||
"user defined datatypes in Flytekit" | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||||||||||||
import pandas as pd | ||||||||||||||||||||||||
from flytekit.types.file import FlyteFile | ||||||||||||||||||||||||
from flytekit.types.directory import FlyteDirectory | ||||||||||||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||||||||||||
from flytekit.types.schema import FlyteSchema | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def convert_dataclass(instance, target_cls): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding type hints to function
The Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #b16268 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||||||||
if not (is_dataclass(instance) and is_dataclass(target_cls)): | ||||||||||||||||||||||||
return instance | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
kwargs = {} | ||||||||||||||||||||||||
target_fields = {f.name: f.type for f in fields(target_cls)} | ||||||||||||||||||||||||
for field in fields(instance.__class__): | ||||||||||||||||||||||||
if field.name in target_fields: | ||||||||||||||||||||||||
value = getattr(instance, field.name) | ||||||||||||||||||||||||
if is_dataclass(value) and is_dataclass(target_fields[field.name]): | ||||||||||||||||||||||||
value = convert_dataclass(value, target_fields[field.name]) | ||||||||||||||||||||||||
kwargs[field.name] = value | ||||||||||||||||||||||||
return target_cls(**kwargs) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def transform_dataclass(cls, memo=None): | ||||||||||||||||||||||||
if memo is None: | ||||||||||||||||||||||||
memo = {} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if cls in memo: | ||||||||||||||||||||||||
return memo[cls] | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
cls_hints = get_type_hints(cls) | ||||||||||||||||||||||||
new_field_defs = [] | ||||||||||||||||||||||||
for field in fields(cls): | ||||||||||||||||||||||||
orig_type = cls_hints[field.name] | ||||||||||||||||||||||||
if orig_type == pd.DataFrame: | ||||||||||||||||||||||||
new_type = StructuredDataset | ||||||||||||||||||||||||
elif is_dataclass(orig_type): | ||||||||||||||||||||||||
new_type = transform_dataclass(orig_type, memo) | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
new_type = orig_type | ||||||||||||||||||||||||
new_field_defs.append((field.name, new_type)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
new_cls = make_dataclass("FlyteModified" + cls.__name__, new_field_defs) | ||||||||||||||||||||||||
memo[cls] = new_cls | ||||||||||||||||||||||||
return new_cls | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
new_expected_python_type = transform_dataclass(expected_python_type) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if lv.scalar and lv.scalar.binary: | ||||||||||||||||||||||||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore | ||||||||||||||||||||||||
return convert_dataclass(self.from_binary_idl(lv.scalar.binary, new_expected_python_type), expected_python_type) # type: ignore | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
json_str = _json_format.MessageToJson(lv.scalar.generic) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. | ||||||||||||||||||||||||
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder | ||||||||||||||||||||||||
# We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. | ||||||||||||||||||||||||
if issubclass(expected_python_type, DataClassJSONMixin): | ||||||||||||||||||||||||
dc = expected_python_type.from_json(json_str) # type: ignore | ||||||||||||||||||||||||
if issubclass(new_expected_python_type, DataClassJSONMixin): | ||||||||||||||||||||||||
dc = new_expected_python_type.from_json(json_str) # type: ignore | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
# The function looks up or creates a JSONDecoder specifically designed for the object's type. | ||||||||||||||||||||||||
# This decoder is then used to convert a JSON string into a data class. | ||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
decoder = self._json_decoder[expected_python_type] | ||||||||||||||||||||||||
decoder = self._json_decoder[new_expected_python_type] | ||||||||||||||||||||||||
except KeyError: | ||||||||||||||||||||||||
decoder = JSONDecoder(expected_python_type) | ||||||||||||||||||||||||
self._json_decoder[expected_python_type] = decoder | ||||||||||||||||||||||||
decoder = JSONDecoder(new_expected_python_type) | ||||||||||||||||||||||||
self._json_decoder[new_expected_python_type] = decoder | ||||||||||||||||||||||||
Comment on lines
+1077
to
+1078
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider caching decoder with original type
Consider storing the decoder in the original Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #fedbf7 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
dc = decoder.decode(json_str) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
return self._fix_dataclass_int(expected_python_type, dc) | ||||||||||||||||||||||||
return convert_dataclass(self._fix_dataclass_int(new_expected_python_type, dc), expected_python_type) | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider validating dataclass field compatibility
The conversion from Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #b16268 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
# This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` | ||||||||||||||||||||||||
# command needs to call guess_python_type to get the TypeEngine-derived dataclass. Without caching here, separate | ||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using
python_type
instead ofnew_python_type
when storing the encoder inself._json_encoder
. The encoder is created usingpython_type
but stored withnew_python_type
, which could lead to inconsistencies in encoder lookup.Code suggestion
Code Review Run #c4bd83
Is this a valid issue, or was it incorrectly flagged by the Agent?