-
Notifications
You must be signed in to change notification settings - Fork 310
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] Pandas Dataframe in Dataclass #3116
base: master
Are you sure you want to change the base?
Changes from 4 commits
28db576
fc506fe
de150b1
2a3bffe
1625381
d9d438e
c2579e4
b5f2a6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,5 +1,6 @@ | ||||||||||||||
from __future__ import annotations | ||||||||||||||
|
||||||||||||||
import dataclasses | ||||||||||||||
from dataclasses import dataclass, fields, make_dataclass, is_dataclass, MISSING | ||||||||||||||
import asyncio | ||||||||||||||
import collections | ||||||||||||||
import copy | ||||||||||||||
|
@@ -128,7 +129,6 @@ def modify_literal_uris(lit: Literal): | |||||||||||||
lit.scalar.structured_dataset.uri | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class TypeTransformerFailedError(TypeError, AssertionError, ValueError): ... | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
|
@@ -682,6 +682,33 @@ def to_generic_literal( | |||||||||||||
Set `FLYTE_USE_OLD_DC_FORMAT=true` to use the old JSON-based format. | ||||||||||||||
Note: This is deprecated and will be removed in the future. | ||||||||||||||
""" | ||||||||||||||
import pandas as pd | ||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||
|
||||||||||||||
def transform_dataclass(cls, memo=None): | ||||||||||||||
if memo is None: | ||||||||||||||
memo = {} | ||||||||||||||
|
||||||||||||||
if cls in memo: | ||||||||||||||
return memo[cls] | ||||||||||||||
|
||||||||||||||
cls_hints = get_type_hints(cls) | ||||||||||||||
new_field_defs = [] | ||||||||||||||
for field in fields(cls): | ||||||||||||||
orig_type = cls_hints[field.name] | ||||||||||||||
if orig_type == pd.DataFrame: | ||||||||||||||
new_type = StructuredDataset | ||||||||||||||
elif is_dataclass(orig_type): | ||||||||||||||
new_type = transform_dataclass(orig_type, memo) | ||||||||||||||
else: | ||||||||||||||
new_type = orig_type | ||||||||||||||
new_field_defs.append((field.name, new_type)) | ||||||||||||||
|
||||||||||||||
new_cls = make_dataclass("FlyteModified" + cls.__name__, new_field_defs) | ||||||||||||||
memo[cls] = new_cls | ||||||||||||||
return new_cls | ||||||||||||||
|
||||||||||||||
if isinstance(python_val, dict): | ||||||||||||||
json_str = json.dumps(python_val) | ||||||||||||||
return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) | ||||||||||||||
|
@@ -693,6 +720,7 @@ def to_generic_literal( | |||||||||||||
) | ||||||||||||||
|
||||||||||||||
self._make_dataclass_serializable(python_val, python_type) | ||||||||||||||
new_python_type = transform_dataclass(python_type) | ||||||||||||||
|
||||||||||||||
# JSON serialization using mashumaro's DataClassJSONMixin | ||||||||||||||
if isinstance(python_val, DataClassJSONMixin): | ||||||||||||||
|
@@ -728,7 +756,36 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp | |||||||||||||
f"user defined datatypes in Flytekit" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
import pandas as pd | ||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||
|
||||||||||||||
def transform_dataclass(cls, memo=None): | ||||||||||||||
if memo is None: | ||||||||||||||
memo = {} | ||||||||||||||
|
||||||||||||||
if cls in memo: | ||||||||||||||
return memo[cls] | ||||||||||||||
|
||||||||||||||
cls_hints = get_type_hints(cls) | ||||||||||||||
new_field_defs = [] | ||||||||||||||
for field in fields(cls): | ||||||||||||||
orig_type = cls_hints[field.name] | ||||||||||||||
if orig_type == pd.DataFrame: | ||||||||||||||
new_type = StructuredDataset | ||||||||||||||
elif is_dataclass(orig_type): | ||||||||||||||
new_type = transform_dataclass(orig_type, memo) | ||||||||||||||
else: | ||||||||||||||
new_type = orig_type | ||||||||||||||
new_field_defs.append((field.name, new_type)) | ||||||||||||||
|
||||||||||||||
new_cls = make_dataclass("FlyteModified" + cls.__name__, new_field_defs) | ||||||||||||||
memo[cls] = new_cls | ||||||||||||||
return new_cls | ||||||||||||||
|
||||||||||||||
self._make_dataclass_serializable(python_val, python_type) | ||||||||||||||
new_python_type = transform_dataclass(python_type) | ||||||||||||||
|
||||||||||||||
# The `to_json` integrated through mashumaro's `DataClassJSONMixin` allows for more | ||||||||||||||
# functionality than JSONEncoder | ||||||||||||||
|
@@ -741,10 +798,10 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp | |||||||||||||
# The function looks up or creates a MessagePackEncoder specifically designed for the object's type. | ||||||||||||||
# This encoder is then used to convert a data class into MessagePack Bytes. | ||||||||||||||
try: | ||||||||||||||
encoder = self._msgpack_encoder[python_type] | ||||||||||||||
encoder = self._msgpack_encoder[new_python_type] | ||||||||||||||
except KeyError: | ||||||||||||||
encoder = MessagePackEncoder(python_type) | ||||||||||||||
self._msgpack_encoder[python_type] = encoder | ||||||||||||||
encoder = MessagePackEncoder(new_python_type) | ||||||||||||||
self._msgpack_encoder[new_python_type] = encoder | ||||||||||||||
|
||||||||||||||
try: | ||||||||||||||
msgpack_bytes = encoder.encode(python_val) | ||||||||||||||
|
@@ -835,6 +892,9 @@ def get_expected_type(python_val: T, types: tuple) -> Type[T | None]: | |||||||||||||
} | ||||||||||||||
|
||||||||||||||
if not dataclasses.is_dataclass(python_type): | ||||||||||||||
import pandas as pd | ||||||||||||||
if isinstance(python_val, pd.DataFrame): | ||||||||||||||
python_val = StructuredDataset(dataframe=python_val, file_format="parquet") | ||||||||||||||
Comment on lines
+910
to
+912
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider dedicated transformer for DataFrame conversion
Consider moving the pandas DataFrame conversion logic to a dedicated transformer class instead of handling it in the Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #fedbf7 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||
return python_val | ||||||||||||||
|
||||||||||||||
# Transform str to FlyteFile or FlyteDirectory so that mashumaro can serialize the path. | ||||||||||||||
|
@@ -873,6 +933,10 @@ def _fix_val_int(self, t: typing.Type, val: typing.Any) -> typing.Any: | |||||||||||||
if t == int: | ||||||||||||||
return int(val) | ||||||||||||||
|
||||||||||||||
import pandas as pd | ||||||||||||||
if t == pd.DataFrame: | ||||||||||||||
return val().open(dataframe_type=pd.DataFrame).all() | ||||||||||||||
|
||||||||||||||
if isinstance(val, list): | ||||||||||||||
# Handle nested List. e.g. [[1, 2], [3, 4]] | ||||||||||||||
return list(map(lambda x: self._fix_val_int(ListTransformer.get_sub_type(t), x), val)) | ||||||||||||||
|
@@ -917,7 +981,8 @@ def from_binary_idl(self, binary_idl_object: Binary, expected_python_type: Type[ | |||||||||||||
self._msgpack_decoder[expected_python_type] = decoder | ||||||||||||||
dc = decoder.decode(binary_idl_object.value) | ||||||||||||||
|
||||||||||||||
return dc | ||||||||||||||
# return dc | ||||||||||||||
return self._fix_dataclass_int(expected_python_type, dc) | ||||||||||||||
else: | ||||||||||||||
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") | ||||||||||||||
|
||||||||||||||
|
@@ -928,28 +993,71 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: | |||||||||||||
"user defined datatypes in Flytekit" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
import pandas as pd | ||||||||||||||
from flytekit.types.structured.structured_dataset import StructuredDataset | ||||||||||||||
from typing import get_type_hints, Type, Dict | ||||||||||||||
|
||||||||||||||
def convert_dataclass(instance, target_cls): | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider adding type hints to function
The Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #b16268 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||
if not (is_dataclass(instance) and is_dataclass(target_cls)): | ||||||||||||||
return instance | ||||||||||||||
|
||||||||||||||
kwargs = {} | ||||||||||||||
target_fields = {f.name: f.type for f in fields(target_cls)} | ||||||||||||||
for field in fields(instance.__class__): | ||||||||||||||
if field.name in target_fields: | ||||||||||||||
value = getattr(instance, field.name) | ||||||||||||||
if is_dataclass(value) and is_dataclass(target_fields[field.name]): | ||||||||||||||
value = convert_dataclass(value, target_fields[field.name]) | ||||||||||||||
kwargs[field.name] = value | ||||||||||||||
return target_cls(**kwargs) | ||||||||||||||
|
||||||||||||||
def transform_dataclass(cls, memo=None): | ||||||||||||||
if memo is None: | ||||||||||||||
memo = {} | ||||||||||||||
|
||||||||||||||
if cls in memo: | ||||||||||||||
return memo[cls] | ||||||||||||||
|
||||||||||||||
cls_hints = get_type_hints(cls) | ||||||||||||||
new_field_defs = [] | ||||||||||||||
for field in fields(cls): | ||||||||||||||
orig_type = cls_hints[field.name] | ||||||||||||||
if orig_type == pd.DataFrame: | ||||||||||||||
new_type = StructuredDataset | ||||||||||||||
elif is_dataclass(orig_type): | ||||||||||||||
new_type = transform_dataclass(orig_type, memo) | ||||||||||||||
else: | ||||||||||||||
new_type = orig_type | ||||||||||||||
new_field_defs.append((field.name, new_type)) | ||||||||||||||
|
||||||||||||||
new_cls = make_dataclass("FlyteModified" + cls.__name__, new_field_defs) | ||||||||||||||
memo[cls] = new_cls | ||||||||||||||
return new_cls | ||||||||||||||
|
||||||||||||||
new_expected_python_type = transform_dataclass(expected_python_type) | ||||||||||||||
|
||||||||||||||
if lv.scalar and lv.scalar.binary: | ||||||||||||||
return self.from_binary_idl(lv.scalar.binary, expected_python_type) # type: ignore | ||||||||||||||
return convert_dataclass(self.from_binary_idl(lv.scalar.binary, new_expected_python_type), expected_python_type) # type: ignore | ||||||||||||||
|
||||||||||||||
json_str = _json_format.MessageToJson(lv.scalar.generic) | ||||||||||||||
|
||||||||||||||
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`. | ||||||||||||||
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder | ||||||||||||||
# We can't use hasattr(expected_python_type, "from_json") here because we rely on mashumaro's API to customize the deserialization behavior for Flyte types. | ||||||||||||||
if issubclass(expected_python_type, DataClassJSONMixin): | ||||||||||||||
dc = expected_python_type.from_json(json_str) # type: ignore | ||||||||||||||
if issubclass(new_expected_python_type, DataClassJSONMixin): | ||||||||||||||
dc = new_expected_python_type.from_json(json_str) # type: ignore | ||||||||||||||
else: | ||||||||||||||
# The function looks up or creates a JSONDecoder specifically designed for the object's type. | ||||||||||||||
# This decoder is then used to convert a JSON string into a data class. | ||||||||||||||
try: | ||||||||||||||
decoder = self._json_decoder[expected_python_type] | ||||||||||||||
decoder = self._json_decoder[new_expected_python_type] | ||||||||||||||
except KeyError: | ||||||||||||||
decoder = JSONDecoder(expected_python_type) | ||||||||||||||
self._json_decoder[expected_python_type] = decoder | ||||||||||||||
decoder = JSONDecoder(new_expected_python_type) | ||||||||||||||
self._json_decoder[new_expected_python_type] = decoder | ||||||||||||||
Comment on lines
+1077
to
+1078
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider caching decoder with original type
Consider storing the decoder in the original Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #fedbf7 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||
|
||||||||||||||
dc = decoder.decode(json_str) | ||||||||||||||
|
||||||||||||||
return self._fix_dataclass_int(expected_python_type, dc) | ||||||||||||||
return convert_dataclass(self._fix_dataclass_int(new_expected_python_type, dc), expected_python_type) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider validating dataclass field compatibility
The conversion from Code suggestionCheck the AI-generated fix before applying
Suggested change
Code Review Run #b16268 Is this a valid issue, or was it incorrectly flagged by the Agent?
|
||||||||||||||
|
||||||||||||||
# This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run`` | ||||||||||||||
# command needs to call guess_python_type to get the TypeEngine-derived dataclass. Without caching here, separate | ||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider extracting the
transform_dataclass
function to avoid code duplication. This function appears multiple times in the codebase (lines 689-710, 764-785, and 1014-1035) with identical functionality.Code suggestion
Code Review Run #9c627a
Is this a valid issue, or was it incorrectly flagged by the Agent?