3
3
import enum
4
4
import textwrap
5
5
from dataclasses import dataclass , replace
6
- from typing import List , Optional
6
+ from typing import Any , Literal , Sequence , cast
7
+
8
+ from sqlalchemy import Column
7
9
8
10
from sqlalchemy_declarative_extensions .function import base
9
11
from sqlalchemy_declarative_extensions .sql import quote_name
@@ -33,44 +35,38 @@ def from_provolatile(cls, provolatile: str) -> FunctionVolatility:
33
35
raise ValueError (f"Invalid volatility: { provolatile } " )
34
36
35
37
36
- def normalize_arg (arg : str ) -> str :
37
- parts = arg .strip ().split (maxsplit = 1 )
38
- if len (parts ) == 2 :
39
- name , type_str = parts
40
- norm_type = type_map .get (type_str .lower (), type_str .lower ())
41
- # Handle array types
42
- if norm_type .endswith ("[]" ):
43
- base_type = norm_type [:- 2 ]
44
- norm_base_type = type_map .get (base_type , base_type )
45
- norm_type = f"{ norm_base_type } []"
46
-
47
- return f"{ name } { norm_type } "
48
- else :
49
- # Handle case where it might just be the type (e.g., from DROP FUNCTION)
50
- type_str = arg .strip ()
51
- norm_type = type_map .get (type_str .lower (), type_str .lower ())
52
- if norm_type .endswith ("[]" ):
53
- base_type = norm_type [:- 2 ]
54
- norm_base_type = type_map .get (base_type , base_type )
55
- norm_type = f"{ norm_base_type } []"
56
- return norm_type
38
+ # def normalize_arg(arg: str) -> str:
39
+ # parts = arg.strip().split(maxsplit=1)
40
+ # if len(parts) == 2:
41
+ # name, type_str = parts
42
+ # norm_type = type_map.get(type_str.lower(), type_str.lower())
43
+ # # Handle array types
44
+ # if norm_type.endswith("[]"):
45
+ # base_type = norm_type[:-2]
46
+ # norm_base_type = type_map.get(base_type, base_type)
47
+ # norm_type = f"{norm_base_type}[]"
48
+ #
49
+ # return f"{name} {norm_type}"
50
+ # # Handle case where it might just be the type (e.g., from DROP FUNCTION)
51
+ # type_str = arg.strip()
52
+ # norm_type = type_map.get(type_str.lower(), type_str.lower())
53
+ # if norm_type.endswith("[]"):
54
+ # base_type = norm_type[:-2]
55
+ # norm_base_type = type_map.get(base_type, base_type)
56
+ # norm_type = f"{norm_base_type}[]"
57
+ # return norm_type
57
58
58
59
59
60
@dataclass
60
61
class Function (base .Function ):
61
62
"""Describes a PostgreSQL function.
62
63
63
- Many attributes are not currently supported. Support is **currently**
64
- minimal due to being a means to an end for defining triggers, but can certainly
65
- be evaluated/added on request.
64
+ Not all functionality is currently implemented, but can be evaluated/added on request.
66
65
"""
67
66
68
67
security : FunctionSecurity = FunctionSecurity .invoker
69
-
70
- #: Defines the parameters for the function, e.g. ["param1 int", "param2 varchar"]
71
- parameters : Optional [List [str ]] = None
72
-
73
- #: Defines the volatility of the function.
68
+ returns : FunctionReturn | str | None = None # type: ignore
69
+ parameters : Sequence [FunctionParam | str ] | None = None # type: ignore
74
70
volatility : FunctionVolatility = FunctionVolatility .VOLATILE
75
71
76
72
def to_sql_create (self , replace = False ) -> list [str ]:
@@ -81,11 +77,15 @@ def to_sql_create(self, replace=False) -> list[str]:
81
77
82
78
parameter_str = ""
83
79
if self .parameters :
84
- parameter_str = ", " .join (self .parameters )
80
+ parameter_str = ", " .join (
81
+ cast (FunctionParam , p ).to_sql_create () for p in self .parameters
82
+ )
85
83
86
84
components .append ("FUNCTION" )
87
85
components .append (quote_name (self .qualified_name ) + f"({ parameter_str } )" )
88
- components .append (f"RETURNS { self .returns } " )
86
+
87
+ returns = cast (FunctionReturn , self .returns )
88
+ components .append (f"RETURNS { returns .to_sql_create ()} " )
89
89
90
90
if self .security == FunctionSecurity .definer :
91
91
components .append ("SECURITY DEFINER" )
@@ -102,17 +102,12 @@ def to_sql_update(self) -> list[str]:
102
102
return self .to_sql_create (replace = True )
103
103
104
104
def to_sql_drop (self ) -> list [str ]:
105
- param_types = []
105
+ param_str = ""
106
106
if self .parameters :
107
- for param in self .parameters :
108
- # Naive split, assumes 'name type' or just 'type' format
109
- parts = param .split (maxsplit = 1 )
110
- if len (parts ) == 2 :
111
- param_types .append (parts [1 ])
112
- else :
113
- param_types .append (param ) # Assume it's just the type if no space
107
+ param_str = ", " .join (
108
+ cast (FunctionParam , p ).to_sql_drop () for p in self .parameters
109
+ )
114
110
115
- param_str = ", " .join (param_types )
116
111
return [f"DROP FUNCTION { self .qualified_name } ({ param_str } );" ]
117
112
118
113
def with_security (self , security : FunctionSecurity ):
@@ -124,36 +119,162 @@ def with_security_definer(self):
124
119
def normalize (self ) -> Function :
125
120
definition = textwrap .dedent (self .definition )
126
121
127
- # Handle RETURNS TABLE(...) normalization
128
- returns_lower = self .returns .lower ().strip ()
129
- if returns_lower .startswith ("table(" ):
130
- # Basic normalization: lowercase and remove extra spaces
131
- # This might need refinement for complex TABLE definitions
132
- inner_content = returns_lower [len ("table(" ):- 1 ].strip ()
133
- cols = [normalize_arg (c ) for c in inner_content .split (',' )]
134
- normalized_returns = f"table({ ', ' .join (cols )} )"
135
- else :
136
- # Normalize base return type (including array types)
137
- norm_type = type_map .get (returns_lower , returns_lower )
138
- if norm_type .endswith ("[]" ):
139
- base = norm_type [:- 2 ]
140
- norm_base = type_map .get (base , base )
141
- normalized_returns = f"{ norm_base } []"
142
- else :
143
- normalized_returns = norm_type
144
-
145
122
# Normalize parameter types
146
- normalized_parameters = None
123
+ parameters = []
147
124
if self .parameters :
148
- normalized_parameters = [normalize_arg (p ) for p in self .parameters ]
125
+ parameters = [
126
+ FunctionParam .from_unknown (p ).normalize () for p in self .parameters
127
+ ]
128
+
129
+ input_parameters = [p for p in parameters if p .is_input ]
130
+ table_parameters = [p for p in parameters if p .is_table ]
131
+
132
+ returns = FunctionReturn .from_unknown (self .returns , parameters = table_parameters )
133
+ if returns :
134
+ returns = returns .normalize ()
149
135
150
136
return replace (
151
137
self ,
152
138
definition = definition ,
153
- returns = normalized_returns ,
154
- parameters = normalized_parameters , # Use normalized parameters
139
+ returns = returns ,
140
+ parameters = input_parameters ,
141
+ )
142
+
143
+
144
+ @dataclass
145
+ class FunctionParam :
146
+ name : str
147
+ type : str
148
+ default : Any | None = None
149
+ mode : Literal ["i" , "o" , "b" , "v" , "t" ] | None = None
150
+
151
+ @classmethod
152
+ def from_unknown (
153
+ cls , source_param : str | tuple [str , str ] | FunctionParam
154
+ ) -> FunctionParam :
155
+ if isinstance (source_param , FunctionParam ):
156
+ return source_param
157
+
158
+ if isinstance (source_param , tuple ):
159
+ return cls (* source_param )
160
+
161
+ name , type = source_param .strip ().split (maxsplit = 1 )
162
+ return cls (name , type )
163
+
164
+ def normalize (self ) -> FunctionParam :
165
+ type = self .type .lower ()
166
+ return replace (
167
+ self ,
168
+ name = self .name .lower (),
169
+ mode = self .mode or "i" ,
170
+ type = type_map .get (type , type ),
171
+ default = str (self .default ) if self .default is not None else None ,
155
172
)
156
173
174
+ def to_sql_create (self ) -> str :
175
+ result = ""
176
+ if self .mode :
177
+ result += {"o" : "OUT " , "b" : "INOUT " , "v" : "VARIADIC " , "t" : "TABLE " }.get (
178
+ self .mode , ""
179
+ )
180
+
181
+ result += f"{ self .name } { self .type } "
182
+
183
+ if self .default is not None :
184
+ result += f" DEFAULT { self .default } "
185
+ return result
186
+
187
+ def to_sql_drop (self ) -> str :
188
+ return self .type
189
+
190
+ @property
191
+ def is_input (self ) -> bool :
192
+ """Check if the parameter is an input parameter."""
193
+ return self .mode not in {"o" , "t" }
194
+
195
+ @property
196
+ def is_table (self ) -> bool :
197
+ return self .mode == "t"
198
+
199
+
200
+ @dataclass
201
+ class FunctionReturn :
202
+ value : str | None = None
203
+ table : Sequence [Column | tuple [str , str ] | str ] | None = None
204
+
205
+ @classmethod
206
+ def from_unknown (
207
+ cls ,
208
+ source : str | FunctionReturn | None ,
209
+ parameters : list [FunctionParam ] | None = None ,
210
+ ) -> FunctionReturn | None :
211
+ if source is None :
212
+ return None
213
+
214
+ if isinstance (source , FunctionReturn ):
215
+ return source
216
+
217
+ # Handle RETURNS TABLE(...) normalization
218
+ returns_lower = source .lower ().strip ()
219
+ if returns_lower .startswith ("table(" ):
220
+ assert parameters is not None , (
221
+ "Parameters must be provided for TABLE return type"
222
+ )
223
+
224
+ table_return_params = [
225
+ (p .name , p .type ) for p in parameters if p .mode == "t"
226
+ ]
227
+ return cls (table = table_return_params )
228
+ # # Basic normalization: lowercase and remove extra spaces
229
+ # # This might need refinement for complex TABLE definitions
230
+ # inner_content = returns_lower[len("table(") : -1].strip()
231
+ # cols = [normalize_arg(c) for c in inner_content.split(",")]
232
+ # normalized_returns = f"table({', '.join(cols)})"
233
+ # return cls()
234
+
235
+ # Normalize base return type (including array types)
236
+ norm_type = type_map .get (returns_lower , returns_lower )
237
+ if norm_type .endswith ("[]" ):
238
+ base = norm_type [:- 2 ]
239
+ norm_base = type_map .get (base , base )
240
+ normalized_returns = f"{ norm_base } []"
241
+ else :
242
+ normalized_returns = norm_type
243
+
244
+ return cls (value = normalized_returns )
245
+
246
+ def normalize (self ) -> FunctionReturn :
247
+ value = self .value
248
+
249
+ table = self .table
250
+ if self .table :
251
+ table = []
252
+ for arg in self .table :
253
+ if isinstance (arg , Column ):
254
+ arg_name = arg .name
255
+ arg_type = arg .type .compile ()
256
+ elif isinstance (arg , tuple ):
257
+ arg_name , arg_type = arg
258
+ else :
259
+ arg_name , arg_type = arg .strip ().split (maxsplit = 1 )
260
+
261
+ arg_type = arg_type .lower ()
262
+ arg_type = type_map .get (arg_type , arg_type )
263
+ table .append ((arg_name .lower (), arg_type ))
264
+
265
+ return replace (self , value = value , table = table )
266
+
267
+ def to_sql_create (self ) -> str :
268
+ if self .value :
269
+ return self .value or "void"
270
+
271
+ if self .table :
272
+ table = cast (list [tuple [str , str ]], self .table )
273
+ table_args = ", " .join (f"{ name } { type } " for name , type in table )
274
+ return f"TABLE({ table_args } )"
275
+
276
+ raise NotImplementedError ()
277
+
157
278
158
279
type_map = {
159
280
"bigint" : "int8" ,
0 commit comments