Skip to content

Commit 14a682e

Browse files
committed
feat: Get complex postgres argument components directly.
1 parent 9e31cc3 commit 14a682e

File tree

16 files changed

+334
-120
lines changed

16 files changed

+334
-120
lines changed

src/sqlalchemy_declarative_extensions/dialects/mysql/function.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ def normalize(self) -> Function:
109109
norm_type = type_map.get(type_str.lower(), type_str.lower())
110110
normalized_parameters.append(f"{name} {norm_type}")
111111
else:
112-
normalized_parameters.append(param) # Keep as is if format unexpected
112+
normalized_parameters.append(
113+
param
114+
) # Keep as is if format unexpected
113115

114116
return replace(
115117
self,

src/sqlalchemy_declarative_extensions/dialects/mysql/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def get_functions_mysql(connection: Connection) -> Sequence[BaseFunction]:
9292
functions = []
9393
for f in connection.execute(functions_query, {"schema": database}).fetchall():
9494
parameters = None
95-
if f.parameters: # Parameter string might be None if no parameters
95+
if f.parameters: # Parameter string might be None if no parameters
9696
parameters = [p.strip() for p in f.parameters.split(",")]
9797

9898
functions.append(

src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@
8787
select(
8888
column("SPECIFIC_NAME").label("routine_name"),
8989
func.group_concat(
90-
text("concat(PARAMETER_NAME, ' ', DTD_IDENTIFIER) ORDER BY ORDINAL_POSITION SEPARATOR ', '"),
90+
text(
91+
"concat(PARAMETER_NAME, ' ', DTD_IDENTIFIER) ORDER BY ORDINAL_POSITION SEPARATOR ', '"
92+
),
9193
).label("parameters"),
9294
)
9395
.select_from(table("PARAMETERS", schema="INFORMATION_SCHEMA"))
@@ -107,7 +109,7 @@
107109
routine_table.c.sql_data_access.label("data_access"),
108110
parameters_subquery.c.parameters.label("parameters"),
109111
)
110-
.select_from( # Join routines with the parameter subquery
112+
.select_from( # Join routines with the parameter subquery
111113
routine_table.outerjoin(
112114
parameters_subquery,
113115
routine_table.c.routine_name == parameters_subquery.c.routine_name,

src/sqlalchemy_declarative_extensions/dialects/postgresql/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from sqlalchemy_declarative_extensions.dialects.postgresql.function import (
22
Function,
3+
FunctionParam,
4+
FunctionReturn,
35
FunctionSecurity,
46
FunctionVolatility,
57
)
@@ -43,6 +45,8 @@
4345
"DefaultGrantTypes",
4446
"Function",
4547
"FunctionGrants",
48+
"FunctionParam",
49+
"FunctionReturn",
4650
"FunctionSecurity",
4751
"FunctionVolatility",
4852
"Grant",

src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py

Lines changed: 184 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
import enum
44
import textwrap
55
from dataclasses import dataclass, replace
6-
from typing import List, Optional
6+
from typing import Any, Literal, Sequence, cast
7+
8+
from sqlalchemy import Column
79

810
from sqlalchemy_declarative_extensions.function import base
911
from sqlalchemy_declarative_extensions.sql import quote_name
@@ -33,44 +35,38 @@ def from_provolatile(cls, provolatile: str) -> FunctionVolatility:
3335
raise ValueError(f"Invalid volatility: {provolatile}")
3436

3537

36-
def normalize_arg(arg: str) -> str:
37-
parts = arg.strip().split(maxsplit=1)
38-
if len(parts) == 2:
39-
name, type_str = parts
40-
norm_type = type_map.get(type_str.lower(), type_str.lower())
41-
# Handle array types
42-
if norm_type.endswith("[]"):
43-
base_type = norm_type[:-2]
44-
norm_base_type = type_map.get(base_type, base_type)
45-
norm_type = f"{norm_base_type}[]"
46-
47-
return f"{name} {norm_type}"
48-
else:
49-
# Handle case where it might just be the type (e.g., from DROP FUNCTION)
50-
type_str = arg.strip()
51-
norm_type = type_map.get(type_str.lower(), type_str.lower())
52-
if norm_type.endswith("[]"):
53-
base_type = norm_type[:-2]
54-
norm_base_type = type_map.get(base_type, base_type)
55-
norm_type = f"{norm_base_type}[]"
56-
return norm_type
38+
# def normalize_arg(arg: str) -> str:
39+
# parts = arg.strip().split(maxsplit=1)
40+
# if len(parts) == 2:
41+
# name, type_str = parts
42+
# norm_type = type_map.get(type_str.lower(), type_str.lower())
43+
# # Handle array types
44+
# if norm_type.endswith("[]"):
45+
# base_type = norm_type[:-2]
46+
# norm_base_type = type_map.get(base_type, base_type)
47+
# norm_type = f"{norm_base_type}[]"
48+
#
49+
# return f"{name} {norm_type}"
50+
# # Handle case where it might just be the type (e.g., from DROP FUNCTION)
51+
# type_str = arg.strip()
52+
# norm_type = type_map.get(type_str.lower(), type_str.lower())
53+
# if norm_type.endswith("[]"):
54+
# base_type = norm_type[:-2]
55+
# norm_base_type = type_map.get(base_type, base_type)
56+
# norm_type = f"{norm_base_type}[]"
57+
# return norm_type
5758

5859

5960
@dataclass
6061
class Function(base.Function):
6162
"""Describes a PostgreSQL function.
6263
63-
Many attributes are not currently supported. Support is **currently**
64-
minimal due to being a means to an end for defining triggers, but can certainly
65-
be evaluated/added on request.
64+
Not all functionality is currently implemented, but can be evaluated/added on request.
6665
"""
6766

6867
security: FunctionSecurity = FunctionSecurity.invoker
69-
70-
#: Defines the parameters for the function, e.g. ["param1 int", "param2 varchar"]
71-
parameters: Optional[List[str]] = None
72-
73-
#: Defines the volatility of the function.
68+
returns: FunctionReturn | str | None = None # type: ignore
69+
parameters: Sequence[FunctionParam | str] | None = None # type: ignore
7470
volatility: FunctionVolatility = FunctionVolatility.VOLATILE
7571

7672
def to_sql_create(self, replace=False) -> list[str]:
@@ -81,11 +77,15 @@ def to_sql_create(self, replace=False) -> list[str]:
8177

8278
parameter_str = ""
8379
if self.parameters:
84-
parameter_str = ", ".join(self.parameters)
80+
parameter_str = ", ".join(
81+
cast(FunctionParam, p).to_sql_create() for p in self.parameters
82+
)
8583

8684
components.append("FUNCTION")
8785
components.append(quote_name(self.qualified_name) + f"({parameter_str})")
88-
components.append(f"RETURNS {self.returns}")
86+
87+
returns = cast(FunctionReturn, self.returns)
88+
components.append(f"RETURNS {returns.to_sql_create()}")
8989

9090
if self.security == FunctionSecurity.definer:
9191
components.append("SECURITY DEFINER")
@@ -102,17 +102,12 @@ def to_sql_update(self) -> list[str]:
102102
return self.to_sql_create(replace=True)
103103

104104
def to_sql_drop(self) -> list[str]:
105-
param_types = []
105+
param_str = ""
106106
if self.parameters:
107-
for param in self.parameters:
108-
# Naive split, assumes 'name type' or just 'type' format
109-
parts = param.split(maxsplit=1)
110-
if len(parts) == 2:
111-
param_types.append(parts[1])
112-
else:
113-
param_types.append(param) # Assume it's just the type if no space
107+
param_str = ", ".join(
108+
cast(FunctionParam, p).to_sql_drop() for p in self.parameters
109+
)
114110

115-
param_str = ", ".join(param_types)
116111
return [f"DROP FUNCTION {self.qualified_name}({param_str});"]
117112

118113
def with_security(self, security: FunctionSecurity):
@@ -124,36 +119,162 @@ def with_security_definer(self):
124119
def normalize(self) -> Function:
125120
definition = textwrap.dedent(self.definition)
126121

127-
# Handle RETURNS TABLE(...) normalization
128-
returns_lower = self.returns.lower().strip()
129-
if returns_lower.startswith("table("):
130-
# Basic normalization: lowercase and remove extra spaces
131-
# This might need refinement for complex TABLE definitions
132-
inner_content = returns_lower[len("table("):-1].strip()
133-
cols = [normalize_arg(c) for c in inner_content.split(',')]
134-
normalized_returns = f"table({', '.join(cols)})"
135-
else:
136-
# Normalize base return type (including array types)
137-
norm_type = type_map.get(returns_lower, returns_lower)
138-
if norm_type.endswith("[]"):
139-
base = norm_type[:-2]
140-
norm_base = type_map.get(base, base)
141-
normalized_returns = f"{norm_base}[]"
142-
else:
143-
normalized_returns = norm_type
144-
145122
# Normalize parameter types
146-
normalized_parameters = None
123+
parameters = []
147124
if self.parameters:
148-
normalized_parameters = [normalize_arg(p) for p in self.parameters]
125+
parameters = [
126+
FunctionParam.from_unknown(p).normalize() for p in self.parameters
127+
]
128+
129+
input_parameters = [p for p in parameters if p.is_input]
130+
table_parameters = [p for p in parameters if p.is_table]
131+
132+
returns = FunctionReturn.from_unknown(self.returns, parameters=table_parameters)
133+
if returns:
134+
returns = returns.normalize()
149135

150136
return replace(
151137
self,
152138
definition=definition,
153-
returns=normalized_returns,
154-
parameters=normalized_parameters, # Use normalized parameters
139+
returns=returns,
140+
parameters=input_parameters,
141+
)
142+
143+
144+
@dataclass
145+
class FunctionParam:
146+
name: str
147+
type: str
148+
default: Any | None = None
149+
mode: Literal["i", "o", "b", "v", "t"] | None = None
150+
151+
@classmethod
152+
def from_unknown(
153+
cls, source_param: str | tuple[str, str] | FunctionParam
154+
) -> FunctionParam:
155+
if isinstance(source_param, FunctionParam):
156+
return source_param
157+
158+
if isinstance(source_param, tuple):
159+
return cls(*source_param)
160+
161+
name, type = source_param.strip().split(maxsplit=1)
162+
return cls(name, type)
163+
164+
def normalize(self) -> FunctionParam:
165+
type = self.type.lower()
166+
return replace(
167+
self,
168+
name=self.name.lower(),
169+
mode=self.mode or "i",
170+
type=type_map.get(type, type),
171+
default=str(self.default) if self.default is not None else None,
155172
)
156173

174+
def to_sql_create(self) -> str:
175+
result = ""
176+
if self.mode:
177+
result += {"o": "OUT ", "b": "INOUT ", "v": "VARIADIC ", "t": "TABLE "}.get(
178+
self.mode, ""
179+
)
180+
181+
result += f"{self.name} {self.type}"
182+
183+
if self.default is not None:
184+
result += f" DEFAULT {self.default}"
185+
return result
186+
187+
def to_sql_drop(self) -> str:
188+
return self.type
189+
190+
@property
191+
def is_input(self) -> bool:
192+
"""Check if the parameter is an input parameter."""
193+
return self.mode not in {"o", "t"}
194+
195+
@property
196+
def is_table(self) -> bool:
197+
return self.mode == "t"
198+
199+
200+
@dataclass
201+
class FunctionReturn:
202+
value: str | None = None
203+
table: Sequence[Column | tuple[str, str] | str] | None = None
204+
205+
@classmethod
206+
def from_unknown(
207+
cls,
208+
source: str | FunctionReturn | None,
209+
parameters: list[FunctionParam] | None = None,
210+
) -> FunctionReturn | None:
211+
if source is None:
212+
return None
213+
214+
if isinstance(source, FunctionReturn):
215+
return source
216+
217+
# Handle RETURNS TABLE(...) normalization
218+
returns_lower = source.lower().strip()
219+
if returns_lower.startswith("table("):
220+
assert parameters is not None, (
221+
"Parameters must be provided for TABLE return type"
222+
)
223+
224+
table_return_params = [
225+
(p.name, p.type) for p in parameters if p.mode == "t"
226+
]
227+
return cls(table=table_return_params)
228+
# # Basic normalization: lowercase and remove extra spaces
229+
# # This might need refinement for complex TABLE definitions
230+
# inner_content = returns_lower[len("table(") : -1].strip()
231+
# cols = [normalize_arg(c) for c in inner_content.split(",")]
232+
# normalized_returns = f"table({', '.join(cols)})"
233+
# return cls()
234+
235+
# Normalize base return type (including array types)
236+
norm_type = type_map.get(returns_lower, returns_lower)
237+
if norm_type.endswith("[]"):
238+
base = norm_type[:-2]
239+
norm_base = type_map.get(base, base)
240+
normalized_returns = f"{norm_base}[]"
241+
else:
242+
normalized_returns = norm_type
243+
244+
return cls(value=normalized_returns)
245+
246+
def normalize(self) -> FunctionReturn:
247+
value = self.value
248+
249+
table = self.table
250+
if self.table:
251+
table = []
252+
for arg in self.table:
253+
if isinstance(arg, Column):
254+
arg_name = arg.name
255+
arg_type = arg.type.compile()
256+
elif isinstance(arg, tuple):
257+
arg_name, arg_type = arg
258+
else:
259+
arg_name, arg_type = arg.strip().split(maxsplit=1)
260+
261+
arg_type = arg_type.lower()
262+
arg_type = type_map.get(arg_type, arg_type)
263+
table.append((arg_name.lower(), arg_type))
264+
265+
return replace(self, value=value, table=table)
266+
267+
def to_sql_create(self) -> str:
268+
if self.value:
269+
return self.value or "void"
270+
271+
if self.table:
272+
table = cast(list[tuple[str, str]], self.table)
273+
table_args = ", ".join(f"{name} {type}" for name, type in table)
274+
return f"TABLE({table_args})"
275+
276+
raise NotImplementedError()
277+
157278

158279
type_map = {
159280
"bigint": "int8",

src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from collections.abc import Sequence
4+
from itertools import zip_longest
45
from typing import Container, List, cast
56

67
from sqlalchemy import Index, UniqueConstraint
@@ -13,6 +14,7 @@
1314
)
1415
from sqlalchemy_declarative_extensions.dialects.postgresql.function import (
1516
Function,
17+
FunctionParam,
1618
FunctionSecurity,
1719
FunctionVolatility,
1820
)
@@ -225,9 +227,15 @@ def get_functions_postgresql(connection: Connection) -> Sequence[BaseFunction]:
225227
schema = f.schema if f.schema != "public" else None
226228

227229
function = Function(
228-
parameters=(
229-
[p.strip() for p in f.parameters.split(",")] if f.parameters else None
230-
),
230+
parameters=[
231+
FunctionParam(name, type, default, mode)
232+
for name, type, default, mode in zip_longest(
233+
f.arg_names or [],
234+
f.arg_types or [],
235+
f.arg_defaults or [],
236+
f.arg_modes or [],
237+
)
238+
],
231239
volatility=FunctionVolatility.from_provolatile(f.volatility),
232240
name=name,
233241
definition=definition,

0 commit comments

Comments
 (0)