@@ -16,7 +16,8 @@ package plpgsql
1616
1717import (
1818 "fmt"
19- "strings"
19+
20+ pg_query "github.com/pganalyze/pg_query_go/v6"
2021)
2122
2223// Statement represents a PL/pgSQL statement.
@@ -43,15 +44,12 @@ func (Assignment) OperationSize() int32 {
4344
4445// AppendOperations implements the interface Statement.
4546func (stmt Assignment ) AppendOperations (ops * []InterpreterOperation , stack * InterpreterStack ) {
46- // TODO: figure out how I'm supposed to actually do this, rather than a search and replace
47- expression := stmt .Expression
48- var referencedVariables []string
49- for varName := range stack .ListVariables () {
50- if strings .Contains (expression , varName ) {
51- referencedVariables = append (referencedVariables , varName )
52- }
53- expression = strings .Replace (expression , varName , fmt .Sprintf ("$%d" , len (referencedVariables )), 1 )
47+ expression , referencedVariables , err := substituteVariableReferences (stmt .Expression , stack )
48+ if err != nil {
49+ // TODO: add an error return param instead of panicing
50+ panic (err )
5451 }
52+
5553 * ops = append (* ops , InterpreterOperation {
5654 OpCode : OpCode_Assign ,
5755 PrimaryData : "SELECT " + expression + ";" ,
@@ -143,14 +141,12 @@ func (If) OperationSize() int32 {
143141
144142// AppendOperations implements the interface Statement.
145143func (stmt If ) AppendOperations (ops * []InterpreterOperation , stack * InterpreterStack ) {
146- condition := stmt .Condition
147- var referencedVariables []string
148- for varName := range stack .ListVariables () {
149- if strings .Contains (condition , varName ) {
150- referencedVariables = append (referencedVariables , varName )
151- }
152- condition = strings .Replace (condition , varName , fmt .Sprintf ("$%d" , len (referencedVariables )), 1 )
144+ condition , referencedVariables , err := substituteVariableReferences (stmt .Condition , stack )
145+ if err != nil {
146+ // TODO: add an error return param instead of panicing
147+ panic (err )
153148 }
149+
154150 * ops = append (* ops , InterpreterOperation {
155151 OpCode : OpCode_If ,
156152 PrimaryData : "SELECT " + condition + ";" ,
@@ -173,15 +169,12 @@ func (Return) OperationSize() int32 {
173169
174170// AppendOperations implements the interface Statement.
175171func (stmt Return ) AppendOperations (ops * []InterpreterOperation , stack * InterpreterStack ) {
176- // TODO: figure out how I'm supposed to actually do this, rather than a search and replace
177- expression := stmt .Expression
178- var referencedVariables []string
179- for varName := range stack .ListVariables () {
180- if strings .Contains (expression , varName ) {
181- referencedVariables = append (referencedVariables , varName )
182- }
183- expression = strings .Replace (expression , varName , fmt .Sprintf ("$%d" , len (referencedVariables )), 1 )
172+ expression , referencedVariables , err := substituteVariableReferences (stmt .Expression , stack )
173+ if err != nil {
174+ // TODO: add an error return param instead of panicing
175+ panic (err )
184176 }
177+
185178 * ops = append (* ops , InterpreterOperation {
186179 OpCode : OpCode_Return ,
187180 PrimaryData : "SELECT " + expression + ";" ,
@@ -204,3 +197,29 @@ func OperationSizeForStatements(stmts []Statement) int32 {
204197 }
205198 return total
206199}
200+
201+ // substituteVariableReferences parses the specified |expression| and replaces
202+ // any token that matches a variable name in the |stack| with "$N", where N
203+ // indicates which variable in the returned |referenceVars| slice is used.
204+ func substituteVariableReferences (expression string , stack * InterpreterStack ) (newExpression string , referencedVars []string , err error ) {
205+ scanResult , err := pg_query .Scan (expression )
206+ if err != nil {
207+ return "" , nil , err
208+ }
209+
210+ varMap := stack .ListVariables ()
211+ for _ , token := range scanResult .Tokens {
212+ substring := expression [token .Start :token .End ]
213+ if _ , ok := varMap [substring ]; ok {
214+ // TODO: There's another bug here where the same variable could
215+ // be referenced multiple times in the expression and the
216+ // length of the slice won't be the correct index.
217+ referencedVars = append (referencedVars , substring )
218+ newExpression += fmt .Sprintf ("$%d " , len (referencedVars ))
219+ } else {
220+ newExpression += substring + " "
221+ }
222+ }
223+
224+ return newExpression , referencedVars , nil
225+ }
0 commit comments