120 lines
2.1 KiB
Go

package sqlplaceholder
import (
"strconv"
"strings"
)
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
// It skips quoted strings, quoted identifiers and comments.
func ConvertQuestionToDollarPlaceholders(query string) string {
if query == "" || !strings.Contains(query, "?") {
return query
}
const (
stateNormal = iota
stateSingleQuote
stateDoubleQuote
stateBacktick
stateLineComment
stateBlockComment
)
var b strings.Builder
b.Grow(len(query) + 8)
state := stateNormal
index := 1
for i := 0; i < len(query); i++ {
c := query[i]
switch state {
case stateNormal:
if c == '\'' {
state = stateSingleQuote
b.WriteByte(c)
continue
}
if c == '"' {
state = stateDoubleQuote
b.WriteByte(c)
continue
}
if c == '`' {
state = stateBacktick
b.WriteByte(c)
continue
}
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
state = stateLineComment
b.WriteByte(c)
i++
b.WriteByte(query[i])
continue
}
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
state = stateBlockComment
b.WriteByte(c)
i++
b.WriteByte(query[i])
continue
}
if c == '?' {
b.WriteByte('$')
b.WriteString(strconv.Itoa(index))
index++
continue
}
b.WriteByte(c)
case stateSingleQuote:
b.WriteByte(c)
if c == '\'' {
// SQL escaped single quote: ''
if i+1 < len(query) && query[i+1] == '\'' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateDoubleQuote:
b.WriteByte(c)
if c == '"' {
// escaped double quote: ""
if i+1 < len(query) && query[i+1] == '"' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateBacktick:
b.WriteByte(c)
if c == '`' {
state = stateNormal
}
case stateLineComment:
b.WriteByte(c)
if c == '\n' {
state = stateNormal
}
case stateBlockComment:
b.WriteByte(c)
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
i++
b.WriteByte(query[i])
state = stateNormal
}
}
}
return b.String()
}