diff --git a/docs/source/functions.md b/docs/source/functions.md index c66fc4d..567b022 100644 --- a/docs/source/functions.md +++ b/docs/source/functions.md @@ -2,7 +2,11 @@ ```python from sqlalchemy.orm import declarative_base -from sqlalchemy_declarative_extensions import declarative_database, Function, Functions +from sqlalchemy_declarative_extensions import declarative_database, Functions + +# Import dialect-specific Function for full feature support +from sqlalchemy_declarative_extensions.dialects.postgresql import Function +# from sqlalchemy_declarative_extensions.dialects.mysql import Function _Base = declarative_base() @@ -22,22 +26,49 @@ class Base(_Base): """, language="plpgsql", returns="trigger", + ), + Function( + "gimme_rows", + ''' + SELECT id, name + FROM dem_rowz + WHERE group_id = _group_id; + ''', + language="sql", + parameters=["_group_id int"], + returns="TABLE(id int, name text)", + volatility='stable', # PostgreSQL specific characteristic ) + + # Example MySQL function + # Function( + # "gimme_concat", + # "RETURN CONCAT(label, ': ', CAST(val AS CHAR));", + # parameters=["val INT", "label VARCHAR(50)"], + # returns="VARCHAR(100)", + # deterministic=True, # MySQL specific + # data_access='NO SQL', # MySQL specific + # security='INVOKER', # MySQL specific + # ), ) ``` ```{note} Functions options are wildly different across dialects. As such, you should likely always use -the diaelect-specific `Function` object. +the dialect-specific `Function` object (e.g., `sqlalchemy_declarative_extensions.dialects.postgresql.Function` +or `sqlalchemy_declarative_extensions.dialects.mysql.Function`) to access all available features. +The base `Function` provides only the most common subset of options. ``` ```{note} -Function behavior (for eaxmple...arguments) is not fully implemented at current time, -although it **should** be functional for the options it does support. Any ability to instantiate -an object which produces a syntax error should be considered a bug. Additionally, feature requests -for supporting more function options are welcome! +Function comparison logic now supports parsing and comparing function parameters (including name and type) +and various dialect-specific characteristics: + +* **PostgreSQL:** `LANGUAGE`, `VOLATILITY`, `SECURITY`, `RETURNS TABLE(...)` syntax. +* **MySQL:** `DETERMINISTIC`, `DATA ACCESS`, `SECURITY`. -In particular, the current function support is heavily oriented around support for defining triggers. +The comparison logic handles normalization (e.g., mapping `integer` to `int4` in PostgreSQL) to ensure +accurate idempotency checks during Alembic autogeneration. ``` ```{eval-rst} @@ -52,3 +83,8 @@ any dialect-specific options. .. autoapimodule:: sqlalchemy_declarative_extensions.dialects.postgresql.function :members: Function, Procedure ``` + +```{eval-rst} +.. autoapimodule:: sqlalchemy_declarative_extensions.dialects.mysql.function + :members: Function +``` diff --git a/src/sqlalchemy_declarative_extensions/dialects/mysql/function.py b/src/sqlalchemy_declarative_extensions/dialects/mysql/function.py index e3cb0c8..18fa78d 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/mysql/function.py +++ b/src/sqlalchemy_declarative_extensions/dialects/mysql/function.py @@ -42,12 +42,17 @@ def from_unknown_function(cls, f: base.Function) -> Self: language=f.language, schema=f.schema, returns=f.returns, + parameters=f.parameters, ) def to_sql_create(self) -> list[str]: components = ["CREATE FUNCTION"] - components.append(self.qualified_name + "()") + parameter_str = "" + if self.parameters: + parameter_str = ", ".join(self.parameters) + + components.append(f"{self.qualified_name}({parameter_str})") components.append(f"RETURNS {self.returns}") if self.deterministic: @@ -85,9 +90,32 @@ def modifies_sql(self): def normalize(self) -> Function: definition = textwrap.dedent(self.definition).strip() + + # Remove optional trailing semicolon for comparison robustness + if definition.endswith(";"): + definition = definition[:-1] + returns = self.returns.lower() + normalized_returns = type_map.get(returns, returns) + + normalized_parameters = None + if self.parameters: + normalized_parameters = [] + for param in self.parameters: + # Naive split, assumes 'name type' format + parts = param.split(maxsplit=1) + if len(parts) == 2: + name, type_str = parts + norm_type = type_map.get(type_str.lower(), type_str.lower()) + normalized_parameters.append(f"{name} {norm_type}") + else: + normalized_parameters.append(param) # Keep as is if format unexpected + return replace( - self, definition=definition, returns=type_map.get(returns, returns) + self, + definition=definition, + returns=normalized_returns, + parameters=normalized_parameters, ) diff --git a/src/sqlalchemy_declarative_extensions/dialects/mysql/query.py b/src/sqlalchemy_declarative_extensions/dialects/mysql/query.py index b9b99c2..56580f0 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/mysql/query.py +++ b/src/sqlalchemy_declarative_extensions/dialects/mysql/query.py @@ -91,10 +91,15 @@ def get_functions_mysql(connection: Connection) -> Sequence[BaseFunction]: functions = [] for f in connection.execute(functions_query, {"schema": database}).fetchall(): + parameters = None + if f.parameters: # Parameter string might be None if no parameters + parameters = [p.strip() for p in f.parameters.split(",")] + functions.append( Function( name=f.name, definition=f.definition, + parameters=parameters, security=( FunctionSecurity.definer if f.security == "DEFINER" diff --git a/src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py b/src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py index 5c50cab..c1b4a16 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py +++ b/src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py @@ -1,4 +1,5 @@ from sqlalchemy import bindparam, column, table +from sqlalchemy.sql import func, text from sqlalchemy_declarative_extensions.sqlalchemy import select @@ -81,6 +82,21 @@ .where(routine_table.c.routine_type == "PROCEDURE") ) +# Need to query PARAMETERS separately to reconstruct the parameter list +parameters_subquery = ( + select( + column("SPECIFIC_NAME").label("routine_name"), + func.group_concat( + text("concat(PARAMETER_NAME, ' ', DTD_IDENTIFIER) ORDER BY ORDINAL_POSITION SEPARATOR ', '"), + ).label("parameters"), + ) + .select_from(table("PARAMETERS", schema="INFORMATION_SCHEMA")) + .where(column("SPECIFIC_SCHEMA") == bindparam("schema")) + .where(column("ROUTINE_TYPE") == "FUNCTION") + .group_by(column("SPECIFIC_NAME")) + .alias("parameters_sq") +) + functions_query = ( select( routine_table.c.routine_name.label("name"), @@ -89,6 +105,13 @@ routine_table.c.dtd_identifier.label("return_type"), routine_table.c.is_deterministic.label("deterministic"), routine_table.c.sql_data_access.label("data_access"), + parameters_subquery.c.parameters.label("parameters"), + ) + .select_from( # Join routines with the parameter subquery + routine_table.outerjoin( + parameters_subquery, + routine_table.c.routine_name == parameters_subquery.c.routine_name, + ) ) .where(routine_table.c.routine_schema == bindparam("schema")) .where(routine_table.c.routine_type == "FUNCTION") diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/__init__.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/__init__.py index 70e9d54..aa42db4 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/__init__.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/__init__.py @@ -1,6 +1,7 @@ from sqlalchemy_declarative_extensions.dialects.postgresql.function import ( Function, FunctionSecurity, + FunctionVolatility, ) from sqlalchemy_declarative_extensions.dialects.postgresql.grant import ( DefaultGrant, @@ -43,6 +44,7 @@ "Function", "FunctionGrants", "FunctionSecurity", + "FunctionVolatility", "Grant", "Grant", "GrantStatement", diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py index 5eec6dc..a4c6369 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py @@ -3,6 +3,7 @@ import enum import textwrap from dataclasses import dataclass, replace +from typing import List, Optional from sqlalchemy_declarative_extensions.function import base @@ -13,6 +14,47 @@ class FunctionSecurity(enum.Enum): definer = "DEFINER" +@enum.unique +class FunctionVolatility(enum.Enum): + VOLATILE = "VOLATILE" + STABLE = "STABLE" + IMMUTABLE = "IMMUTABLE" + + @classmethod + def from_provolatile(cls, provolatile: str) -> FunctionVolatility: + """Convert a `pg_proc.provolatile` value to a `FunctionVolatility` enum.""" + if provolatile == "v": + return cls.VOLATILE + if provolatile == "s": + return cls.STABLE + if provolatile == "i": + return cls.IMMUTABLE + raise ValueError(f"Invalid volatility: {provolatile}") + + +def normalize_arg(arg: str) -> str: + parts = arg.strip().split(maxsplit=1) + if len(parts) == 2: + name, type_str = parts + norm_type = type_map.get(type_str.lower(), type_str.lower()) + # Handle array types + if norm_type.endswith("[]"): + base_type = norm_type[:-2] + norm_base_type = type_map.get(base_type, base_type) + norm_type = f"{norm_base_type}[]" + + return f"{name} {norm_type}" + else: + # Handle case where it might just be the type (e.g., from DROP FUNCTION) + type_str = arg.strip() + norm_type = type_map.get(type_str.lower(), type_str.lower()) + if norm_type.endswith("[]"): + base_type = norm_type[:-2] + norm_base_type = type_map.get(base_type, base_type) + norm_type = f"{norm_base_type}[]" + return norm_type + + @dataclass class Function(base.Function): """Describes a PostgreSQL function. @@ -24,19 +66,32 @@ class Function(base.Function): security: FunctionSecurity = FunctionSecurity.invoker + #: Defines the parameters for the function, e.g. ["param1 int", "param2 varchar"] + parameters: Optional[List[str]] = None + + #: Defines the volatility of the function. + volatility: FunctionVolatility = FunctionVolatility.VOLATILE + def to_sql_create(self, replace=False) -> list[str]: components = ["CREATE"] if replace: components.append("OR REPLACE") + parameter_str = "" + if self.parameters: + parameter_str = ", ".join(self.parameters) + components.append("FUNCTION") - components.append(self.qualified_name + "()") + components.append(f"{self.qualified_name}({parameter_str})") components.append(f"RETURNS {self.returns}") if self.security == FunctionSecurity.definer: components.append("SECURITY DEFINER") + if self.volatility != FunctionVolatility.VOLATILE: + components.append(self.volatility.value) + components.append(f"LANGUAGE {self.language}") components.append(f"AS $${self.definition}$$") @@ -45,6 +100,20 @@ def to_sql_create(self, replace=False) -> list[str]: def to_sql_update(self) -> list[str]: return self.to_sql_create(replace=True) + def to_sql_drop(self) -> list[str]: + param_types = [] + if self.parameters: + for param in self.parameters: + # Naive split, assumes 'name type' or just 'type' format + parts = param.split(maxsplit=1) + if len(parts) == 2: + param_types.append(parts[1]) + else: + param_types.append(param) # Assume it's just the type if no space + + param_str = ", ".join(param_types) + return [f"DROP FUNCTION {self.qualified_name}({param_str});"] + def with_security(self, security: FunctionSecurity): return replace(self, security=security) @@ -53,9 +122,35 @@ def with_security_definer(self): def normalize(self) -> Function: definition = textwrap.dedent(self.definition) - returns = self.returns.lower() + + # Handle RETURNS TABLE(...) normalization + returns_lower = self.returns.lower().strip() + if returns_lower.startswith("table("): + # Basic normalization: lowercase and remove extra spaces + # This might need refinement for complex TABLE definitions + inner_content = returns_lower[len("table("):-1].strip() + cols = [normalize_arg(c) for c in inner_content.split(',')] + normalized_returns = f"table({', '.join(cols)})" + else: + # Normalize base return type (including array types) + norm_type = type_map.get(returns_lower, returns_lower) + if norm_type.endswith("[]"): + base = norm_type[:-2] + norm_base = type_map.get(base, base) + normalized_returns = f"{norm_base}[]" + else: + normalized_returns = norm_type + + # Normalize parameter types + normalized_parameters = None + if self.parameters: + normalized_parameters = [normalize_arg(p) for p in self.parameters] + return replace( - self, definition=definition, returns=type_map.get(returns, returns) + self, + definition=definition, + returns=normalized_returns, + parameters=normalized_parameters, # Use normalized parameters ) diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py index 4479183..bc651ed 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py @@ -14,6 +14,7 @@ from sqlalchemy_declarative_extensions.dialects.postgresql.function import ( Function, FunctionSecurity, + FunctionVolatility, ) from sqlalchemy_declarative_extensions.dialects.postgresql.procedure import ( Procedure, @@ -195,6 +196,7 @@ def get_procedures_postgresql(connection: Connection) -> Sequence[BaseProcedure] def get_functions_postgresql(connection: Connection) -> Sequence[BaseFunction]: functions = [] + for f in connection.execute(functions_query).fetchall(): name = f.name definition = f.source @@ -202,6 +204,10 @@ def get_functions_postgresql(connection: Connection) -> Sequence[BaseFunction]: schema = f.schema if f.schema != "public" else None function = Function( + parameters=( + [p.strip() for p in f.parameters.split(",")] if f.parameters else None + ), + volatility=FunctionVolatility.from_provolatile(f.volatility), name=name, definition=definition, language=language, @@ -211,7 +217,7 @@ def get_functions_postgresql(connection: Connection) -> Sequence[BaseFunction]: if f.security_definer else FunctionSecurity.invoker ), - returns=f.return_type, + returns=f.return_type_string or f.base_return_type, ) functions.append(function) diff --git a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py index fc62198..e07d77f 100644 --- a/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py +++ b/src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py @@ -85,6 +85,7 @@ column("prorettype"), column("prosecdef"), column("prokind"), + column("provolatile"), ) pg_language = table( @@ -290,10 +291,13 @@ def _schema_not_pg(column=pg_namespace.c.nspname): pg_proc.c.proname.label("name"), pg_namespace.c.nspname.label("schema"), pg_language.c.lanname.label("language"), - pg_type.c.typname.label("return_type"), + pg_type.c.typname.label("base_return_type"), pg_proc.c.prosrc.label("source"), pg_proc.c.prosecdef.label("security_definer"), pg_proc.c.prokind.label("kind"), + func.pg_get_function_arguments(pg_proc.c.oid).label("parameters"), + pg_proc.c.provolatile.label("volatility"), + func.pg_get_function_result(pg_proc.c.oid).label("return_type_string"), ) .select_from( pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid) diff --git a/src/sqlalchemy_declarative_extensions/function/base.py b/src/sqlalchemy_declarative_extensions/function/base.py index d797f4d..2ae8ddc 100644 --- a/src/sqlalchemy_declarative_extensions/function/base.py +++ b/src/sqlalchemy_declarative_extensions/function/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field, replace -from typing import Iterable, Sequence +from typing import Iterable, List, Optional, Sequence from sqlalchemy import MetaData from typing_extensions import Self @@ -24,6 +24,12 @@ class Function: language: str = "sql" schema: str | None = None + #: Defines the parameters for the function, e.g. ["param1 int", "param2 varchar"] + parameters: Optional[List[str]] = None + """List of parameter definitions as strings, e.g., `['param1 int', 'param2 varchar']`. + The comparison logic parses these strings to compare parameter names and types. + """ + @classmethod def from_unknown_function(cls, f: Function) -> Self: if isinstance(f, cls): @@ -35,6 +41,7 @@ def from_unknown_function(cls, f: Function) -> Self: language=f.language, schema=f.schema, returns=f.returns, + parameters=f.parameters, ) @property @@ -54,6 +61,11 @@ def to_sql_update(self) -> list[str]: ] def to_sql_drop(self) -> list[str]: + # Base implementation can only reliably drop functions without parameters + # since function overloads are determined by parameter types in most SQL dialects. + # Dialect-specific implementations should handle cases with parameters. + if self.parameters: + raise NotImplementedError("Dropping functions with parameters must be implemented by dialect-specific subclasses") return [f"DROP FUNCTION {self.qualified_name}();"] def with_name(self, name: str): diff --git a/tests/examples/test_function_create_with_params_mysql/alembic.ini b/tests/examples/test_function_create_with_params_mysql/alembic.ini new file mode 100644 index 0000000..9a7dfa6 --- /dev/null +++ b/tests/examples/test_function_create_with_params_mysql/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = migrations + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_mysql/migrations/env.py b/tests/examples/test_function_create_with_params_mysql/migrations/env.py new file mode 100644 index 0000000..bd392f9 --- /dev/null +++ b/tests/examples/test_function_create_with_params_mysql/migrations/env.py @@ -0,0 +1,27 @@ +from alembic import context + +# isort: split +from models import Base # Imports from the example's models.py +from sqlalchemy import engine_from_config, pool + +from sqlalchemy_declarative_extensions import register_alembic_events + +target_metadata = Base.metadata + +# Ensure functions=True is enabled +register_alembic_events(functions=True) + +connectable = context.config.attributes.get("connection", None) + +if connectable is None: + connectable = engine_from_config( + context.config.get_section(context.config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + +with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_mysql/migrations/script.py.mako b/tests/examples/test_function_create_with_params_mysql/migrations/script.py.mako new file mode 100644 index 0000000..50a043f --- /dev/null +++ b/tests/examples/test_function_create_with_params_mysql/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ''} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else 'pass'} + + +def downgrade(): + ${downgrades if downgrades else 'pass'} \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_mysql/migrations/versions/.gitkeep b/tests/examples/test_function_create_with_params_mysql/migrations/versions/.gitkeep new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/tests/examples/test_function_create_with_params_mysql/migrations/versions/.gitkeep @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_mysql/models.py b/tests/examples/test_function_create_with_params_mysql/models.py new file mode 100644 index 0000000..2993635 --- /dev/null +++ b/tests/examples/test_function_create_with_params_mysql/models.py @@ -0,0 +1,52 @@ +from sqlalchemy import Column, types +from sqlalchemy.orm import declarative_base + +from sqlalchemy_declarative_extensions import Functions, declarative_database +from sqlalchemy_declarative_extensions.dialects.mysql import ( + Function, + FunctionDataAccess, + FunctionSecurity, +) + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + functions = Functions().are( + Function( + "add_deterministic", + "RETURN i + 1;", + parameters=["i integer"], + returns="INTEGER", + deterministic=True, + data_access=FunctionDataAccess.no_sql, + ), + # NEW FUNCTION: Multiple params + Function( + "add_two_numbers", + "RETURN a + b;", + parameters=["a integer", "b integer"], + returns="INTEGER", + deterministic=True, + data_access=FunctionDataAccess.no_sql, + ), + # Complex function for Alembic test + Function( + "complex_processor", + "RETURN CONCAT(label, ': ', CAST(val AS CHAR));", + parameters=["val INT", "label VARCHAR(50)"], + returns="VARCHAR(100)", + deterministic=True, + data_access=FunctionDataAccess.no_sql, + security=FunctionSecurity.invoker, + ), + ) + + +# Include a dummy table just in case it's needed +class DummyTable(Base): + __tablename__ = "dummy_table" + id = Column(types.Integer, primary_key=True) \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_mysql/test_migrations.py b/tests/examples/test_function_create_with_params_mysql/test_migrations.py new file mode 100644 index 0000000..035c028 --- /dev/null +++ b/tests/examples/test_function_create_with_params_mysql/test_migrations.py @@ -0,0 +1,57 @@ +import pytest +from pytest_alembic import MigrationContext +from pytest_mock_resources import MysqlConfig, create_mysql_fixture +from sqlalchemy import text +from pytest_alembic import tests # Import the tests module + +# Use the mysql fixture provided by pytest-mock-resources +alembic_engine = create_mysql_fixture(scope="function", engine_kwargs={"echo": True}) + +@pytest.fixture(scope="session") +def pmr_mysql_config(): + '''Override the default config.''' + return MysqlConfig(image="mysql:8", port=None, ci_port=None) + +def test_apply_autogenerated_revision(alembic_runner: MigrationContext, alembic_engine): + '''Check that autogenerate detects the new function and the migration runs.''' + # Required for MySQL function creation without SUPER privilege + with alembic_engine.connect() as conn: + conn.execute(text("SET GLOBAL log_bin_trust_function_creators = ON;")) + conn.commit() # Commit the global setting change if needed + + # Generate the revision based on the models + alembic_runner.generate_revision(autogenerate=True, prevent_file_generation=False, message="Add MySQL functions") + + # Apply the generated migration + alembic_runner.migrate_up_one() + + # Verify the function exists and works + with alembic_engine.connect() as conn: + result = conn.execute(text("select add_deterministic(5)")).scalar() + assert result == 6 + + # Also verify the multi-param function + multi_result = conn.execute(text("select add_two_numbers(10, 20)")).scalar() + assert multi_result == 30 + + # Verify the complex function + complex_result = conn.execute(text("SELECT complex_processor(99, 'Complex')")).scalar() + assert complex_result == "Complex: 99" + + # Verify that a subsequent autogenerate is empty using pytest-alembic helper + tests.test_model_definitions_match_ddl(alembic_runner) + + # Apply the downgrade + alembic_runner.migrate_down_one() + + # Verify functions are gone by querying information_schema.ROUTINES + with alembic_engine.connect() as conn: + db_name = conn.execute(text("SELECT DATABASE()")).scalar() + count = conn.execute( + text("SELECT count(*) FROM information_schema.ROUTINES " + "WHERE ROUTINE_SCHEMA = :db_name " + "AND ROUTINE_TYPE = 'FUNCTION' " + "AND ROUTINE_NAME IN ('add_deterministic', 'add_two_numbers', 'complex_processor')"), + {"db_name": db_name} + ).scalar() + assert count == 0 \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_postgresql/alembic.ini b/tests/examples/test_function_create_with_params_postgresql/alembic.ini new file mode 100644 index 0000000..9a7dfa6 --- /dev/null +++ b/tests/examples/test_function_create_with_params_postgresql/alembic.ini @@ -0,0 +1,36 @@ +[alembic] +script_location = migrations + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_postgresql/migrations/env.py b/tests/examples/test_function_create_with_params_postgresql/migrations/env.py new file mode 100644 index 0000000..bd392f9 --- /dev/null +++ b/tests/examples/test_function_create_with_params_postgresql/migrations/env.py @@ -0,0 +1,27 @@ +from alembic import context + +# isort: split +from models import Base # Imports from the example's models.py +from sqlalchemy import engine_from_config, pool + +from sqlalchemy_declarative_extensions import register_alembic_events + +target_metadata = Base.metadata + +# Ensure functions=True is enabled +register_alembic_events(functions=True) + +connectable = context.config.attributes.get("connection", None) + +if connectable is None: + connectable = engine_from_config( + context.config.get_section(context.config.config_ini_section), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + +with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_postgresql/migrations/script.py.mako b/tests/examples/test_function_create_with_params_postgresql/migrations/script.py.mako new file mode 100644 index 0000000..50a043f --- /dev/null +++ b/tests/examples/test_function_create_with_params_postgresql/migrations/script.py.mako @@ -0,0 +1,24 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ''} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + + +def upgrade(): + ${upgrades if upgrades else 'pass'} + + +def downgrade(): + ${downgrades if downgrades else 'pass'} \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_postgresql/migrations/versions/.gitkeep b/tests/examples/test_function_create_with_params_postgresql/migrations/versions/.gitkeep new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/tests/examples/test_function_create_with_params_postgresql/migrations/versions/.gitkeep @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_postgresql/models.py b/tests/examples/test_function_create_with_params_postgresql/models.py new file mode 100644 index 0000000..b6e3db9 --- /dev/null +++ b/tests/examples/test_function_create_with_params_postgresql/models.py @@ -0,0 +1,42 @@ +from sqlalchemy import Column, types +from sqlalchemy.orm import declarative_base + +from sqlalchemy_declarative_extensions import Functions, declarative_database +from sqlalchemy_declarative_extensions.dialects.postgresql import Function, FunctionVolatility + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + functions = Functions().are( + Function( + "add_stable", + "SELECT i + 1;", + parameters=["i integer"], + returns="INTEGER", + volatility=FunctionVolatility.STABLE, + ), + # NEW FUNCTION: Multiple params, RETURNS TABLE, plpgsql + Function( + "get_users_by_group", + """ + BEGIN + -- Dummy implementation for testing definition + RETURN QUERY SELECT dt.id, dt.id::text as name FROM dummy_table dt WHERE dt.id = any(group_ids); + END; + """, + parameters=["group_ids integer[]"], + returns="TABLE(id integer, name text)", + language="plpgsql", # Requires plpgsql + volatility=FunctionVolatility.STABLE, + ), + ) + + +# Include a dummy table just in case it's needed +class DummyTable(Base): + __tablename__ = "dummy_table" + id = Column(types.Integer, primary_key=True) \ No newline at end of file diff --git a/tests/examples/test_function_create_with_params_postgresql/test_migrations.py b/tests/examples/test_function_create_with_params_postgresql/test_migrations.py new file mode 100644 index 0000000..5ea7588 --- /dev/null +++ b/tests/examples/test_function_create_with_params_postgresql/test_migrations.py @@ -0,0 +1,56 @@ +import pytest +from pytest_alembic import MigrationContext +from pytest_mock_resources import PostgresConfig, create_postgres_fixture +from sqlalchemy import text +from pathlib import Path +from pytest_alembic import tests # Import the tests module + +# Use the postgres fixture provided by pytest-mock-resources +alembic_engine = create_postgres_fixture(scope="function", engine_kwargs={"echo": True}) + +@pytest.fixture(scope="session") +def pmr_postgres_config(): + '''Override the default config to avoid port conflicts.''' + return PostgresConfig(port=None, ci_port=None) + +def test_apply_autogenerated_revision(alembic_runner: MigrationContext, alembic_engine): + '''Check that autogenerate detects the new function and the migration runs.''' + # Generate the revision based on the models + alembic_runner.generate_revision(autogenerate=True, prevent_file_generation=False, message="Add add_stable function") + + # Apply the generated migration + alembic_runner.migrate_up_one() + + # Verify the function exists and works + with alembic_engine.connect() as conn: + result = conn.execute(text("select add_stable(5)")).scalar() + assert result == 6 + + # Also verify the multi-param function + # Create some dummy data first + conn.execute(text("INSERT INTO dummy_table (id) VALUES (1), (3), (5)")) + conn.commit() + + multi_result = conn.execute(text("select * from get_users_by_group(ARRAY[1, 5, 7])")).fetchall() + + assert len(multi_result) == 2 + assert {row.id for row in multi_result} == {1, 5} + assert {row.name for row in multi_result} == {"1", "5"} + + # Verify that a subsequent autogenerate is empty using pytest-alembic helper + tests.test_model_definitions_match_ddl(alembic_runner) + + # Apply the downgrade + alembic_runner.migrate_down_one() + + # Verify functions are gone by querying pg_proc + with alembic_engine.connect() as conn: + # We need the schema OID first + schema_oid = conn.execute(text("SELECT oid FROM pg_namespace WHERE nspname = 'public'")).scalar() + + count = conn.execute( + text("SELECT count(*) FROM pg_proc " + "WHERE pronamespace = :schema_oid AND proname IN ('add_stable', 'get_users_by_group')"), + {"schema_oid": schema_oid} + ).scalar() + assert count == 0 \ No newline at end of file diff --git a/tests/function/test_alembic.py b/tests/function/test_alembic.py index 87d2045..757ec56 100644 --- a/tests/function/test_alembic.py +++ b/tests/function/test_alembic.py @@ -26,3 +26,14 @@ def test_function_leading_whitespace(pytester): @pytest.mark.alembic def test_function_rewriter(pytester): successful_test_run(pytester, count=1) + + +@pytest.mark.alembic +def test_function_create_with_params_postgresql(pytester): + successful_test_run(pytester, count=1) + + +@pytest.mark.alembic +@pytest.mark.mysql +def test_function_create_with_params_mysql(pytester): + successful_test_run(pytester, count=1) diff --git a/tests/function/test_create.py b/tests/function/test_create.py deleted file mode 100644 index 8268d41..0000000 --- a/tests/function/test_create.py +++ /dev/null @@ -1,52 +0,0 @@ -from pytest_mock_resources import create_postgres_fixture -from sqlalchemy import Column, Integer, text - -from sqlalchemy_declarative_extensions import ( - Function, - Functions, - declarative_database, - register_sqlalchemy_events, -) -from sqlalchemy_declarative_extensions.function.compare import compare_functions -from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base - -_Base = declarative_base() - - -@declarative_database -class Base(_Base): # type: ignore - __abstract__ = True - - functions = Functions().are( - Function( - "gimme", - "INSERT INTO foo (id) VALUES (DEFAULT); SELECT count(*) FROM foo;", - returns="INTEGER", - ) - ) - - -class Foo(Base): - __tablename__ = "foo" - - id = Column(Integer, primary_key=True) - - -register_sqlalchemy_events(Base.metadata, functions=True) - -pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True) - - -def test_create(pg): - Base.metadata.create_all(bind=pg.connection()) - pg.commit() - - result = pg.execute(text("SELECT gimme()")).scalar() - assert result == 1 - - result = pg.execute(text("SELECT gimme()")).scalar() - assert result == 2 - - connection = pg.connection() - diff = compare_functions(connection, Base.metadata.info["functions"]) - assert diff == [] diff --git a/tests/function/test_create_mysql.py b/tests/function/test_create_mysql.py index bf4c4cd..bd54097 100644 --- a/tests/function/test_create_mysql.py +++ b/tests/function/test_create_mysql.py @@ -6,7 +6,11 @@ declarative_database, register_sqlalchemy_events, ) -from sqlalchemy_declarative_extensions.dialects.mysql import Function +from sqlalchemy_declarative_extensions.dialects.mysql import ( + Function, + FunctionDataAccess, + FunctionSecurity, +) from sqlalchemy_declarative_extensions.function.compare import compare_functions from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base @@ -18,42 +22,82 @@ class Base(_Base): # type: ignore __abstract__ = True functions = Functions().are( + # Basic function Function( - "gimme", + "gimme_count", """ BEGIN - INSERT INTO foo (id) VALUES (DEFAULT); - RETURN (SELECT count(*) FROM foo); + DECLARE current_count INT; + INSERT INTO foo (id) VALUES (DEFAULT); + SELECT count(*) INTO current_count FROM foo; + RETURN current_count; END """, returns="INTEGER", - ).modifies_sql() + deterministic=False, # Explicitly non-deterministic due to insert + data_access=FunctionDataAccess.modifies_sql, + ), + # Function with parameters and deterministic + Function( + "add_deterministic", + "RETURN i + 1;", + parameters=["i integer"], + returns="INTEGER", + deterministic=True, + data_access=FunctionDataAccess.no_sql, # No SQL access + ), + # Complex function with multiple parameters and specific characteristics + Function( + "complex_processor", + "RETURN CONCAT(label, ': ', CAST(val AS CHAR));", + parameters=["val INT", "label VARCHAR(50)"], + returns="VARCHAR(100)", + deterministic=True, + data_access=FunctionDataAccess.no_sql, + security=FunctionSecurity.invoker, # Explicitly set security + ), ) class Foo(Base): __tablename__ = "foo" - - id = Column(Integer, primary_key=True) + id = Column(Integer, primary_key=True, autoincrement=True) register_sqlalchemy_events(Base.metadata, functions=True) -db = create_mysql_fixture(engine_kwargs={"echo": True}, session=True) +mysql = create_mysql_fixture(engine_kwargs={"echo": True}, session=True) -def test_create(db): - db.execute(text("SET GLOBAL log_bin_trust_function_creators = ON;")) +def test_create_mysql(mysql): + # Required for MySQL function creation without SUPER privilege + mysql.execute(text("SET GLOBAL log_bin_trust_function_creators = ON;")) - Base.metadata.create_all(bind=db.connection()) - db.commit() + Base.metadata.create_all(bind=mysql.connection()) + mysql.commit() - result = db.execute(text("SELECT gimme()")).scalar() + # Test gimme_count + result = mysql.execute(text("SELECT gimme_count()")).scalar() assert result == 1 - result = db.execute(text("SELECT gimme()")).scalar() + result = mysql.execute(text("SELECT gimme_count()")).scalar() assert result == 2 - connection = db.connection() + # Test add_deterministic + result_add = mysql.execute(text("SELECT add_deterministic(10)")).scalar() + assert result_add == 11 + + result_add_2 = mysql.execute(text("SELECT add_deterministic(1)")).scalar() + assert result_add_2 == 2 + + # Test complex_processor + result_complex = mysql.execute(text("SELECT complex_processor(123, 'Test')")).scalar() + assert result_complex == "Test: 123" + + result_complex_2 = mysql.execute(text("SELECT complex_processor(45, 'Another')")).scalar() + assert result_complex_2 == "Another: 45" + + # Verify comparison + connection = mysql.connection() diff = compare_functions(connection, Base.metadata.info["functions"]) - assert diff == [] + assert diff == [], f"Diff was: {diff}" # Added detail to assertion diff --git a/tests/function/test_create_postgresql.py b/tests/function/test_create_postgresql.py new file mode 100644 index 0000000..bc1e2ef --- /dev/null +++ b/tests/function/test_create_postgresql.py @@ -0,0 +1,83 @@ +from pytest_mock_resources import create_postgres_fixture +from sqlalchemy import Column, Integer, text + +from sqlalchemy_declarative_extensions import Functions, declarative_database, register_sqlalchemy_events +from sqlalchemy_declarative_extensions.dialects.postgresql import Function, FunctionVolatility +from sqlalchemy_declarative_extensions.function.compare import compare_functions +from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base + +_Base = declarative_base() + + +@declarative_database +class Base(_Base): # type: ignore + __abstract__ = True + + functions = Functions().are( + Function( + "gimme", + "INSERT INTO foo (id) VALUES (DEFAULT); SELECT count(*) FROM foo;", + returns="INTEGER", + ), + # New function with parameters and volatility + Function( + "add_stable", + "SELECT i + 1;", + parameters=["i integer"], + returns="INTEGER", + volatility=FunctionVolatility.STABLE, + ), + # Function returning TABLE + Function( + "generate_series_squared", + ''' + SELECT i, i*i + FROM generate_series(1, _limit) as i; + ''', + language="sql", + parameters=["_limit integer"], + returns="TABLE(num integer, num_squared integer)", + volatility=FunctionVolatility.IMMUTABLE, + ), + ) + + +class Foo(Base): + __tablename__ = "foo" + + id = Column(Integer, primary_key=True) + + +register_sqlalchemy_events(Base.metadata, functions=True) + +pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True) + + +def test_create(pg): + Base.metadata.create_all(bind=pg.connection()) + pg.commit() + + result = pg.execute(text("SELECT gimme()")).scalar() + assert result == 1 + + result = pg.execute(text("SELECT gimme()")).scalar() + assert result == 2 + + # Test function with parameters + result_params = pg.execute(text("SELECT add_stable(10)")).scalar() + assert result_params == 11 + + result_params_2 = pg.execute(text("SELECT add_stable(1)")).scalar() + assert result_params_2 == 2 + + # Test function returning table + result_table = pg.execute(text("SELECT * FROM generate_series_squared(3)")).fetchall() + assert result_table == [ + (1, 1), + (2, 4), + (3, 9), + ] + + connection = pg.connection() + diff = compare_functions(connection, Base.metadata.info["functions"]) + assert diff == []