@@ -27,6 +27,18 @@ class Base(_Base): # type: ignore
27
27
returns = "INTEGER" ,
28
28
volatility = FunctionVolatility .STABLE ,
29
29
),
30
+ # Function returning TABLE
31
+ Function (
32
+ "generate_series_squared" ,
33
+ '''
34
+ SELECT i, i*i
35
+ FROM generate_series(1, _limit) as i;
36
+ ''' ,
37
+ language = "sql" ,
38
+ parameters = ["_limit integer" ],
39
+ returns = "TABLE(num integer, num_squared integer)" ,
40
+ volatility = FunctionVolatility .IMMUTABLE ,
41
+ ),
30
42
)
31
43
32
44
@@ -51,22 +63,21 @@ def test_create(pg):
51
63
result = pg .execute (text ("SELECT gimme()" )).scalar ()
52
64
assert result == 2
53
65
54
- connection = pg .connection ()
55
- diff = compare_functions (connection , Base .metadata .info ["functions" ])
56
- assert diff == []
57
-
66
+ # Test function with parameters
67
+ result_params = pg .execute (text ("SELECT add_stable(10)" )).scalar ()
68
+ assert result_params == 11
58
69
59
- def test_create_with_params (pg ):
60
- Base .metadata .create_all (bind = pg .connection ())
61
- pg .commit ()
70
+ result_params_2 = pg .execute (text ("SELECT add_stable(1)" )).scalar ()
71
+ assert result_params_2 == 2
62
72
63
- result = pg .execute (text ("SELECT add_stable(10)" )).scalar ()
64
- assert result == 11
65
-
66
- result = pg .execute (text ("SELECT add_stable(1)" )).scalar ()
67
- assert result == 2
73
+ # Test function returning table
74
+ result_table = pg .execute (text ("SELECT * FROM generate_series_squared(3)" )).fetchall ()
75
+ assert result_table == [
76
+ (1 , 1 ),
77
+ (2 , 4 ),
78
+ (3 , 9 ),
79
+ ]
68
80
69
- # Verify the function is there via comparison
70
81
connection = pg .connection ()
71
82
diff = compare_functions (connection , Base .metadata .info ["functions" ])
72
83
assert diff == []
0 commit comments