From db2d50d92103a8e5ed95730995bf1ccc97f6611c Mon Sep 17 00:00:00 2001 From: xxsc0529 Date: Thu, 5 Feb 2026 14:18:47 +0800 Subject: [PATCH 1/2] fix: preserve table options when adding sparse vector indexes --- pyobvector/client/ob_vec_client.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pyobvector/client/ob_vec_client.py b/pyobvector/client/ob_vec_client.py index 5c21dc6..50912c0 100644 --- a/pyobvector/client/ob_vec_client.py +++ b/pyobvector/client/ob_vec_client.py @@ -97,14 +97,20 @@ def create_table_with_index_params( ) if sparse_vidxs is not None and len(sparse_vidxs) > 0: create_table_sql = str(CreateTable(table).compile(self.engine)) - new_sql = create_table_sql[: create_table_sql.rfind(")")] + # Preserve table options (e.g. ORGANIZATION=heap) after the closing ")" + last_paren = create_table_sql.rfind(")") + table_options_suffix = create_table_sql[ + last_paren: + ] # e.g. ")ORGANIZATION=heap" + new_sql = create_table_sql[:last_paren] for sparse_vidx in sparse_vidxs: sparse_params = sparse_vidx._parse_kwargs() if "type" in sparse_params: new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (type={sparse_params['type']}, distance=inner_product)" else: new_sql += f",\n\tVECTOR INDEX {sparse_vidx.index_name}({sparse_vidx.field_name}) with (distance=inner_product)" - new_sql += "\n)" + # Restore table options after the new closing ")" + new_sql += "\n)" + table_options_suffix[1:] conn.execute(text(new_sql)) else: table.create(self.engine, checkfirst=True) From 52028a1972268753c1fe146305025d31037e92c9 Mon Sep 17 00:00:00 2001 From: xxsc0529 Date: Tue, 10 Mar 2026 10:58:11 +0800 Subject: [PATCH 2/2] feat: ObClient embedded SeekDB support, docs, tests and CI - ObClient/ObVecClient: path=, engine=, pyseekdb_client= for embedded SeekDB - seekdb_engine: create_embedded_engine, create_engine_from_client, DBAPI wrapper - SeekdbRemoteClient: connect embedded (path/pyseekdb_client) or remote - dialect: has_table handles SeekDB RuntimeError for missing table - ob_client: check_table_exists/drop_table_if_exist safe for SeekDB - Optional dependency: pyobvector[pyseekdb] - README: Embedded SeekDB mode section and install note - tests/test_seekdb_embedded.py: connection, create_table, insert, ann_search - CI: test-embedded-seekdb job runs tests/test_seekdb_embedded.py Made-with: Cursor --- .github/workflows/ci.yml | 24 ++- README.md | 73 ++++++- pyobvector/__init__.py | 1 + pyobvector/client/__init__.py | 53 ++++- pyobvector/client/collection_schema.py | 5 +- pyobvector/client/fts_index_param.py | 5 +- pyobvector/client/index_param.py | 3 +- pyobvector/client/milvus_like_client.py | 69 +++--- pyobvector/client/ob_client.py | 142 ++++++++----- pyobvector/client/ob_vec_client.py | 56 +++-- pyobvector/client/ob_vec_json_table_client.py | 17 +- pyobvector/client/partitions.py | 27 ++- pyobvector/client/seekdb_engine.py | 156 ++++++++++++++ pyobvector/json_table/oceanbase_dialect.py | 7 +- pyobvector/json_table/virtual_data_type.py | 13 +- pyobvector/schema/array.py | 14 +- pyobvector/schema/dialect.py | 13 +- pyobvector/schema/geo_srid_point.py | 3 +- pyproject.toml | 6 +- tests/test_seekdb_embedded.py | 200 ++++++++++++++++++ 20 files changed, 723 insertions(+), 164 deletions(-) create mode 100644 pyobvector/client/seekdb_engine.py create mode 100644 tests/test_seekdb_embedded.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index dfec3b4..6b8709e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13", "3.14"] + python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"] image_tag: ["4.4.1.0-100000032025101610"] init_sql: ["ALTER SYSTEM ob_vector_memory_limit_percentage = 30; SET GLOBAL ob_query_timeout=100000000;"] test_filter: ["tests/test_hybrid_search.py::HybridSearchTest"] @@ -65,3 +65,25 @@ jobs: - name: Run tests run: | make test TEST_FILTER='${{ matrix.test_filter }}' + + test-embedded-seekdb: + name: Test embedded SeekDB + runs-on: ubuntu-latest + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Install uv + uses: astral-sh/setup-uv@v6 + with: + python-version: "3.12" + + - name: Install dependencies + run: uv sync --dev + + - name: Install pyseekdb (optional dependency for embedded SeekDB) + run: uv pip install pyseekdb + + - name: Run embedded SeekDB tests + run: | + uv run python -m pytest tests/test_seekdb_embedded.py -v diff --git a/README.md b/README.md index 86e89b9..da79971 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,12 @@ uv sync pip install pyobvector==0.2.23 ``` +- for **embedded SeekDB** support (local SeekDB without server): + +```shell +pip install pyobvector[pyseekdb] +``` + ## Build Doc You can build document locally with `sphinx`: @@ -33,10 +39,11 @@ For detailed release notes and changelog, see [RELEASE_NOTES.md](RELEASE_NOTES.m ## Usage -`pyobvector` supports three modes: +`pyobvector` supports four modes: - `Milvus compatible mode`: You can use the `MilvusLikeClient` class to use vector storage in a way similar to the Milvus API - `SQLAlchemy hybrid mode`: You can use the vector storage function provided by the `ObVecClient` class and execute the relational database statement with the SQLAlchemy library. In this mode, you can regard `pyobvector` as an extension of SQLAlchemy. +- `Embedded SeekDB mode`: Use `ObVecClient` or `SeekdbRemoteClient` with local embedded SeekDB (no server). Same API as remote: `create_table`, `insert`, `ann_search`, etc. Requires optional dependency: `pip install pyobvector[pyseekdb]`. - `Hybrid Search mode`: You can use the `HybridSearch` class to perform hybrid search that combines full-text search and vector similarity search, with Elasticsearch-compatible query syntax. ### Milvus compatible mode @@ -264,6 +271,70 @@ engine = create_async_engine(connection_str) - For further usage in pure `SQLAlchemy` mode, please refer to [SQLAlchemy](https://www.sqlalchemy.org/) +### Embedded SeekDB mode + +Use the same ObClient/ObVecClient API with **embedded SeekDB** (local file, no server). Install the optional dependency: + +```shell +pip install pyobvector[pyseekdb] +``` + +- connect with path or with an existing `pyseekdb.Client`: + +```python +from pyobvector import SeekdbRemoteClient, ObVecClient +from pyobvector.client.ob_client import ObClient + +# Option 1: path to SeekDB data directory +client = SeekdbRemoteClient(path="./seekdb_data", database="test") + +# Option 2: use an existing pyseekdb.Client +import pyseekdb +pyseekdb_client = pyseekdb.Client(path="./seekdb_data", database="test") +client = SeekdbRemoteClient(pyseekdb_client=pyseekdb_client) + +# Option 3: ObVecClient directly +client = ObVecClient(path="./seekdb_data", db_name="test") + +assert isinstance(client, ObVecClient) +assert isinstance(client, ObClient) +``` + +- create table, insert, and ann search (same API as remote): + +```python +from sqlalchemy import Column, Integer, VARCHAR +from pyobvector import VECTOR, VectorIndex, l2_distance + +client.drop_table_if_exist("vec_table") +client.create_table( + table_name="vec_table", + columns=[ + Column("id", Integer, primary_key=True), + Column("title", VARCHAR(255)), + Column("vec", VECTOR(3)), + ], + indexes=[VectorIndex("vec_idx", "vec", params="distance=l2, type=hnsw, lib=vsag")], + mysql_organization="heap", +) +client.insert("vec_table", data=[ + {"id": 1, "title": "doc A", "vec": [1.0, 1.0, 1.0]}, + {"id": 2, "title": "doc B", "vec": [1.0, 2.0, 3.0]}, +]) +res = client.ann_search( + "vec_table", + vec_data=[1.0, 2.0, 3.0], + vec_column_name="vec", + distance_func=l2_distance, + with_dist=True, + topk=5, + output_column_names=["id", "title"], +) +client.drop_table_if_exist("vec_table") +``` + +- See `tests/test_seekdb_embedded.py` for more examples. + ### Hybrid Search Mode `pyobvector` supports hybrid search that combines full-text search and vector similarity search, with query syntax compatible with Elasticsearch. This allows you to perform semantic search with both keyword matching and vector similarity in a single query. diff --git a/pyobvector/__init__.py b/pyobvector/__init__.py index 440b5bf..d6033c7 100644 --- a/pyobvector/__init__.py +++ b/pyobvector/__init__.py @@ -64,6 +64,7 @@ from .json_table import OceanBase __all__ = [ + "SeekdbRemoteClient", "ObVecClient", "MilvusLikeClient", "ObVecJsonTableClient", diff --git a/pyobvector/client/__init__.py b/pyobvector/client/__init__.py index 486b85e..1a3a772 100644 --- a/pyobvector/client/__init__.py +++ b/pyobvector/client/__init__.py @@ -5,8 +5,11 @@ 2. `SQLAlchemy hybrid mode`: You can use the vector storage function provided by the `ObVecClient` class and execute the relational database statement with the SQLAlchemy library. In this mode, you can regard `pyobvector` as an extension of SQLAlchemy. +3. `Embedded SeekDB`: ObClient/ObVecClient support path= or pyseekdb_client= for embedded +SeekDB (pip install pyobvector[pyseekdb]). Same API as remote: create_table, insert, etc. -* ObVecClient MySQL client in SQLAlchemy hybrid mode +* SeekdbRemoteClient Connect to embedded (path= / pyseekdb_client=) or remote; returns ObVecClient +* ObVecClient MySQL/SeekDB client in SQLAlchemy hybrid mode (uri, path, or pyseekdb_client) * MilvusLikeClient Milvus compatible client * VecIndexType VecIndexType is used to specify vector index type for MilvusLikeClient * IndexParam Specify vector index parameters for MilvusLikeClient @@ -31,6 +34,9 @@ * FtsIndexParam Full Text Search index parameter """ +import os +from typing import Any + from .ob_vec_client import ObVecClient from .milvus_like_client import MilvusLikeClient from .ob_vec_json_table_client import ObVecJsonTableClient @@ -40,7 +46,52 @@ from .partitions import * from .fts_index_param import FtsParser, FtsIndexParam + +def _resolve_password(password: str) -> str: + return password or os.environ.get("SEEKDB_PASSWORD", "") + + +def SeekdbRemoteClient( + path: str | None = None, + uri: str | None = None, + host: str | None = None, + port: int | None = None, + tenant: str = "test", + database: str = "test", + user: str | None = None, + password: str = "", + pyseekdb_client: Any | None = None, + **kwargs: Any, +) -> Any: + """ + Connect to embedded SeekDB (path= or pyseekdb_client=) or remote OceanBase/SeekDB (uri/host=). + Returns ObVecClient with the same API (create_table, insert, ann_search, etc.). + Embedded requires: pip install pyobvector[pyseekdb] + """ + password = _resolve_password(password) + if pyseekdb_client is not None: + return ObVecClient(pyseekdb_client=pyseekdb_client, **kwargs) + if path is not None: + return ObVecClient(path=path, db_name=database, **kwargs) + if uri is None and host is not None: + port = port if port is not None else 2881 + uri = f"{host}:{port}" + if uri is None: + uri = "127.0.0.1:2881" + ob_user = user if user is not None else "root" + if "@" not in ob_user: + ob_user = f"{ob_user}@{tenant}" + return ObVecClient( + uri=uri, + user=ob_user, + password=password, + db_name=database, + **kwargs, + ) + + __all__ = [ + "SeekdbRemoteClient", "ObVecClient", "MilvusLikeClient", "ObVecJsonTableClient", diff --git a/pyobvector/client/collection_schema.py b/pyobvector/client/collection_schema.py index 06bb3cd..1000fab 100644 --- a/pyobvector/client/collection_schema.py +++ b/pyobvector/client/collection_schema.py @@ -1,7 +1,6 @@ """FieldSchema & CollectionSchema definition module to be compatible with Milvus.""" import copy -from typing import Optional from sqlalchemy import Column from .schema_type import DataType, convert_datatype_to_sqltype from .exceptions import * @@ -129,8 +128,8 @@ class CollectionSchema: def __init__( self, - fields: Optional[list[FieldSchema]] = None, - partitions: Optional[ObPartition] = None, + fields: list[FieldSchema] | None = None, + partitions: ObPartition | None = None, description: str = "", # ignored in oceanbase **kwargs, ): diff --git a/pyobvector/client/fts_index_param.py b/pyobvector/client/fts_index_param.py index 07de4ab..813cd7e 100644 --- a/pyobvector/client/fts_index_param.py +++ b/pyobvector/client/fts_index_param.py @@ -1,7 +1,6 @@ """A module to specify fts index parameters""" from enum import Enum -from typing import Optional, Union class FtsParser(Enum): @@ -28,13 +27,13 @@ def __init__( self, index_name: str, field_names: list[str], - parser_type: Optional[Union[FtsParser, str]] = None, + parser_type: FtsParser | str | None = None, ): self.index_name = index_name self.field_names = field_names self.parser_type = parser_type - def param_str(self) -> Optional[str]: + def param_str(self) -> str | None: """Convert parser type to string format for SQL.""" if self.parser_type is None: return None # Default Space parser, no need to specify diff --git a/pyobvector/client/index_param.py b/pyobvector/client/index_param.py index 35565bf..646376e 100644 --- a/pyobvector/client/index_param.py +++ b/pyobvector/client/index_param.py @@ -1,7 +1,6 @@ """A module to specify vector index parameters for MilvusLikeClient""" from enum import Enum -from typing import Union class VecIndexType(Enum): @@ -42,7 +41,7 @@ def __init__( self, index_name: str, field_name: str, - index_type: Union[VecIndexType, str], + index_type: VecIndexType | str, **kwargs, ): self.index_name = index_name diff --git a/pyobvector/client/milvus_like_client.py b/pyobvector/client/milvus_like_client.py index 8db07ac..2eb7cf9 100644 --- a/pyobvector/client/milvus_like_client.py +++ b/pyobvector/client/milvus_like_client.py @@ -2,7 +2,6 @@ import logging import json -from typing import Optional, Union from sqlalchemy.exc import NoSuchTableError from sqlalchemy import ( @@ -52,15 +51,15 @@ def create_schema(self, **kwargs) -> CollectionSchema: def create_collection( self, collection_name: str, - dimension: Optional[int] = None, + dimension: int | None = None, primary_field_name: str = "id", - id_type: Union[DataType, str] = DataType.INT64, + id_type: DataType | str = DataType.INT64, vector_field_name: str = "vector", metric_type: str = "l2", auto_id: bool = False, - timeout: Optional[float] = None, - schema: Optional[CollectionSchema] = None, # Used for custom setup - index_params: Optional[IndexParams] = None, # Used for custom setup + timeout: float | None = None, + schema: CollectionSchema | None = None, # Used for custom setup + index_params: IndexParams | None = None, # Used for custom setup max_length: int = 16384, **kwargs, ): # pylint: disable=unused-argument @@ -149,7 +148,7 @@ def create_collection( def get_collection_stats( self, collection_name: str, - timeout: Optional[float] = None, # pylint: disable=unused-argument + timeout: float | None = None, # pylint: disable=unused-argument ) -> dict: """Get collection row count. @@ -171,7 +170,7 @@ def get_collection_stats( def has_collection( self, collection_name: str, - timeout: Optional[float] = None, # pylint: disable=unused-argument + timeout: float | None = None, # pylint: disable=unused-argument ) -> bool: # pylint: disable=unused-argument """Check if collection exists. @@ -196,7 +195,7 @@ def rename_collection( self, old_name: str, new_name: str, - timeout: Optional[float] = None, # pylint: disable=unused-argument + timeout: float | None = None, # pylint: disable=unused-argument ) -> None: """Rename collection. @@ -236,7 +235,7 @@ def create_index( self, collection_name: str, index_params: IndexParams, - timeout: Optional[float] = None, + timeout: float | None = None, **kwargs, ): # pylint: disable=unused-argument """Create vector index with index params. @@ -269,7 +268,7 @@ def drop_index( self, collection_name: str, index_name: str, - timeout: Optional[float] = None, + timeout: float | None = None, **kwargs, ): # pylint: disable=unused-argument """Drop index on specified collection. @@ -357,15 +356,15 @@ def _parse_value_for_text_sql( def search( self, collection_name: str, - data: Union[list, dict], + data: list | dict, anns_field: str, with_dist: bool = False, flter=None, limit: int = 10, - output_fields: Optional[list[str]] = None, - search_params: Optional[dict] = None, - timeout: Optional[float] = None, # pylint: disable=unused-argument - partition_names: Optional[list[str]] = None, + output_fields: list[str] | None = None, + search_params: dict | None = None, + timeout: float | None = None, # pylint: disable=unused-argument + partition_names: list[str] | None = None, **kwargs, # pylint: disable=unused-argument ) -> list[dict]: """Perform ann search. @@ -482,9 +481,9 @@ def query( self, collection_name: str, flter=None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, # pylint: disable=unused-argument - partition_names: Optional[list[str]] = None, + output_fields: list[str] | None = None, + timeout: float | None = None, # pylint: disable=unused-argument + partition_names: list[str] | None = None, **kwargs, # pylint: disable=unused-argument ) -> list[dict]: """Query records. @@ -549,10 +548,10 @@ def query( def get( self, collection_name: str, - ids: Union[list, str, int] = None, - output_fields: Optional[list[str]] = None, - timeout: Optional[float] = None, # pylint: disable=unused-argument - partition_names: Optional[list[str]] = None, + ids: list | str | int = None, + output_fields: list[str] | None = None, + timeout: float | None = None, # pylint: disable=unused-argument + partition_names: list[str] | None = None, **kwargs, # pylint: disable=unused-argument ) -> list[dict]: """Get records with specified primary field `ids`. @@ -592,7 +591,7 @@ def get( ) if isinstance(ids, list): where_in_clause = table.c[pkey_names[0]].in_(ids) - elif isinstance(ids, (str, int)): + elif isinstance(ids, str | int): where_in_clause = table.c[pkey_names[0]].in_([ids]) else: raise TypeError("'ids' is not a list/str/int") @@ -629,10 +628,10 @@ def get( def delete( self, collection_name: str, - ids: Optional[Union[list, str, int]] = None, - timeout: Optional[float] = None, # pylint: disable=unused-argument + ids: list | str | int | None = None, + timeout: float | None = None, # pylint: disable=unused-argument flter=None, - partition_name: Optional[str] = "", + partition_name: str | None = "", **kwargs, # pylint: disable=unused-argument ) -> dict: """Delete data in collection. @@ -667,7 +666,7 @@ def delete( ) if isinstance(ids, list): where_in_clause = table.c[pkey_names[0]].in_(ids) - elif isinstance(ids, (str, int)): + elif isinstance(ids, str | int): where_in_clause = table.c[pkey_names[0]].in_([ids]) else: raise TypeError("'ids' is not a list/str/int") @@ -691,9 +690,9 @@ def delete( def insert( self, collection_name: str, - data: Union[dict, list[dict]], - timeout: Optional[float] = None, - partition_name: Optional[str] = "", + data: dict | list[dict], + timeout: float | None = None, + partition_name: str | None = "", ) -> None: # pylint: disable=unused-argument """Insert data into collection. @@ -717,10 +716,10 @@ def insert( def upsert( self, collection_name: str, - data: Union[dict, list[dict]], - timeout: Optional[float] = None, # pylint: disable=unused-argument - partition_name: Optional[str] = "", - ) -> list[Union[str, int]]: + data: dict | list[dict], + timeout: float | None = None, # pylint: disable=unused-argument + partition_name: str | None = "", + ) -> list[str | int]: """Update data in table. If primary key is duplicated, replace it. Args: diff --git a/pyobvector/client/ob_client.py b/pyobvector/client/ob_client.py index 2715f0b..1254b04 100644 --- a/pyobvector/client/ob_client.py +++ b/pyobvector/client/ob_client.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union +from typing import Any from urllib.parse import quote import sqlalchemy.sql.functions as func_mod @@ -18,7 +18,6 @@ and_, ) from sqlalchemy.dialects import registry -from sqlalchemy.exc import NoSuchTableError from .index_param import IndexParams from .partitions import ObPartition @@ -40,8 +39,35 @@ logger.setLevel(logging.DEBUG) +def _get_ob_version_from_engine(engine: Any) -> "ObVersion": + """Get ObVersion from engine; supports both OceanBase (OB_VERSION) and SeekDB (VERSION).""" + with engine.connect() as conn: + with conn.begin(): + try: + res = conn.execute(text("SELECT OB_VERSION() FROM DUAL")) + version = [r[0] for r in res][0] + except Exception: + try: + res = conn.execute(text("SELECT VERSION()")) + version = [r[0] for r in res][0] + except Exception: + version = "4.3.3.0" + vs = str(version).strip() + parts = vs.split(".") + if len(parts) >= 4: + return ObVersion.from_db_version_string(".".join(parts[:4])) + if len(parts) == 3: + return ObVersion.from_db_version_string(vs + ".0") + return ObVersion.from_db_version_nums(4, 3, 3, 0) + + class ObClient: - """The OceanBase Client""" + """ + OceanBase / SeekDB client. Supports: + - Remote: uri + user + password + db_name + - Embedded SeekDB: path= or pyseekdb_client= (requires pip install pyobvector[pyseekdb]) + - External engine: engine= + """ def __init__( self, @@ -49,7 +75,10 @@ def __init__( user: str = "root@test", password: str = "", db_name: str = "test", - **kwargs, + path: str | None = None, + engine: Any | None = None, + pyseekdb_client: Any | None = None, + **kwargs: Any, ): registry.register( "mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect" @@ -64,23 +93,33 @@ def __init__( setattr(func_mod, "st_dwithin", st_dwithin) setattr(func_mod, "st_astext", st_astext) - user = quote(user, safe="") - password = quote(password, safe="") + engine_kw = {k: v for k, v in kwargs.items() if k != "pyseekdb_client"} + + if engine is not None: + self.engine = engine + elif pyseekdb_client is not None: + from .seekdb_engine import create_engine_from_client + + self.engine = create_engine_from_client(pyseekdb_client, **engine_kw) + elif path is not None: + from .seekdb_engine import create_embedded_engine + + self.engine = create_embedded_engine(path, database=db_name, **engine_kw) + else: + user_quoted = quote(user, safe="") + password_quoted = quote(password, safe="") + connection_str = f"mysql+oceanbase://{user_quoted}:{password_quoted}@{uri}/{db_name}?charset=utf8mb4" + self.engine = create_engine(connection_str, **engine_kw) - connection_str = ( - f"mysql+oceanbase://{user}:{password}@{uri}/{db_name}?charset=utf8mb4" - ) - self.engine = create_engine(connection_str, **kwargs) self.metadata_obj = MetaData() - self.metadata_obj.reflect(bind=self.engine) + try: + self.metadata_obj.reflect(bind=self.engine) + except Exception as e: + logger.debug("metadata reflect skipped: %s", e) - with self.engine.connect() as conn: - with conn.begin(): - res = conn.execute(text("SELECT OB_VERSION() FROM DUAL")) - version = [r[0] for r in res][0] - self.ob_version = ObVersion.from_db_version_string(version) + self.ob_version = _get_ob_version_from_engine(self.engine) - def refresh_metadata(self, tables: Optional[list[str]] = None): + def refresh_metadata(self, tables: list[str] | None = None): """Reload metadata from the database. Args: @@ -132,24 +171,28 @@ def _insert_partition_hint_for_query_sql(self, sql: str, partition_hint: str): + sql[first_space_after_from:] ) - def check_table_exists(self, table_name: str): - """Check if table exists. - - Args: - table_name (string): table name - - Returns: - bool: True if table exists, False otherwise - """ - inspector = inspect(self.engine) - return inspector.has_table(table_name) + def check_table_exists(self, table_name: str) -> bool: + """Check if table exists. Safe for embedded SeekDB (uses SHOW TABLES when needed).""" + try: + inspector = inspect(self.engine) + return inspector.has_table(table_name) + except Exception: + try: + with self.engine.connect() as conn: + r = conn.execute( + text("SHOW TABLES LIKE :name"), + {"name": table_name}, + ) + return r.fetchone() is not None + except Exception: + return False def create_table( self, table_name: str, columns: list[Column], - indexes: Optional[list[Index]] = None, - partitions: Optional[ObPartition] = None, + indexes: list[Index] | None = None, + partitions: ObPartition | None = None, **kwargs, ): """Create a table. @@ -191,16 +234,13 @@ def prepare_index_params(cls): """Create `IndexParams` to hold index configuration.""" return IndexParams() - def drop_table_if_exist(self, table_name: str): - """Drop table if exists.""" - try: - table = Table(table_name, self.metadata_obj, autoload_with=self.engine) - except NoSuchTableError: - return + def drop_table_if_exist(self, table_name: str) -> None: + """Drop table if exists. Safe for embedded SeekDB (avoids autoload on missing table).""" with self.engine.connect() as conn: with conn.begin(): - table.drop(self.engine, checkfirst=True) - self.metadata_obj.remove(table) + conn.execute(text(f"DROP TABLE IF EXISTS `{table_name}`")) + if table_name in self.metadata_obj.tables: + self.metadata_obj.remove(self.metadata_obj.tables[table_name]) def drop_index(self, table_name: str, index_name: str): """drop index on specified table. @@ -214,8 +254,8 @@ def drop_index(self, table_name: str, index_name: str): def insert( self, table_name: str, - data: Union[dict, list[dict]], - partition_name: Optional[str] = "", + data: dict | list[dict], + partition_name: str | None = "", ): """Insert data into table. @@ -246,8 +286,8 @@ def insert( def upsert( self, table_name: str, - data: Union[dict, list[dict]], - partition_name: Optional[str] = "", + data: dict | list[dict], + partition_name: str | None = "", ): """Update data in table. If primary key is duplicated, replace it. @@ -279,7 +319,7 @@ def update( table_name: str, values_clause, where_clause=None, - partition_name: Optional[str] = "", + partition_name: str | None = "", ): """Update data in table. @@ -323,9 +363,9 @@ def update( def delete( self, table_name: str, - ids: Optional[Union[list, str, int]] = None, + ids: list | str | int | None = None, where_clause=None, - partition_name: Optional[str] = "", + partition_name: str | None = "", ): """Delete data in table. @@ -343,7 +383,7 @@ def delete( if len(pkey_names) == 1: if isinstance(ids, list): where_in_clause = table.c[pkey_names[0]].in_(ids) - elif isinstance(ids, (str, int)): + elif isinstance(ids, str | int): where_in_clause = table.c[pkey_names[0]].in_([ids]) else: raise TypeError("'ids' is not a list/str/int") @@ -369,11 +409,11 @@ def delete( def get( self, table_name: str, - ids: Optional[Union[list, str, int]] = None, + ids: list | str | int | None = None, where_clause=None, - output_column_name: Optional[list[str]] = None, - partition_names: Optional[list[str]] = None, - n_limits: Optional[int] = None, + output_column_name: list[str] | None = None, + partition_names: list[str] | None = None, + n_limits: int | None = None, ): """Get records with specified primary field `ids`. @@ -400,7 +440,7 @@ def get( if ids is not None and len(pkey_names) == 1: if isinstance(ids, list): where_in_clause = table.c[pkey_names[0]].in_(ids) - elif isinstance(ids, (str, int)): + elif isinstance(ids, str | int): where_in_clause = table.c[pkey_names[0]].in_([ids]) else: raise TypeError("'ids' is not a list/str/int") diff --git a/pyobvector/client/ob_vec_client.py b/pyobvector/client/ob_vec_client.py index 50912c0..138af3c 100644 --- a/pyobvector/client/ob_vec_client.py +++ b/pyobvector/client/ob_vec_client.py @@ -1,7 +1,7 @@ """OceanBase Vector Store Client.""" import logging -from typing import Optional, Union +from typing import Any import numpy as np from sqlalchemy import ( @@ -38,9 +38,21 @@ def __init__( user: str = "root@test", password: str = "", db_name: str = "test", - **kwargs, + path: str | None = None, + engine: Any | None = None, + pyseekdb_client: Any | None = None, + **kwargs: Any, ): - super().__init__(uri, user, password, db_name, **kwargs) + super().__init__( + uri=uri, + user=user, + password=password, + db_name=db_name, + path=path, + engine=engine, + pyseekdb_client=pyseekdb_client, + **kwargs, + ) if self.ob_version < ObVersion.from_db_version_nums(4, 3, 3, 0): raise ClusterVersionException( @@ -49,7 +61,7 @@ def __init__( % ("Vector Store", "4.3.3.0"), ) - def _get_sparse_vector_index_params(self, vidxs: Optional[IndexParams]): + def _get_sparse_vector_index_params(self, vidxs: IndexParams | None): if vidxs is None: return None return [vidx for vidx in vidxs if vidx.is_index_type_sparse_vector()] @@ -58,10 +70,10 @@ def create_table_with_index_params( self, table_name: str, columns: list[Column], - indexes: Optional[list[Index]] = None, - vidxs: Optional[IndexParams] = None, - fts_idxs: Optional[list[FtsIndexParam]] = None, - partitions: Optional[ObPartition] = None, + indexes: list[Index] | None = None, + vidxs: IndexParams | None = None, + fts_idxs: list[FtsIndexParam] | None = None, + partitions: ObPartition | None = None, **kwargs, ): """Create table with optional index_params. @@ -149,7 +161,7 @@ def create_index( is_vec_index: bool, index_name: str, column_names: list[str], - vidx_params: Optional[str] = None, + vidx_params: str | None = None, **kw, ): """Create common index or vector index. @@ -279,18 +291,18 @@ def get_ob_hnsw_ef_search(self) -> int: def ann_search( self, table_name: str, - vec_data: Union[list, dict], + vec_data: list | dict, vec_column_name: str, distance_func, with_dist: bool = False, topk: int = 10, - output_column_names: Optional[list[str]] = None, - output_columns: Optional[Union[list, tuple]] = None, - extra_output_cols: Optional[list] = None, + output_column_names: list[str] | None = None, + output_columns: list | tuple | None = None, + extra_output_cols: list | None = None, where_clause=None, - partition_names: Optional[list[str]] = None, - idx_name_hint: Optional[list[str]] = None, - distance_threshold: Optional[float] = None, + partition_names: list[str] | None = None, + idx_name_hint: list[str] | None = None, + distance_threshold: float | None = None, **kwargs, ): # pylint: disable=unused-argument """Perform ann search. @@ -321,7 +333,7 @@ def ann_search( columns = [] if output_columns: - if isinstance(output_columns, (list, tuple)): + if isinstance(output_columns, list | tuple): columns = list(output_columns) else: columns = [output_columns] @@ -407,11 +419,11 @@ def post_ann_search( distance_func, with_dist: bool = False, topk: int = 10, - output_column_names: Optional[list[str]] = None, - extra_output_cols: Optional[list] = None, + output_column_names: list[str] | None = None, + extra_output_cols: list | None = None, where_clause=None, - partition_names: Optional[list[str]] = None, - str_list: Optional[list[str]] = None, + partition_names: list[str] | None = None, + str_list: list[str] | None = None, **kwargs, ): # pylint: disable=unused-argument """Perform post ann search. @@ -493,7 +505,7 @@ def precise_search( vec_column_name: str, distance_func, topk: int = 10, - output_column_names: Optional[list[str]] = None, + output_column_names: list[str] | None = None, where_clause=None, **kwargs, ): # pylint: disable=unused-argument diff --git a/pyobvector/client/ob_vec_json_table_client.py b/pyobvector/client/ob_vec_json_table_client.py index 9ba823d..92cef2d 100644 --- a/pyobvector/client/ob_vec_json_table_client.py +++ b/pyobvector/client/ob_vec_json_table_client.py @@ -1,7 +1,6 @@ import json import logging import re -from typing import Optional, Union from sqlalchemy import ( Column, @@ -134,7 +133,7 @@ def reflect(self, engine: Engine): def __init__( self, - user_id: Optional[str], + user_id: str | None, admin_id: str, uri: str = "127.0.0.1:2881", user: str = "root@test", @@ -182,8 +181,8 @@ def perform_json_table_sql( self, sql: str, select_with_data_id: bool = False, - opt_user_id: Optional[str] = None, - ) -> Union[Optional[CursorResult], int]: + opt_user_id: str | None = None, + ) -> CursorResult | None | int: """Perform common SQL that operates on JSON Table.""" ast = parse_one(sql, dialect="oceanbase") if isinstance(ast, exp.Create): @@ -346,7 +345,7 @@ def _handle_create_json_table(self, ast: Expression): def _check_table_exists(self, jtable_name: str) -> bool: return jtable_name in self.jmetadata.meta_cache - def _check_col_exists(self, jtable_name: str, col_name: str) -> Optional[dict]: + def _check_col_exists(self, jtable_name: str, col_name: str) -> dict | None: if not self._check_table_exists(jtable_name): return None for col_meta in self.jmetadata.meta_cache[jtable_name]: @@ -715,7 +714,7 @@ def _handle_alter_json_table(self, ast: Expression): session.close() def _handle_jtable_dml_insert( - self, ast: Expression, opt_user_id: Optional[str] = None + self, ast: Expression, opt_user_id: str | None = None ): real_user_id = opt_user_id or self.user_id @@ -799,7 +798,7 @@ def _handle_jtable_dml_insert( return n_new_records def _handle_jtable_dml_update( - self, ast: Expression, opt_user_id: Optional[str] = None + self, ast: Expression, opt_user_id: str | None = None ): real_user_id = opt_user_id or self.user_id @@ -855,7 +854,7 @@ def _handle_jtable_dml_update( return res.rowcount def _handle_jtable_dml_delete( - self, ast: Expression, opt_user_id: Optional[str] = None + self, ast: Expression, opt_user_id: str | None = None ): real_user_id = opt_user_id or self.user_id @@ -899,7 +898,7 @@ def _handle_jtable_dml_select( self, ast: Expression, select_with_data_id: bool = False, - opt_user_id: Optional[str] = None, + opt_user_id: str | None = None, ): real_user_id = opt_user_id or self.user_id diff --git a/pyobvector/client/partitions.py b/pyobvector/client/partitions.py index 77dc643..de41fa7 100644 --- a/pyobvector/client/partitions.py +++ b/pyobvector/client/partitions.py @@ -1,6 +1,5 @@ """A module to do compilation of OceanBase Parition Clause.""" -from typing import Optional, Union import logging from dataclasses import dataclass from .enum import IntEnum @@ -73,7 +72,7 @@ class RangeListPartInfo: """ part_name: str - part_upper_bound_expr: Union[list, str, int] + part_upper_bound_expr: list | str | int def get_part_expr_str(self): """Parse part_upper_bound_expr to text SQL.""" @@ -93,8 +92,8 @@ def __init__( self, is_range_columns: bool, range_part_infos: list[RangeListPartInfo], - range_expr: Optional[str] = None, - col_name_list: Optional[list[str]] = None, + range_expr: str | None = None, + col_name_list: list[str] | None = None, ): super().__init__(PartType.RangeColumns if is_range_columns else PartType.Range) self.range_part_infos = range_part_infos @@ -153,8 +152,8 @@ def __init__( self, is_range_columns: bool, range_part_infos: list[RangeListPartInfo], - range_expr: Optional[str] = None, - col_name_list: Optional[list[str]] = None, + range_expr: str | None = None, + col_name_list: list[str] | None = None, ): super().__init__(is_range_columns, range_part_infos, range_expr, col_name_list) self.is_sub = True @@ -194,8 +193,8 @@ def __init__( self, is_list_columns: bool, list_part_infos: list[RangeListPartInfo], - list_expr: Optional[str] = None, - col_name_list: Optional[list[str]] = None, + list_expr: str | None = None, + col_name_list: list[str] | None = None, ): super().__init__(PartType.ListColumns if is_list_columns else PartType.List) self.list_part_infos = list_part_infos @@ -253,8 +252,8 @@ def __init__( self, is_list_columns: bool, list_part_infos: list[RangeListPartInfo], - list_expr: Optional[str] = None, - col_name_list: Optional[list[str]] = None, + list_expr: str | None = None, + col_name_list: list[str] | None = None, ): super().__init__(is_list_columns, list_part_infos, list_expr, col_name_list) self.is_sub = True @@ -291,7 +290,7 @@ def __init__( self, hash_expr: str, hash_part_name_list: list[str] = None, - part_count: Optional[int] = None, + part_count: int | None = None, ): super().__init__(PartType.Hash) self.hash_expr = hash_expr @@ -341,7 +340,7 @@ def __init__( self, hash_expr: str, hash_part_name_list: list[str] = None, - part_count: Optional[int] = None, + part_count: int | None = None, ): super().__init__(hash_expr, hash_part_name_list, part_count) self.is_sub = True @@ -369,7 +368,7 @@ def __init__( self, col_name_list: list[str], key_part_name_list: list[str] = None, - part_count: Optional[int] = None, + part_count: int | None = None, ): super().__init__(PartType.Key) self.col_name_list = col_name_list @@ -423,7 +422,7 @@ def __init__( self, col_name_list: list[str], key_part_name_list: list[str] = None, - part_count: Optional[int] = None, + part_count: int | None = None, ): super().__init__(col_name_list, key_part_name_list, part_count) self.is_sub = True diff --git a/pyobvector/client/seekdb_engine.py b/pyobvector/client/seekdb_engine.py new file mode 100644 index 0000000..37f279f --- /dev/null +++ b/pyobvector/client/seekdb_engine.py @@ -0,0 +1,156 @@ +""" +Build a SQLAlchemy Engine from pyseekdb embedded client so ObClient/ObVecClient +work the same for both remote and embedded SeekDB. + +Requires optional dependency: pip install pyobvector[pyseekdb] +""" + +import re +from collections.abc import Mapping, Sequence +from typing import Any + +from sqlalchemy import create_engine +from sqlalchemy.pool import NullPool + + +def _pyformat_to_format(sql: str, params: Any) -> tuple[str, list[Any]]: + """Convert SQLAlchemy pyformat (%(name)s) + dict params to %s + list for pyseekdb.""" + if not isinstance(params, Mapping): + return sql, list(params) if params is not None else [] + + # Find placeholder names in order: %(name)s + pattern = re.compile(r"%\(([^)]+)\)s") + names = pattern.findall(sql) + if not names: + return sql, [] + + values = [params[n] for n in names] + new_sql = pattern.sub("%s", sql) + return new_sql, values + + +def _execute_via_pyseekdb(client: Any, sql: str, params: Any) -> list[dict[str, Any]]: + """Execute SQL via pyseekdb SeekdbEmbeddedClient; accepts dict or list params.""" + sql, param_list = _pyformat_to_format(sql, params) + conn = client.get_raw_connection() + return client._execute_query_with_cursor( + conn, sql, param_list, use_context_manager=False + ) + + +class _SeekdbCursor: + """DBAPI-2 style Cursor delegating to pyseekdb SeekdbEmbeddedClient.""" + + def __init__(self, client: Any) -> None: + self._client = client + self._description: list[tuple[str]] | None = None + self._rows: list[tuple] | None = None + self.rowcount = -1 + + def execute(self, operation: str, parameters: Sequence[Any] | None = None) -> None: + result = _execute_via_pyseekdb(self._client, operation, parameters or ()) + if not result: + self._description = None + self._rows = [] + self.rowcount = 0 + return + + def make_desc(name: str) -> tuple: + return (name, None, None, None, None, None, None) + + first = result[0] + if isinstance(first, dict): + keys = list(first.keys()) + self._description = [make_desc(k) for k in keys] + self._rows = [tuple(row[k] for k in keys) for row in result] + else: + n = len(first) + self._description = [make_desc(f"column_{i}") for i in range(n)] + self._rows = [ + tuple(row) if not isinstance(row, tuple) else row for row in result + ] + self.rowcount = len(self._rows) + + def fetchall(self) -> list[tuple]: + return self._rows or [] + + def fetchone(self) -> tuple | None: + if not self._rows: + return None + return self._rows.pop(0) + + @property + def description(self) -> list[tuple[str]] | None: + return self._description + + def close(self) -> None: + self._rows = None + self._description = None + + +class _SeekdbConnection: + """DBAPI-2 style Connection holding a pyseekdb SeekdbEmbeddedClient.""" + + def __init__(self, client: Any) -> None: + self._client = client + + def cursor(self) -> _SeekdbCursor: + return _SeekdbCursor(self._client) + + def close(self) -> None: + if hasattr(self._client, "_cleanup"): + self._client._cleanup() + + def commit(self) -> None: + pass + + def rollback(self) -> None: + pass + + def character_set_name(self) -> str: + return "utf8mb4" + + +def create_engine_from_client(pyseekdb_client: Any, **kwargs: Any): + """ + Create a SQLAlchemy Engine from an existing pyseekdb.Client. + + Use when you have client = pyseekdb.Client(path=..., database=...). + """ + from sqlalchemy.dialects import registry + + registry.register( + "mysql.oceanbase", "pyobvector.schema.dialect", "OceanBaseDialect" + ) + server = getattr(pyseekdb_client, "_server", None) + if server is None: + raise ValueError( + "pyseekdb_client must be a pyseekdb.Client instance (has _server). " + "Create with: pyseekdb.Client(path='./seekdb.db', database='test')" + ) + database = getattr(server, "database", "test") + + def creator() -> _SeekdbConnection: + return _SeekdbConnection(server) + + return create_engine( + "mysql+oceanbase://root:@127.0.0.1:2881/" + database, + creator=creator, + poolclass=NullPool, + **kwargs, + ) + + +def create_embedded_engine(path: str, database: str = "test", **kwargs: Any): + """ + Create a SQLAlchemy Engine from embedded SeekDB using official pyseekdb.Client(). + """ + try: + import pyseekdb + except ImportError as e: + raise ImportError( + "Embedded SeekDB requires: pip install pyobvector[pyseekdb]" + ) from e + + client = pyseekdb.Client(path=path, database=database) + return create_engine_from_client(client, **kwargs) diff --git a/pyobvector/json_table/oceanbase_dialect.py b/pyobvector/json_table/oceanbase_dialect.py index 2a07f65..c8010ec 100644 --- a/pyobvector/json_table/oceanbase_dialect.py +++ b/pyobvector/json_table/oceanbase_dialect.py @@ -1,4 +1,3 @@ -import typing as t from sqlglot import parser, exp, Expression from sqlglot.dialects.mysql import MySQL from sqlglot.tokens import TokenType @@ -30,7 +29,7 @@ class Parser(MySQL.Parser): "CHANGE": lambda self: self._parse_change_table_column(), } - def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]: + def _parse_alter_table_alter(self) -> exp.Expression | None: if self._match_texts(self.ALTER_ALTER_PARSERS): return self.ALTER_ALTER_PARSERS[self._prev.text.upper()](self) @@ -70,7 +69,7 @@ def _parse_alter_table_alter(self) -> t.Optional[exp.Expression]: using=self._match(TokenType.USING) and self._parse_assignment(), ) - def _parse_drop(self, exists: bool = False) -> t.Union[exp.Drop, exp.Command]: + def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command: temporary = self._match(TokenType.TEMPORARY) materialized = self._match_text_seq("MATERIALIZED") @@ -111,7 +110,7 @@ def _parse_drop(self, exists: bool = False) -> t.Union[exp.Drop, exp.Command]: concurrently=concurrently, ) - def _parse_change_table_column(self) -> t.Optional[exp.Expression]: + def _parse_change_table_column(self) -> exp.Expression | None: self._match(TokenType.COLUMN) origin_col = self._parse_field(any_token=True) column = self._parse_field() diff --git a/pyobvector/json_table/virtual_data_type.py b/pyobvector/json_table/virtual_data_type.py index 57e2ac7..937c715 100644 --- a/pyobvector/json_table/virtual_data_type.py +++ b/pyobvector/json_table/virtual_data_type.py @@ -1,7 +1,6 @@ from datetime import datetime from decimal import Decimal, InvalidOperation, ROUND_DOWN from enum import Enum -from typing import Optional from typing import Annotated from pydantic import BaseModel, Field, AfterValidator, create_model @@ -25,16 +24,16 @@ class JsonTableDataType(BaseModel): class JsonTableBool(JsonTableDataType): type: JType = Field(default=JType.J_BOOL) - val: Optional[bool] + val: bool | None class JsonTableTimestamp(JsonTableDataType): type: JType = Field(default=JType.J_TIMESTAMP) - val: Optional[datetime] + val: datetime | None def check_varchar_len_with_length(length: int): - def check_varchar_len(x: Optional[str]): + def check_varchar_len(x: str | None): if x is None: return None if len(x) > length: @@ -54,7 +53,7 @@ def get_json_table_varchar_type(self): "type": (JType, JType.J_VARCHAR), "val": ( Annotated[ - Optional[str], + str | None, AfterValidator(check_varchar_len_with_length(self.length)), ], ..., @@ -106,7 +105,7 @@ def get_json_table_decimal_type(self): "type": (JType, JType.J_DECIMAL), "val": ( Annotated[ - Optional[float], + float | None, AfterValidator( check_and_parse_decimal(self.ndigits, self.decimal_p) ), @@ -119,7 +118,7 @@ def get_json_table_decimal_type(self): class JsonTableInt(JsonTableDataType): type: JType = Field(default=JType.J_INT) - val: Optional[int] + val: int | None def val2json(val): diff --git a/pyobvector/schema/array.py b/pyobvector/schema/array.py index 61d17ee..8f42889 100644 --- a/pyobvector/schema/array.py +++ b/pyobvector/schema/array.py @@ -1,7 +1,7 @@ """ARRAY: An extended data type for SQLAlchemy""" import json -from typing import Any, Optional, Union +from typing import Any from collections.abc import Sequence from sqlalchemy.sql.type_api import TypeEngine @@ -14,7 +14,7 @@ class ARRAY(UserDefinedType): cache_ok = True _string = String() - def __init__(self, item_type: Union[TypeEngine, type]): + def __init__(self, item_type: TypeEngine | type): """Construct an ARRAY. Args: @@ -63,7 +63,7 @@ def bind_processor(self, dialect): item_proc = item_type.dialect_impl(dialect).bind_processor(dialect) - def process(value: Optional[Union[Sequence[Any] | str]]) -> Optional[str]: + def process(value: Sequence[Any] | str | None) -> str | None: if value is None: return None if isinstance(value, str): @@ -71,7 +71,7 @@ def process(value: Optional[Union[Sequence[Any] | str]]) -> Optional[str]: return value def convert(val): - if isinstance(val, (list, tuple)): + if isinstance(val, list | tuple): return [convert(v) for v in val] if item_proc: return item_proc(val) @@ -90,12 +90,12 @@ def result_processor(self, dialect, coltype): item_proc = item_type.dialect_impl(dialect).result_processor(dialect, coltype) - def process(value: Optional[str]) -> Optional[list[Any]]: + def process(value: str | None) -> list[Any] | None: if value is None: return None def convert(val): - if isinstance(val, (list, tuple)): + if isinstance(val, list | tuple): return [convert(v) for v in val] if item_proc: return item_proc(val) @@ -115,7 +115,7 @@ def literal_processor(self, dialect): def process(value: Sequence[Any]) -> str: def convert(val): - if isinstance(val, (list, tuple)): + if isinstance(val, list | tuple): return [convert(v) for v in val] if item_proc: return item_proc(val) diff --git a/pyobvector/schema/dialect.py b/pyobvector/schema/dialect.py index d9a469a..aa6cea1 100644 --- a/pyobvector/schema/dialect.py +++ b/pyobvector/schema/dialect.py @@ -2,6 +2,7 @@ from sqlalchemy import util from sqlalchemy.dialects.mysql import aiomysql, pymysql +from sqlalchemy.engine import reflection from .reflection import OceanBaseTableDefinitionParser from .vector import VECTOR @@ -12,7 +13,7 @@ class OceanBaseDialect(pymysql.MySQLDialect_pymysql): # not change dialect name, since it is a subclass of pymysql.MySQLDialect_pymysql # name = "oceanbase" - """Ocenbase dialect.""" + """OceanBase dialect. Compatible with SeekDB (embedded); has_table treats missing table as False.""" supports_statement_cache = True @@ -22,6 +23,16 @@ def __init__(self, **kwargs): self.ischema_names["SPARSEVECTOR"] = SPARSE_VECTOR self.ischema_names["point"] = POINT + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + """Override so SeekDB RuntimeError for non-existent table is treated as False.""" + try: + return super().has_table(connection, table_name, schema=schema, **kw) + except RuntimeError as e: + if "doesn't exist" in str(e) or "1146" in str(e): + return False + raise + @util.memoized_property def _tabledef_parser(self): """return the MySQLTableDefinitionParser, generate if needed. diff --git a/pyobvector/schema/geo_srid_point.py b/pyobvector/schema/geo_srid_point.py index 094a671..9d2c943 100644 --- a/pyobvector/schema/geo_srid_point.py +++ b/pyobvector/schema/geo_srid_point.py @@ -1,6 +1,5 @@ """Point: OceanBase GIS data type for SQLAlchemy""" -from typing import Optional from sqlalchemy.types import UserDefinedType, String @@ -13,7 +12,7 @@ class POINT(UserDefinedType): def __init__( self, # lat_long: Tuple[float, float], - srid: Optional[int] = None, + srid: int | None = None, ): """Init Latitude and Longitude.""" super(UserDefinedType, self).__init__() diff --git a/pyproject.toml b/pyproject.toml index 71fc763..4d46a01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ authors = [{name="shanhaikang.shk",email="shanhaikang.shk@oceanbase.com"}] readme = "README.md" license = "Apache-2.0" keywords = ["oceanbase", "vector store", "obvector"] -requires-python = ">=3.9" +requires-python = ">=3.10" dependencies = [ "numpy>=1.17.0", @@ -17,6 +17,10 @@ dependencies = [ "pydantic>=2.7.0,<3" ] +[project.optional-dependencies] +# pyseekdb requires Python>=3.11; only install on 3.11+ so uv sync --dev works on 3.10 +pyseekdb = ["pyseekdb>=0.1.0; python_version >= '3.11'"] + [project.urls] Homepage = "https://github.com/oceanbase/pyobvector" Repository = "https://github.com/oceanbase/pyobvector.git" diff --git a/tests/test_seekdb_embedded.py b/tests/test_seekdb_embedded.py new file mode 100644 index 0000000..2284077 --- /dev/null +++ b/tests/test_seekdb_embedded.py @@ -0,0 +1,200 @@ +""" +Tests for embedded SeekDB via ObClient/ObVecClient (path= or pyseekdb_client=). + +Requires optional dependency: pip install pyobvector[pyseekdb] +Tests are skipped when pyseekdb is not installed or when pylibseekdb (embedded runtime) is not available. +""" + +import tempfile +import unittest +from pathlib import Path + +try: + import pyseekdb # noqa: F401 + + PYSEEKDB_AVAILABLE = True +except ImportError: + PYSEEKDB_AVAILABLE = False + +try: + import pylibseekdb # noqa: F401 + + PYLIBSEEKDB_AVAILABLE = True +except ImportError: + PYLIBSEEKDB_AVAILABLE = False + + +def _skip_if_no_embedded(): + """Skip if pyseekdb or pylibseekdb not available (no embedded SeekDB).""" + if not PYSEEKDB_AVAILABLE: + raise unittest.SkipTest( + "pyseekdb not installed; run: pip install pyobvector[pyseekdb]" + ) + if not PYLIBSEEKDB_AVAILABLE: + raise unittest.SkipTest("pylibseekdb not available (embedded SeekDB runtime)") + + +@unittest.skipIf( + not PYSEEKDB_AVAILABLE, + "pyseekdb not installed; run: pip install pyobvector[pyseekdb]", +) +class TestSeekdbEmbeddedConnection(unittest.TestCase): + """Test ObClient/ObVecClient with embedded SeekDB (path= or pyseekdb_client=).""" + + def setUp(self) -> None: + _skip_if_no_embedded() + self.tmpdir = tempfile.mkdtemp(prefix="pyobvector_seekdb_") + self.db_path = str(Path(self.tmpdir) / "seekdb_data") + Path(self.db_path).mkdir(parents=True, exist_ok=True) + + def tearDown(self) -> None: + import shutil + + if hasattr(self, "tmpdir") and Path(self.tmpdir).exists(): + try: + shutil.rmtree(self.tmpdir, ignore_errors=True) + except Exception: + pass + + def test_seekdb_remote_client_path_returns_ob_vec_client(self): + from pyobvector import SeekdbRemoteClient, ObVecClient + from pyobvector.client.ob_client import ObClient + + client = SeekdbRemoteClient(path=self.db_path, database="test") + self.assertIsInstance(client, ObVecClient) + self.assertIsInstance(client, ObClient) + self.assertIsNotNone(client.engine) + self.assertIsNotNone(client.ob_version) + + def test_ob_vec_client_path(self): + from pyobvector import ObVecClient + + client = ObVecClient(path=self.db_path, db_name="test") + self.assertIsInstance(client, ObVecClient) + self.assertIsNotNone(client.engine) + self.assertIsNotNone(client.ob_version) + + def test_ob_vec_client_pyseekdb_client(self): + import pyseekdb + from pyobvector import ObVecClient + + pyseekdb_client = pyseekdb.Client(path=self.db_path, database="test") + client = ObVecClient(pyseekdb_client=pyseekdb_client) + self.assertIsInstance(client, ObVecClient) + self.assertIsNotNone(client.engine) + + def test_create_table_insert_drop(self): + """Test create_table, insert, and drop_table_if_exist via ObClient API.""" + from sqlalchemy import Column, Integer, VARCHAR + + from pyobvector import SeekdbRemoteClient, ObVecClient + from pyobvector.client.ob_client import ObClient + + client = SeekdbRemoteClient(path=self.db_path, database="test") + self.assertIsInstance(client, ObVecClient) + self.assertIsInstance(client, ObClient) + + table_name = "embed_api_table" + client.drop_table_if_exist(table_name) + self.assertFalse(client.check_table_exists(table_name)) + + client.create_table( + table_name=table_name, + columns=[ + Column("id", Integer, primary_key=True), + Column("name", VARCHAR(64)), + ], + ) + self.assertTrue(client.check_table_exists(table_name)) + + client.insert( + table_name, + data=[ + {"id": 1, "name": "alice"}, + {"id": 2, "name": "bob"}, + ], + ) + + from sqlalchemy import text + + with client.engine.connect() as conn: + with conn.begin(): + res = conn.execute( + text(f"SELECT id, name FROM `{table_name}` ORDER BY id") + ) + rows = res.fetchall() + self.assertEqual(len(rows), 2) + self.assertEqual(rows[0], (1, "alice")) + self.assertEqual(rows[1], (2, "bob")) + + client.drop_table_if_exist(table_name) + self.assertFalse(client.check_table_exists(table_name)) + + def test_vector_table_and_ann_search(self): + """Test create table with vector index, insert, and ann_search.""" + from sqlalchemy import Column, Integer, VARCHAR + + from pyobvector import ( + SeekdbRemoteClient, + VECTOR, + VectorIndex, + l2_distance, + ) + + client = SeekdbRemoteClient(path=self.db_path, database="test") + table_name = "embed_vec_table" + client.drop_table_if_exist(table_name) + + client.create_table( + table_name=table_name, + columns=[ + Column("id", Integer, primary_key=True), + Column("title", VARCHAR(255)), + Column("vec", VECTOR(3)), + ], + indexes=[ + VectorIndex( + "vec_idx", "vec", params="distance=l2, type=hnsw, lib=vsag" + ), + ], + mysql_organization="heap", + ) + client.insert( + table_name, + data=[ + {"id": 1, "title": "doc A", "vec": [1.0, 1.0, 1.0]}, + {"id": 2, "title": "doc B", "vec": [1.0, 2.0, 3.0]}, + {"id": 3, "title": "doc C", "vec": [3.0, 2.0, 1.0]}, + ], + ) + + res = client.ann_search( + table_name=table_name, + vec_data=[1.0, 2.0, 3.0], + vec_column_name="vec", + distance_func=l2_distance, + with_dist=True, + topk=3, + output_column_names=["id", "title"], + ) + rows = res.fetchall() + self.assertGreaterEqual(len(rows), 1) + self.assertEqual(len(rows[0]), 3) # id, title, distance + + client.drop_table_if_exist(table_name) + + +class TestSeekdbEmbeddedWithoutPyseekdb(unittest.TestCase): + """Test that using path= without pyseekdb raises a clear ImportError.""" + + def test_path_raises_without_pyseekdb(self): + from pyobvector import SeekdbRemoteClient + + if PYSEEKDB_AVAILABLE: + self.skipTest("pyseekdb is installed; cannot test ImportError") + + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(ImportError) as ctx: + SeekdbRemoteClient(path=tmpdir, database="test") + self.assertIn("pyseekdb", str(ctx.exception).lower()) + self.assertIn("pyobvector[pyseekdb]", str(ctx.exception))