From 73af563a06db215d86708554fc1a8c1c2af8cee6 Mon Sep 17 00:00:00 2001 From: Manisha4 Date: Fri, 21 Feb 2025 10:07:24 -0800 Subject: [PATCH] Adding in new changes --- .../cassandra_online_store.py | 96 ++++++++++--------- sdk/python/feast/sort_key.py | 12 ++- .../test_cassandra_online_store.py | 55 ++++++++++- 3 files changed, 112 insertions(+), 51 deletions(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index a412acc2ab..1f927e9062 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -25,6 +25,7 @@ from datetime import datetime from functools import partial from queue import Queue +from tokenize import Double from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from cassandra.auth import PlainTextAuthProvider @@ -49,6 +50,17 @@ from feast.rate_limiter import SlidingWindowRateLimiter from feast.repo_config import FeastConfigBaseModel from feast.sorted_feature_view import SortedFeatureView +from feast.types import ( + Bool, + Bytes, + Float32, + Float64, + Int32, + Int64, + String, + UnixTimestamp, + from_value_type, +) # Error messages E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS = ( @@ -686,7 +698,12 @@ def _drop_table( logger.info(f"Deleting table {fqtable}.") session.execute(drop_cql) - def _create_table(self, config: RepoConfig, project: str, table: Union[FeatureView, SortedFeatureView]): + def _create_table( + self, + config: RepoConfig, + project: str, + table: Union[FeatureView, SortedFeatureView], + ): """Handle the CQL (low-level) creation of a table.""" session: Session = self._get_session(config) keyspace: str = self._keyspace @@ -710,42 +727,43 @@ def _create_table(self, config: RepoConfig, project: str, table: Union[FeatureVi session.execute(create_cql) def _build_sorted_table_cql( - self, project: str, table: SortedFeatureView, fqtable: str + self, project: str, table: SortedFeatureView, fqtable: str ) -> str: """ Build the CQL statement for creating a SortedFeatureView table with custom entity and sort key columns. """ - # Define columns for entity columns. - entity_columns = [ - f"{col.name} {self._get_cql_type(col.value_type)}" - for col in table.spec.entity_columns + feature_columns = [ + f"{feature.name} {self._get_cql_type(feature.dtype)}" + for feature in table.features ] - # Define columns and ordering for sort keys. sort_key_columns = [ - f"{sk.name} {self._get_cql_type(sk.value_type)}" for sk in table.spec.sort_keys + f"{sk.name} {self._get_cql_type(from_value_type(sk.value_type))}" + for sk in table.sort_keys ] + sort_key_orders = [ f"{sk.name} {'ASC' if sk.default_sort_order == SortOrder.Enum.ASC else 'DESC'}" - for sk in table.spec.sort_keys + for sk in table.sort_keys ] - all_columns = entity_columns + sort_key_columns - sort_key_names = ", ".join([col.split()[0] for col in sort_key_columns]) - create_cql = f""" - CREATE TABLE IF NOT EXISTS {fqtable} ( - entity_key TEXT, - {', '.join(all_columns)}, - event_ts TIMESTAMP, - created_ts TIMESTAMP, - PRIMARY KEY ((entity_key), {sort_key_names}) - ) WITH CLUSTERING ORDER BY ({', '.join(sort_key_orders)}) - AND COMMENT='project={project}, feature_view={table.name}'; - """ + feature_columns_str = ",".join(feature_columns) + + create_cql = ( + f"CREATE TABLE IF NOT EXISTS {fqtable} (\n" + f" entity_key TEXT,\n" + f" {feature_columns_str},\n" + f" event_ts TIMESTAMP,\n" + f" created_ts TIMESTAMP,\n" + f" PRIMARY KEY ((entity_key), {sort_key_names})\n" + f") WITH CLUSTERING ORDER BY ({', '.join(sort_key_orders)})\n" + f"AND COMMENT='project={project}, feature_view={table.name}';" + ) return create_cql.strip() + def _get_cql_statement( self, config: RepoConfig, op_name: str, fqtable: str, **kwargs ): @@ -784,34 +802,18 @@ def _get_cql_type(self, value_type: ValueType) -> str: """Map Feast value types to Cassandra CQL data types.""" # Mapping for scalar types. scalar_mapping = { - ValueType.BYTES: "BLOB", - ValueType.STRING: "TEXT", - ValueType.INT32: "INT", - ValueType.INT64: "BIGINT", - ValueType.DOUBLE: "DOUBLE", - ValueType.FLOAT: "FLOAT", - ValueType.BOOL: "BOOLEAN", - ValueType.UNIX_TIMESTAMP: "TIMESTAMP", - } - - # Mapping for list types. - list_mapping = { - ValueType.BYTES_LIST: "BLOB", - ValueType.STRING_LIST: "TEXT", - ValueType.INT32_LIST: "INT", - ValueType.INT64_LIST: "BIGINT", - ValueType.DOUBLE_LIST: "DOUBLE", - ValueType.FLOAT_LIST: "FLOAT", - ValueType.BOOL_LIST: "BOOLEAN", - ValueType.UNIX_TIMESTAMP_LIST: "TIMESTAMP", + Bytes: "BLOB", + String: "TEXT", + Int32: "INT", + Int64: "BIGINT", + Double: "DOUBLE", + Float32: "FLOAT", + Float64: "FLOAT", + Bool: "BOOLEAN", + UnixTimestamp: "TIMESTAMP", } if value_type in scalar_mapping: return scalar_mapping[value_type] - elif value_type in list_mapping: - # Use CQL's collection type for lists. - return f"list<{list_mapping[value_type]}>" - elif value_type in {ValueType.UNKNOWN, ValueType.NULL}: - raise ValueError(f"Unsupported value type: {value_type}") else: - raise ValueError(f"Unsupported value type: {value_type}") \ No newline at end of file + raise ValueError(f"Unsupported type: {value_type}") diff --git a/sdk/python/feast/sort_key.py b/sdk/python/feast/sort_key.py index 24590d6a3d..0e72741fb8 100644 --- a/sdk/python/feast/sort_key.py +++ b/sdk/python/feast/sort_key.py @@ -1,5 +1,5 @@ import warnings -from typing import Dict, Optional +from typing import Dict, Optional, Union from typeguard import typechecked @@ -10,6 +10,7 @@ from feast.protos.feast.core.SortedFeatureView_pb2 import ( SortOrder, ) +from feast.types import ComplexFeastType, PrimitiveFeastType from feast.value_type import ValueType warnings.simplefilter("ignore", DeprecationWarning) @@ -38,13 +39,18 @@ class SortKey: def __init__( self, name: str, - value_type: ValueType, + value_type: Union[ValueType, PrimitiveFeastType, ComplexFeastType], default_sort_order: SortOrder.Enum.ValueType = SortOrder.ASC, tags: Optional[Dict[str, str]] = None, description: str = "", ): self.name = name - self.value_type = value_type + if isinstance(value_type, ValueType): + self.value_type = value_type + elif isinstance(value_type, (PrimitiveFeastType, ComplexFeastType)): + self.value_type = value_type.to_value_type() + else: + raise ValueError(f"Unsupported value type: {value_type}") self.default_sort_order = default_sort_order self.tags = tags or {} self.description = description diff --git a/sdk/python/tests/unit/infra/online_store/test_cassandra_online_store.py b/sdk/python/tests/unit/infra/online_store/test_cassandra_online_store.py index ba37dc4441..e60d766f06 100644 --- a/sdk/python/tests/unit/infra/online_store/test_cassandra_online_store.py +++ b/sdk/python/tests/unit/infra/online_store/test_cassandra_online_store.py @@ -1,10 +1,15 @@ +import textwrap + import pytest -from feast import FeatureView +from feast import Entity, FeatureView, Field from feast.infra.offline_stores.file_source import FileSource from feast.infra.online_stores.contrib.cassandra_online_store.cassandra_online_store import ( CassandraOnlineStore, ) +from feast.protos.feast.core.SortedFeatureView_pb2 import SortOrder +from feast.sorted_feature_view import SortedFeatureView, SortKey +from feast.types import Int64, String @pytest.fixture @@ -13,6 +18,31 @@ def file_source(): return file_source +@pytest.fixture +def sorted_feature_view(file_source): + return SortedFeatureView( + name="test_sorted_feature_view", + entities=[Entity(name="entity1", join_keys=["entity1_id"])], + source=FileSource(name="my_file_source", path="test.parquet"), + schema=[ + Field(name="feature1", dtype=Int64), + Field(name="feature2", dtype=String), + ], + sort_keys=[ + SortKey( + name="sort_key1", + value_type=Int64, + default_sort_order=SortOrder.Enum.ASC, # use the enum value + ), + SortKey( + name="sort_key2", + value_type=String, + default_sort_order=SortOrder.Enum.DESC, + ), + ], + ) + + def test_fq_table_name_v1_within_limit(file_source): keyspace = "test_keyspace" project = "test_project" @@ -71,3 +101,26 @@ def test_fq_table_name_invalid_version(file_source): with pytest.raises(ValueError) as excinfo: CassandraOnlineStore._fq_table_name(keyspace, project, table, 3) assert "Unknown table name format version: 3" in str(excinfo.value) + + +def test_build_sorted_table_cql(sorted_feature_view): + project = "test_project" + fqtable = "test_keyspace.test_project_test_sorted_feature_view" + + expected_cql = textwrap.dedent("""\ + CREATE TABLE IF NOT EXISTS test_keyspace.test_project_test_sorted_feature_view ( + entity_key TEXT, + feature1 BIGINT,feature2 TEXT, + event_ts TIMESTAMP, + created_ts TIMESTAMP, + PRIMARY KEY ((entity_key), sort_key1, sort_key2) + ) WITH CLUSTERING ORDER BY (sort_key1 ASC, sort_key2 DESC) + AND COMMENT='project=test_project, feature_view=test_sorted_feature_view'; + """).strip() + + cassandra_online_store = CassandraOnlineStore() + actual_cql = cassandra_online_store._build_sorted_table_cql( + project, sorted_feature_view, fqtable + ) + + assert actual_cql == expected_cql