Skip to content

Commit

Permalink
Revert changes
Browse files Browse the repository at this point in the history
  • Loading branch information
vbhagwat committed Jan 24, 2025
1 parent c1e7659 commit 0e0e8cf
Showing 1 changed file with 82 additions and 103 deletions.
185 changes: 82 additions & 103 deletions sdk/python/feast/infra/contrib/spark_kafka_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from typing import List, Optional, Set, Union, no_type_check

import pandas as pd
import pyarrow
from pyspark import SparkContext
from pyspark.sql import DataFrame, Row, SparkSession
from pyspark.sql import functions as F
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.avro.functions import from_avro
from pyspark.sql.column import Column, _to_java_column
from pyspark.sql.functions import col, from_json
from pyspark.sql.streaming import StreamingQuery
from pyspark.sql.window import Window

from feast import FeatureView
from feast.data_format import AvroFormat, ConfluentAvroFormat, JsonFormat, StreamFormat
Expand All @@ -21,12 +20,12 @@
StreamProcessor,
StreamTable,
)
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.materialization.contrib.spark.spark_materialization_engine import (
_SparkSerializedArtifacts,
)
from feast.infra.provider import get_provider
from feast.stream_feature_view import StreamFeatureView
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping


class SparkProcessorConfig(ProcessorConfig):
Expand All @@ -48,9 +47,9 @@ def _from_confluent_avro(column: Column, abris_config) -> Column:


def _to_abris_config(
schema_registry_config: dict,
record_name: str,
record_namespace: str,
schema_registry_config: dict,
record_name: str,
record_namespace: str,
):
""":return: za.co.absa.abris.config.FromAvroConfig"""
topic = schema_registry_config["schema.registry.topic"]
Expand All @@ -73,12 +72,12 @@ class SparkKafkaProcessor(StreamProcessor):
join_keys: List[str]

