bug fix:修复可能的panic状态;增加更多功能
This commit is contained in:
@@ -0,0 +1,359 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var defaultNullTimeLayouts = []string{
|
||||
time.RFC3339Nano,
|
||||
time.RFC3339,
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02",
|
||||
}
|
||||
|
||||
// ToInt64 converts any value to int64.
|
||||
func ToInt64(val interface{}) int64 {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
case int:
|
||||
return int64(v)
|
||||
case int32:
|
||||
return int64(v)
|
||||
case int64:
|
||||
return v
|
||||
case uint64:
|
||||
return int64(v)
|
||||
case float32:
|
||||
return int64(v)
|
||||
case float64:
|
||||
return int64(v)
|
||||
case string:
|
||||
result, _ := strconv.ParseInt(v, 10, 64)
|
||||
return result
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case time.Time:
|
||||
return v.Unix()
|
||||
case []byte:
|
||||
result, _ := strconv.ParseInt(string(v), 10, 64)
|
||||
return result
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ToUint64 converts any value to uint64.
|
||||
func ToUint64(val interface{}) uint64 {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
case int:
|
||||
return uint64(v)
|
||||
case int32:
|
||||
return uint64(v)
|
||||
case int64:
|
||||
return uint64(v)
|
||||
case uint64:
|
||||
return v
|
||||
case float32:
|
||||
return uint64(v)
|
||||
case float64:
|
||||
return uint64(v)
|
||||
case string:
|
||||
result, _ := strconv.ParseUint(v, 10, 64)
|
||||
return result
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case time.Time:
|
||||
return uint64(v.Unix())
|
||||
case []byte:
|
||||
result, _ := strconv.ParseUint(string(v), 10, 64)
|
||||
return result
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ToFloat64 converts any value to float64.
|
||||
func ToFloat64(val interface{}) float64 {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return 0
|
||||
case int:
|
||||
return float64(v)
|
||||
case int32:
|
||||
return float64(v)
|
||||
case int64:
|
||||
return float64(v)
|
||||
case uint64:
|
||||
return float64(v)
|
||||
case float32:
|
||||
return float64(v)
|
||||
case float64:
|
||||
return v
|
||||
case string:
|
||||
result, _ := strconv.ParseFloat(v, 64)
|
||||
return result
|
||||
case bool:
|
||||
if v {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
case time.Time:
|
||||
return float64(v.Unix())
|
||||
case []byte:
|
||||
result, _ := strconv.ParseFloat(string(v), 64)
|
||||
return result
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
// ToBool converts any value to bool.
|
||||
func ToBool(val interface{}) bool {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return false
|
||||
case bool:
|
||||
return v
|
||||
case int:
|
||||
return v != 0
|
||||
case int32:
|
||||
return v != 0
|
||||
case int64:
|
||||
return v != 0
|
||||
case uint64:
|
||||
return v != 0
|
||||
case float32:
|
||||
return v != 0
|
||||
case float64:
|
||||
return v != 0
|
||||
case string:
|
||||
result, _ := strconv.ParseBool(v)
|
||||
return result
|
||||
case []byte:
|
||||
result, _ := strconv.ParseBool(string(v))
|
||||
return result
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// ToTime converts any value to time.Time.
|
||||
func ToTime(val interface{}, layout string) time.Time {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return time.Time{}
|
||||
case time.Time:
|
||||
return v
|
||||
case int:
|
||||
return time.Unix(int64(v), 0)
|
||||
case int32:
|
||||
return time.Unix(int64(v), 0)
|
||||
case int64:
|
||||
return time.Unix(v, 0)
|
||||
case uint64:
|
||||
return time.Unix(int64(v), 0)
|
||||
case float32:
|
||||
sec := int64(v)
|
||||
nsec := int64((v - float32(sec)) * 1e9)
|
||||
return time.Unix(sec, nsec)
|
||||
case float64:
|
||||
sec := int64(v)
|
||||
nsec := int64((v - float64(sec)) * 1e9)
|
||||
return time.Unix(sec, nsec)
|
||||
case string:
|
||||
result, _ := time.Parse(layout, v)
|
||||
return result
|
||||
case []byte:
|
||||
result, _ := time.Parse(layout, string(v))
|
||||
return result
|
||||
default:
|
||||
return time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
// ToInt64Safe converts any value to int64 with error handling.
|
||||
func ToInt64Safe(val interface{}) (int64, error) {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return 0, nil
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case int32:
|
||||
return int64(v), nil
|
||||
case int64:
|
||||
return v, nil
|
||||
case uint64:
|
||||
return int64(v), nil
|
||||
case float32:
|
||||
return int64(v), nil
|
||||
case float64:
|
||||
return int64(v), nil
|
||||
case string:
|
||||
return strconv.ParseInt(v, 10, 64)
|
||||
case bool:
|
||||
if v {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, nil
|
||||
case time.Time:
|
||||
return v.Unix(), nil
|
||||
case []byte:
|
||||
return strconv.ParseInt(string(v), 10, 64)
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot convert %T to int64", val)
|
||||
}
|
||||
}
|
||||
|
||||
// ToStringSafe converts any value to string with error handling.
|
||||
func ToStringSafe(val interface{}) (string, error) {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return "", nil
|
||||
case string:
|
||||
return v, nil
|
||||
case int:
|
||||
return strconv.Itoa(v), nil
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(v), 10), nil
|
||||
case int64:
|
||||
return strconv.FormatInt(v, 10), nil
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
|
||||
case float64:
|
||||
return strconv.FormatFloat(v, 'f', -1, 64), nil
|
||||
case bool:
|
||||
return strconv.FormatBool(v), nil
|
||||
case time.Time:
|
||||
return v.String(), nil
|
||||
case []byte:
|
||||
return string(v), nil
|
||||
default:
|
||||
return "", fmt.Errorf("cannot convert %T to string", val)
|
||||
}
|
||||
}
|
||||
|
||||
// ToFloat64Safe converts any value to float64 with error handling.
|
||||
func ToFloat64Safe(val interface{}) (float64, error) {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return 0, nil
|
||||
case float64:
|
||||
return v, nil
|
||||
case float32:
|
||||
return float64(v), nil
|
||||
case int, int32, int64, uint64:
|
||||
intVal, err := ToInt64Safe(v)
|
||||
return float64(intVal), err
|
||||
case string:
|
||||
return strconv.ParseFloat(v, 64)
|
||||
case []byte:
|
||||
return strconv.ParseFloat(string(v), 64)
|
||||
default:
|
||||
return 0, fmt.Errorf("cannot convert %T to float64", val)
|
||||
}
|
||||
}
|
||||
|
||||
// ToBoolSafe converts any value to bool with error handling.
|
||||
func ToBoolSafe(val interface{}) (bool, error) {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return false, nil
|
||||
case bool:
|
||||
return v, nil
|
||||
case int:
|
||||
return v != 0, nil
|
||||
case int8:
|
||||
return v != 0, nil
|
||||
case int16:
|
||||
return v != 0, nil
|
||||
case int32:
|
||||
return v != 0, nil
|
||||
case int64:
|
||||
return v != 0, nil
|
||||
case uint:
|
||||
return v != 0, nil
|
||||
case uint8:
|
||||
return v != 0, nil
|
||||
case uint16:
|
||||
return v != 0, nil
|
||||
case uint32:
|
||||
return v != 0, nil
|
||||
case uint64:
|
||||
return v != 0, nil
|
||||
case float32:
|
||||
return v != 0, nil
|
||||
case float64:
|
||||
return v != 0, nil
|
||||
case string:
|
||||
return ParseBoolString(v)
|
||||
case []byte:
|
||||
return ParseBoolString(string(v))
|
||||
default:
|
||||
return false, fmt.Errorf("cannot convert %T to bool", val)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseBoolString parses string-like bool values.
|
||||
func ParseBoolString(raw string) (bool, error) {
|
||||
normalized := strings.TrimSpace(strings.ToLower(raw))
|
||||
switch normalized {
|
||||
case "", "0", "false", "f", "off", "no", "n":
|
||||
return false, nil
|
||||
case "1", "true", "t", "on", "yes", "y":
|
||||
return true, nil
|
||||
default:
|
||||
return false, fmt.Errorf("cannot parse bool value: %q", raw)
|
||||
}
|
||||
}
|
||||
|
||||
// ToTimeSafe converts any value to time.Time with error handling.
|
||||
func ToTimeSafe(val interface{}) (time.Time, error) {
|
||||
switch v := val.(type) {
|
||||
case nil:
|
||||
return time.Time{}, nil
|
||||
case time.Time:
|
||||
return v, nil
|
||||
case int:
|
||||
return time.Unix(int64(v), 0), nil
|
||||
case int64:
|
||||
return time.Unix(v, 0), nil
|
||||
case string:
|
||||
return ParseTimeValue(v)
|
||||
case []byte:
|
||||
return ParseTimeValue(string(v))
|
||||
default:
|
||||
return time.Time{}, fmt.Errorf("cannot convert %T to time.Time", val)
|
||||
}
|
||||
}
|
||||
|
||||
// ParseTimeValue parses common SQL date-time formats and unix timestamp.
|
||||
func ParseTimeValue(raw string) (time.Time, error) {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return time.Time{}, nil
|
||||
}
|
||||
|
||||
for _, layout := range defaultNullTimeLayouts {
|
||||
if t, err := time.Parse(layout, trimmed); err == nil {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
|
||||
if ts, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
||||
return time.Unix(ts, 0), nil
|
||||
}
|
||||
|
||||
return time.Time{}, fmt.Errorf("cannot parse time value: %q", raw)
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
package scanutil
|
||||
|
||||
// CloneScannedValue copies driver-scanned values that may be reused by driver.
|
||||
// []byte is deep-copied; other types are returned as-is.
|
||||
func CloneScannedValue(val interface{}) interface{} {
|
||||
if b, ok := val.([]byte); ok {
|
||||
copied := make([]byte, len(b))
|
||||
copy(copied, b)
|
||||
return copied
|
||||
}
|
||||
return val
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
package sqlplaceholder
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
|
||||
// It skips quoted strings, quoted identifiers and comments.
|
||||
func ConvertQuestionToDollarPlaceholders(query string) string {
|
||||
if query == "" || !strings.Contains(query, "?") {
|
||||
return query
|
||||
}
|
||||
|
||||
const (
|
||||
stateNormal = iota
|
||||
stateSingleQuote
|
||||
stateDoubleQuote
|
||||
stateBacktick
|
||||
stateLineComment
|
||||
stateBlockComment
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(query) + 8)
|
||||
|
||||
state := stateNormal
|
||||
index := 1
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
c := query[i]
|
||||
|
||||
switch state {
|
||||
case stateNormal:
|
||||
if c == '\'' {
|
||||
state = stateSingleQuote
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
state = stateDoubleQuote
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
if c == '`' {
|
||||
state = stateBacktick
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
||||
state = stateLineComment
|
||||
b.WriteByte(c)
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
||||
state = stateBlockComment
|
||||
b.WriteByte(c)
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
if c == '?' {
|
||||
b.WriteByte('$')
|
||||
b.WriteString(strconv.Itoa(index))
|
||||
index++
|
||||
continue
|
||||
}
|
||||
b.WriteByte(c)
|
||||
|
||||
case stateSingleQuote:
|
||||
b.WriteByte(c)
|
||||
if c == '\'' {
|
||||
// SQL escaped single quote: ''
|
||||
if i+1 < len(query) && query[i+1] == '\'' {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateDoubleQuote:
|
||||
b.WriteByte(c)
|
||||
if c == '"' {
|
||||
// escaped double quote: ""
|
||||
if i+1 < len(query) && query[i+1] == '"' {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateBacktick:
|
||||
b.WriteByte(c)
|
||||
if c == '`' {
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateLineComment:
|
||||
b.WriteByte(c)
|
||||
if c == '\n' {
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateBlockComment:
|
||||
b.WriteByte(c)
|
||||
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
state = stateNormal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package sqlruntime
|
||||
|
||||
import "strings"
|
||||
|
||||
// FingerprintSQL creates a normalized SQL fingerprint.
|
||||
// mode controls literal masking; keepComments controls whether comments are preserved.
|
||||
func FingerprintSQL(query string, mode int, keepComments bool) string {
|
||||
prepared := query
|
||||
if !keepComments {
|
||||
prepared = stripSQLComments(prepared)
|
||||
}
|
||||
|
||||
normalized := normalizeSQL(prepared)
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if NormalizeFingerprintMode(mode) == fingerprintModeMaskLiterals {
|
||||
return maskSQLLiterals(normalized, keepComments)
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeSQL(query string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(query))
|
||||
if normalized == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.Join(strings.Fields(normalized), " ")
|
||||
}
|
||||
|
||||
func stripSQLComments(query string) string {
|
||||
if query == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
const (
|
||||
stateNormal = iota
|
||||
stateSingleQuote
|
||||
stateDoubleQuote
|
||||
stateBacktick
|
||||
stateLineComment
|
||||
stateBlockComment
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(query))
|
||||
state := stateNormal
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
c := query[i]
|
||||
|
||||
switch state {
|
||||
case stateNormal:
|
||||
if c == '\'' {
|
||||
state = stateSingleQuote
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
state = stateDoubleQuote
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
if c == '`' {
|
||||
state = stateBacktick
|
||||
b.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
||||
b.WriteByte(' ')
|
||||
i++
|
||||
state = stateLineComment
|
||||
continue
|
||||
}
|
||||
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
||||
b.WriteByte(' ')
|
||||
i++
|
||||
state = stateBlockComment
|
||||
continue
|
||||
}
|
||||
b.WriteByte(c)
|
||||
|
||||
case stateSingleQuote:
|
||||
b.WriteByte(c)
|
||||
if c == '\'' {
|
||||
if i+1 < len(query) && query[i+1] == '\'' {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateDoubleQuote:
|
||||
b.WriteByte(c)
|
||||
if c == '"' {
|
||||
if i+1 < len(query) && query[i+1] == '"' {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateBacktick:
|
||||
b.WriteByte(c)
|
||||
if c == '`' {
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateLineComment:
|
||||
if c == '\n' {
|
||||
b.WriteByte(' ')
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateBlockComment:
|
||||
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
||||
b.WriteByte(' ')
|
||||
i++
|
||||
state = stateNormal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func maskSQLLiterals(query string, keepComments bool) string {
|
||||
if query == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
const (
|
||||
stateNormal = iota
|
||||
stateSingleQuote
|
||||
stateDoubleQuote
|
||||
stateBacktick
|
||||
stateLineComment
|
||||
stateBlockComment
|
||||
)
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(len(query))
|
||||
state := stateNormal
|
||||
|
||||
for i := 0; i < len(query); i++ {
|
||||
c := query[i]
|
||||
|
||||
switch state {
|
||||
case stateNormal:
|
||||
if c == '\'' {
|
||||
b.WriteByte('?')
|
||||
state = stateSingleQuote
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
b.WriteByte(c)
|
||||
state = stateDoubleQuote
|
||||
continue
|
||||
}
|
||||
if c == '`' {
|
||||
b.WriteByte(c)
|
||||
state = stateBacktick
|
||||
continue
|
||||
}
|
||||
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
||||
if keepComments {
|
||||
b.WriteByte(c)
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
} else {
|
||||
b.WriteByte(' ')
|
||||
i++
|
||||
}
|
||||
state = stateLineComment
|
||||
continue
|
||||
}
|
||||
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
||||
if keepComments {
|
||||
b.WriteByte(c)
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
} else {
|
||||
b.WriteByte(' ')
|
||||
i++
|
||||
}
|
||||
state = stateBlockComment
|
||||
continue
|
||||
}
|
||||
if c == '$' {
|
||||
j := i + 1
|
||||
for j < len(query) && isDigit(query[j]) {
|
||||
j++
|
||||
}
|
||||
if j > i+1 {
|
||||
b.WriteByte('?')
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
if c == '-' && i+1 < len(query) && isDigit(query[i+1]) && isNumberBoundaryBefore(query, i) {
|
||||
j := scanNumber(query, i+1)
|
||||
if isNumberBoundaryAfter(query, j) {
|
||||
b.WriteByte('?')
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
if isDigit(c) && isNumberBoundaryBefore(query, i) {
|
||||
j := scanNumber(query, i)
|
||||
if isNumberBoundaryAfter(query, j) {
|
||||
b.WriteByte('?')
|
||||
i = j - 1
|
||||
continue
|
||||
}
|
||||
}
|
||||
b.WriteByte(c)
|
||||
|
||||
case stateSingleQuote:
|
||||
if c == '\'' {
|
||||
if i+1 < len(query) && query[i+1] == '\'' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateDoubleQuote:
|
||||
b.WriteByte(c)
|
||||
if c == '"' {
|
||||
if i+1 < len(query) && query[i+1] == '"' {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
continue
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateBacktick:
|
||||
b.WriteByte(c)
|
||||
if c == '`' {
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateLineComment:
|
||||
if keepComments {
|
||||
b.WriteByte(c)
|
||||
}
|
||||
if c == '\n' {
|
||||
if !keepComments {
|
||||
b.WriteByte(' ')
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
|
||||
case stateBlockComment:
|
||||
if keepComments {
|
||||
b.WriteByte(c)
|
||||
}
|
||||
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
||||
if keepComments {
|
||||
i++
|
||||
b.WriteByte(query[i])
|
||||
} else {
|
||||
b.WriteByte(' ')
|
||||
i++
|
||||
}
|
||||
state = stateNormal
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(strings.Fields(b.String()), " ")
|
||||
}
|
||||
|
||||
func isDigit(c byte) bool {
|
||||
return c >= '0' && c <= '9'
|
||||
}
|
||||
|
||||
func isIdentifierChar(c byte) bool {
|
||||
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_'
|
||||
}
|
||||
|
||||
func isNumberBoundaryBefore(query string, index int) bool {
|
||||
if index <= 0 {
|
||||
return true
|
||||
}
|
||||
prev := query[index-1]
|
||||
return !isIdentifierChar(prev) && prev != '$' && prev != '.'
|
||||
}
|
||||
|
||||
func isNumberBoundaryAfter(query string, index int) bool {
|
||||
if index >= len(query) {
|
||||
return true
|
||||
}
|
||||
next := query[index]
|
||||
return !isIdentifierChar(next) && next != '.'
|
||||
}
|
||||
|
||||
func scanNumber(query string, start int) int {
|
||||
i := start
|
||||
for i < len(query) && isDigit(query[i]) {
|
||||
i++
|
||||
}
|
||||
if i < len(query) && query[i] == '.' {
|
||||
i++
|
||||
for i < len(query) && isDigit(query[i]) {
|
||||
i++
|
||||
}
|
||||
}
|
||||
if i < len(query) && (query[i] == 'e' || query[i] == 'E') {
|
||||
i++
|
||||
if i < len(query) && (query[i] == '+' || query[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < len(query) && isDigit(query[i]) {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return i
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package sqlruntime
|
||||
|
||||
import "time"
|
||||
|
||||
// CloneHookArgs creates a shallow copy for hook consumers to avoid mutation races.
|
||||
func CloneHookArgs(args []interface{}) []interface{} {
|
||||
if len(args) == 0 {
|
||||
return nil
|
||||
}
|
||||
copied := make([]interface{}, len(args))
|
||||
copy(copied, args)
|
||||
return copied
|
||||
}
|
||||
|
||||
// ShouldRunAfterHook decides whether after-hook should run.
|
||||
func ShouldRunAfterHook(hasAfterHook bool, slowThreshold, duration time.Duration, err error) bool {
|
||||
if !hasAfterHook {
|
||||
return false
|
||||
}
|
||||
if err != nil {
|
||||
return true
|
||||
}
|
||||
if slowThreshold <= 0 {
|
||||
return true
|
||||
}
|
||||
return duration >= slowThreshold
|
||||
}
|
||||
@@ -0,0 +1,269 @@
|
||||
package sqlruntime
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
placeholderQuestion = 0
|
||||
placeholderDollar = 1
|
||||
|
||||
fingerprintModeBasic = 0
|
||||
fingerprintModeMaskLiterals = 1
|
||||
)
|
||||
|
||||
// NormalizePlaceholderStyle converts unknown style values to default question style.
|
||||
func NormalizePlaceholderStyle(style int) int {
|
||||
switch style {
|
||||
case placeholderDollar:
|
||||
return placeholderDollar
|
||||
default:
|
||||
return placeholderQuestion
|
||||
}
|
||||
}
|
||||
|
||||
// NormalizeFingerprintMode converts unknown mode values to default basic mode.
|
||||
func NormalizeFingerprintMode(mode int) int {
|
||||
switch mode {
|
||||
case fingerprintModeMaskLiterals:
|
||||
return fingerprintModeMaskLiterals
|
||||
default:
|
||||
return fingerprintModeBasic
|
||||
}
|
||||
}
|
||||
|
||||
// State stores runtime SQL behavior toggles in a thread-safe manner.
|
||||
type State struct {
|
||||
mu sync.RWMutex
|
||||
beforeHook interface{}
|
||||
afterHook interface{}
|
||||
placeholder int
|
||||
slowThreshold time.Duration
|
||||
fingerprintEnabled bool
|
||||
fingerprintMode int
|
||||
fingerprintKeepComments bool
|
||||
fingerprintCounterEnabled bool
|
||||
fingerprintCounts map[string]uint64
|
||||
}
|
||||
|
||||
// Options returns snapshot of current runtime options.
|
||||
func (s *State) Options() (before, after interface{}, placeholder int, slowThreshold time.Duration) {
|
||||
if s == nil {
|
||||
return nil, nil, placeholderQuestion, 0
|
||||
}
|
||||
s.mu.RLock()
|
||||
before = s.beforeHook
|
||||
after = s.afterHook
|
||||
placeholder = NormalizePlaceholderStyle(s.placeholder)
|
||||
slowThreshold = s.slowThreshold
|
||||
s.mu.RUnlock()
|
||||
return before, after, placeholder, slowThreshold
|
||||
}
|
||||
|
||||
// Hooks returns before/after hooks and slow threshold.
|
||||
func (s *State) Hooks() (before, after interface{}, slowThreshold time.Duration) {
|
||||
before, after, _, slowThreshold = s.Options()
|
||||
return before, after, slowThreshold
|
||||
}
|
||||
|
||||
// SetHooks sets before/after hooks.
|
||||
func (s *State) SetHooks(before, after interface{}) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.beforeHook = before
|
||||
s.afterHook = after
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetBeforeHook sets before hook.
|
||||
func (s *State) SetBeforeHook(before interface{}) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.beforeHook = before
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetAfterHook sets after hook.
|
||||
func (s *State) SetAfterHook(after interface{}) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.afterHook = after
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetPlaceholderStyle sets placeholder style.
|
||||
func (s *State) SetPlaceholderStyle(style int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.placeholder = NormalizePlaceholderStyle(style)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// PlaceholderStyle returns placeholder style.
|
||||
func (s *State) PlaceholderStyle() int {
|
||||
if s == nil {
|
||||
return placeholderQuestion
|
||||
}
|
||||
s.mu.RLock()
|
||||
style := NormalizePlaceholderStyle(s.placeholder)
|
||||
s.mu.RUnlock()
|
||||
return style
|
||||
}
|
||||
|
||||
// SetSlowThreshold sets minimum duration for triggering after hook.
|
||||
func (s *State) SetSlowThreshold(threshold time.Duration) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if threshold < 0 {
|
||||
threshold = 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.slowThreshold = threshold
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SlowThreshold returns current slow threshold.
|
||||
func (s *State) SlowThreshold() time.Duration {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
s.mu.RLock()
|
||||
threshold := s.slowThreshold
|
||||
s.mu.RUnlock()
|
||||
return threshold
|
||||
}
|
||||
|
||||
// SetFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
|
||||
func (s *State) SetFingerprintEnabled(enabled bool) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.fingerprintEnabled = enabled
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// FingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
|
||||
func (s *State) FingerprintEnabled() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.RLock()
|
||||
enabled := s.fingerprintEnabled
|
||||
s.mu.RUnlock()
|
||||
return enabled
|
||||
}
|
||||
|
||||
// SetFingerprintMode sets SQL fingerprint mode.
|
||||
func (s *State) SetFingerprintMode(mode int) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.fingerprintMode = NormalizeFingerprintMode(mode)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// FingerprintMode returns SQL fingerprint mode.
|
||||
func (s *State) FingerprintMode() int {
|
||||
if s == nil {
|
||||
return fingerprintModeBasic
|
||||
}
|
||||
s.mu.RLock()
|
||||
mode := NormalizeFingerprintMode(s.fingerprintMode)
|
||||
s.mu.RUnlock()
|
||||
return mode
|
||||
}
|
||||
|
||||
// SetFingerprintKeepComments toggles comment preservation in generated SQL fingerprints.
|
||||
func (s *State) SetFingerprintKeepComments(keep bool) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.fingerprintKeepComments = keep
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// FingerprintKeepComments reports whether comments are kept in generated SQL fingerprints.
|
||||
func (s *State) FingerprintKeepComments() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.RLock()
|
||||
keep := s.fingerprintKeepComments
|
||||
s.mu.RUnlock()
|
||||
return keep
|
||||
}
|
||||
|
||||
// SetFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
|
||||
func (s *State) SetFingerprintCounterEnabled(enabled bool) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.fingerprintCounterEnabled = enabled
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// FingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
|
||||
func (s *State) FingerprintCounterEnabled() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
s.mu.RLock()
|
||||
enabled := s.fingerprintCounterEnabled
|
||||
s.mu.RUnlock()
|
||||
return enabled
|
||||
}
|
||||
|
||||
// IncFingerprintCount increments hit count for a fingerprint.
|
||||
func (s *State) IncFingerprintCount(fingerprint string) {
|
||||
if s == nil || fingerprint == "" {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.fingerprintCounts == nil {
|
||||
s.fingerprintCounts = make(map[string]uint64)
|
||||
}
|
||||
s.fingerprintCounts[fingerprint]++
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// FingerprintCountsSnapshot returns a snapshot copy of fingerprint counters.
|
||||
func (s *State) FingerprintCountsSnapshot() map[string]uint64 {
|
||||
if s == nil {
|
||||
return map[string]uint64{}
|
||||
}
|
||||
s.mu.RLock()
|
||||
if len(s.fingerprintCounts) == 0 {
|
||||
s.mu.RUnlock()
|
||||
return map[string]uint64{}
|
||||
}
|
||||
out := make(map[string]uint64, len(s.fingerprintCounts))
|
||||
for k, v := range s.fingerprintCounts {
|
||||
out[k] = v
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return out
|
||||
}
|
||||
|
||||
// ResetFingerprintCounts clears all fingerprint counters.
|
||||
func (s *State) ResetFingerprintCounts() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.fingerprintCounts = nil
|
||||
s.mu.Unlock()
|
||||
}
|
||||
Reference in New Issue
Block a user