def __init__(
self,
*,
fs: FeatureStore,
sfv: Union[StreamFeatureView, FeatureView],
config: ProcessorConfig,
preprocess_fn: Optional[MethodType] = None,
self,
*,
fs: FeatureStore,
sfv: Union[StreamFeatureView, FeatureView],
config: ProcessorConfig,
preprocess_fn: Optional[MethodType] = None,
):
if not isinstance(sfv.stream_source, KafkaSource):
raise ValueError("data source is not kafka source")
Expand Down Expand Up @@ -120,7 +119,7 @@ def __init__(

def _create_infra_if_necessary(self):
if self.fs.config.online_store is not None and getattr(
self.fs.config.online_store, "lazy_table_creation", False
self.fs.config.online_store, "lazy_table_creation", False
):
print(
f"Online store {self.fs.config.online_store.__class__.__name__} supports lazy table creation and it is enabled"
Expand All @@ -136,7 +135,7 @@ def _create_infra_if_necessary(self):
)

def ingest_stream_feature_view(
self, to: PushMode = PushMode.ONLINE
self, to: PushMode = PushMode.ONLINE
) -> StreamingQuery:
self._create_infra_if_necessary()
ingested_stream_df = self._ingest_stream_data()
Expand Down Expand Up @@ -224,8 +223,8 @@ def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:
if self.sfv.stream_source is not None:
if self.sfv.stream_source.field_mapping is not None:
for (
field_mapping_key,
field_mapping_value,
field_mapping_key,
field_mapping_value,
) in self.sfv.stream_source.field_mapping.items():
df = df.withColumn(field_mapping_value, df[field_mapping_key])

Expand Down Expand Up @@ -256,100 +255,81 @@ def _construct_transformation_plan(self, df: StreamTable) -> StreamTable:

def _write_stream_data_expedia(self, df: StreamTable, to: PushMode):
"""
Streamlines data writing logic
Ensures materialization logic in sync with stream ingestion.
Support only write to online store. No support for preprocess_fn also.
In Spark 3.2.2, toPandas() is throwing error when the dataframe has Boolean columns.
To fix this error, we need spark 3.4.0 or numpy < 1.20.0 but feast needs numpy >= 1.22.
Switching to use mapInPandas to solve the problem for boolean columns and
toPandas() also load all data into driver's memory.
Error Message:
AttributeError: module 'numpy' has no attribute 'bool'.
`np.bool` was a deprecated alias for the builtin `bool`.
To avoid this error in existing code, use `bool` by itself.
Doing this will not modify any behavior and is safe.
If you specifically wanted the numpy scalar type, use `np.bool_` here.
"""

def online_write_with_connector(config, table, data, progress, spark):
"""
Write a batch of features to the online store.
"""
project = config.project
keyspace = self._keyspace
fqtable = f"{project}_{table.name}"

def prepare_rows():
"""
Transform data into a list of Spark Row objects for insertion.
"""
rows = []
for entity_key, values, timestamp, created_ts in data:
entity_key_bin = serialize_entity_key(
entity_key,
entity_key_serialization_version=config.entity_key_serialization_version,
).hex()

for feature_name, val in values.items():
rows.append(
Row(
feature_name=feature_name,
entity_key=entity_key_bin,
feature_value=val.SerializeToString(),
event_timestamp=timestamp,
created_timestamp=created_ts,
)
)

if progress:
progress(1)

return rows

rows = prepare_rows()
if rows:
df = spark.createDataFrame(rows)

# Write to ScyllaDB using the Cassandra connector
# TODO: Support writing to offline store and preprocess_fn. Remove _write_stream_data method

# Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema.
def batch_write_pandas_df(iterator, spark_serialized_artifacts, join_keys):
for pdf in iterator:
(
df.write.format("org.apache.spark.sql.cassandra")
.options(table=fqtable, keyspace=keyspace)
.mode("append")
.save()
feature_view,
online_store,
repo_config,
) = spark_serialized_artifacts.unserialize()

if isinstance(feature_view, StreamFeatureView):
ts_field = feature_view.timestamp_field
else:
ts_field = feature_view.stream_source.timestamp_field

# Extract the latest feature values for each unique entity row (i.e. the join keys).
pdf = (
pdf.sort_values(by=[*join_keys, ts_field], ascending=False)
.groupby(join_keys)
.nth(0)
)

if progress:
progress(1)

def batch_write(
sdf: DataFrame,
batch_id: int,
spark_serialized_artifacts,
join_keys,
feature_view,
spark_session,
):
"""
Write each batch of data to the online store.
"""
start_time = time.time()
table = pyarrow.Table.from_pandas(pdf)
if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

# Extract latest feature values per entity and write to the online store
latest_df = (
sdf.withColumn(
"row_number",
F.row_number().over(
Window.partitionBy(*join_keys).orderBy(
F.desc(feature_view.timestamp_field)
)
),
join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}
rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
.filter(F.col("row_number") == 1)
.drop("row_number")
)

rows_to_write = (
latest_df.collect()
) # Convert to rows for online_write_with_connector

# Deserialize artifacts and write the batch
feature_view, online_store, repo_config = (
spark_serialized_artifacts.unserialize()
)
online_write_with_connector(
repo_config, feature_view, rows_to_write, None, spark_session
)
yield pd.DataFrame([pd.Series(range(1, 2))]) # dummy result

def batch_write(
sdf: DataFrame,
batch_id: int,
spark_serialized_artifacts,
join_keys,
feature_view,
):
start_time = time.time()
sdf.mapInPandas(
lambda x: batch_write_pandas_df(
x, spark_serialized_artifacts, join_keys
),
"status int",
).count() # dummy action to force evaluation
print(
f"Time taken to write batch {batch_id}: {(time.time() - start_time) * 1000:.2f} ms"
f"Time taken to write batch {batch_id} is: {(time.time() - start_time) * 1000:.2f} ms"
)

query = (
Expand All @@ -363,7 +343,6 @@ def batch_write(
self.spark_serialized_artifacts,
self.join_keys,
self.sfv,
self.spark,
)
)
.start()
Expand Down Expand Up @@ -416,4 +395,4 @@ def batch_write(row: DataFrame, batch_id: int):
)

query.awaitTermination(timeout=self.query_timeout)
return query
return query

0 comments on commit 0e0e8cf

Please sign in to comment.