bug fix:修复可能的panic状态;增加更多功能
This commit is contained in:
parent
8f1a9893bc
commit
e0af498fa4
7
.gitignore
vendored
7
.gitignore
vendored
@ -13,7 +13,12 @@
|
|||||||
vendor/
|
vendor/
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.idea/
|
.idea
|
||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
|
# Agent local governance files
|
||||||
|
.sentrux/
|
||||||
|
agent_readme.md
|
||||||
|
target.md
|
||||||
|
|||||||
59
CHANGELOG.MD
Normal file
59
CHANGELOG.MD
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
# Changelog
|
||||||
|
|
||||||
|
本文档记录 StarDB 的主要变更。
|
||||||
|
|
||||||
|
## [Unreleased] - 2026-03-20
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- 新增可判定错误类型(`errors.Is` 友好):
|
||||||
|
- 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized`
|
||||||
|
- 参数/目标校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` 等
|
||||||
|
- 映射与批量写入:`ErrColumnNotFound` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch` 等
|
||||||
|
- 新增流式查询能力(DB / Tx / Stmt):
|
||||||
|
- `QueryRaw` / `QueryRawContext`
|
||||||
|
- `ScanEach` / `ScanEachContext`
|
||||||
|
- `ScanEachORM` / `ScanEachORMContext`
|
||||||
|
- 新增 NULL 安全取值:
|
||||||
|
- `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime`
|
||||||
|
- 新增 ORM 行为开关:
|
||||||
|
- `SetStrictORM(true)` 启用严格列检查
|
||||||
|
- `ClearReflectCache()` 清理反射缓存
|
||||||
|
- 新增 SQL 运行时可观测能力:
|
||||||
|
- Hook:`SetSQLHooks` `SetSQLBeforeHook` `SetSQLAfterHook`
|
||||||
|
- 慢 SQL 阈值:`SetSQLSlowThreshold`
|
||||||
|
- 指纹:`SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments`
|
||||||
|
- 指纹计数:`SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters`
|
||||||
|
- Context 元信息:`SQLHookMetaFromContext` `BatchExecMetaFromContext`
|
||||||
|
- 新增占位符方言适配:
|
||||||
|
- `SetPlaceholderStyle(PlaceholderQuestion|PlaceholderDollar)`(`?` / `$1,$2...`)
|
||||||
|
- 新增批量插入分片控制:
|
||||||
|
- `SetBatchInsertMaxRows`
|
||||||
|
- `SetBatchInsertMaxParams`
|
||||||
|
- 常见驱动参数上限自动识别(SQLite / PostgreSQL / MySQL / SQL Server)
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
- 批量写入在开启分片或触发参数阈值时,改为事务内多分片执行,降低单条 SQL 过大风险。
|
||||||
|
- 分片批量写入结果语义明确:
|
||||||
|
- `RowsAffected()` 返回分片累计值
|
||||||
|
- `LastInsertId()` 返回最后一个分片的 insert id
|
||||||
|
- 内部结构按模块归档到 `internal/`,保持外部 API 稳定:
|
||||||
|
- `internal/convert`
|
||||||
|
- `internal/scanutil`
|
||||||
|
- `internal/sqlplaceholder`
|
||||||
|
- `internal/sqlruntime`
|
||||||
|
- README 重写为面向使用场景的说明,补齐能力边界、接入顺序和 API 细节。
|
||||||
|
|
||||||
|
### Behavior Notes
|
||||||
|
- 默认查询 `Query` 仍为内存模式(解析到 `StarRows`)。
|
||||||
|
- 关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`。
|
||||||
|
- SQL Hook、指纹与指纹计数默认关闭,需显式开启。
|
||||||
|
- 批量分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。
|
||||||
|
|
||||||
|
### Tests
|
||||||
|
- 新增/补强测试覆盖:
|
||||||
|
- 流式查询与流式 ORM
|
||||||
|
- NULL 安全取值
|
||||||
|
- 严格 ORM 行为
|
||||||
|
- 占位符转换
|
||||||
|
- SQL Hook、慢 SQL 阈值、指纹模式、注释保留开关、指纹计数
|
||||||
|
- BatchInsert 分片(按行数/参数)、失败回滚与结果语义
|
||||||
260
batch.go
260
batch.go
@ -6,8 +6,220 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type multiSQLResult struct {
|
||||||
|
results []sql.Result
|
||||||
|
}
|
||||||
|
|
||||||
|
type batchExecMetaKey struct{}
|
||||||
|
|
||||||
|
// BatchExecMeta contains chunk execution metadata for batch insert operations.
|
||||||
|
// It is attached to context for chunked execution and can be read in SQL hooks.
|
||||||
|
type BatchExecMeta struct {
|
||||||
|
ChunkIndex int
|
||||||
|
ChunkCount int
|
||||||
|
ChunkRows int
|
||||||
|
TotalRows int
|
||||||
|
ColumnCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchExecMetaFromContext extracts batch chunk metadata from context.
|
||||||
|
func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) {
|
||||||
|
if ctx == nil {
|
||||||
|
return BatchExecMeta{}, false
|
||||||
|
}
|
||||||
|
meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta)
|
||||||
|
return meta, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, batchExecMetaKey{}, meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m multiSQLResult) LastInsertId() (int64, error) {
|
||||||
|
if len(m.results) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return m.results[len(m.results)-1].LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m multiSQLResult) RowsAffected() (int64, error) {
|
||||||
|
var total int64
|
||||||
|
for _, result := range m.results {
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
total += affected
|
||||||
|
}
|
||||||
|
return total, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs.
|
||||||
|
// <= 0 disables splitting and keeps single-statement behavior.
|
||||||
|
func (s *StarDB) SetBatchInsertMaxRows(maxRows int) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if maxRows < 0 {
|
||||||
|
maxRows = 0
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchInsertMaxRows returns current batch split threshold.
|
||||||
|
// 0 means disabled.
|
||||||
|
func (s *StarDB) BatchInsertMaxRows() int {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
value := atomic.LoadInt64(&s.batchInsertMaxRows)
|
||||||
|
if value <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs.
|
||||||
|
// <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers.
|
||||||
|
func (s *StarDB) SetBatchInsertMaxParams(maxParams int) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if maxParams < 0 {
|
||||||
|
maxParams = 0
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchInsertMaxParams returns configured max bind parameter threshold.
|
||||||
|
// 0 means auto mode.
|
||||||
|
func (s *StarDB) BatchInsertMaxParams() int {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
value := atomic.LoadInt64(&s.batchInsertMaxParams)
|
||||||
|
if value <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return int(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func detectBatchInsertMaxParams(db *sql.DB) int {
|
||||||
|
if db == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver()))
|
||||||
|
switch {
|
||||||
|
case strings.Contains(driverType, "sqlite"):
|
||||||
|
// Keep conservative default for wide compatibility.
|
||||||
|
return 999
|
||||||
|
case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"):
|
||||||
|
return 65535
|
||||||
|
case strings.Contains(driverType, "mysql"):
|
||||||
|
return 65535
|
||||||
|
case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"):
|
||||||
|
return 2100
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func minPositive(a, b int) int {
|
||||||
|
if a <= 0 {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
if b <= 0 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
if a < b {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) {
|
||||||
|
maxRows := s.BatchInsertMaxRows()
|
||||||
|
maxParams := s.BatchInsertMaxParams()
|
||||||
|
if maxParams <= 0 {
|
||||||
|
maxParams = detectBatchInsertMaxParams(s.db)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxRowsByParams := 0
|
||||||
|
if maxParams > 0 {
|
||||||
|
maxRowsByParams = maxParams / columnCount
|
||||||
|
if maxRowsByParams <= 0 {
|
||||||
|
return 0, ErrBatchInsertMaxParamsTooLow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return minPositive(maxRows, maxRowsByParams), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) {
|
||||||
|
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
|
||||||
|
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
|
||||||
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
|
||||||
|
tableName,
|
||||||
|
strings.Join(columns, ", "),
|
||||||
|
placeholders)
|
||||||
|
|
||||||
|
args := make([]interface{}, 0, len(values)*len(columns))
|
||||||
|
for _, row := range values {
|
||||||
|
args = append(args, row...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return query, args
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.BeginTx(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkCount := (len(values) + chunkSize - 1) / chunkSize
|
||||||
|
results := make([]sql.Result, 0, chunkCount)
|
||||||
|
for start := 0; start < len(values); start += chunkSize {
|
||||||
|
end := start + chunkSize
|
||||||
|
if end > len(values) {
|
||||||
|
end = len(values)
|
||||||
|
}
|
||||||
|
chunkIndex := start/chunkSize + 1
|
||||||
|
query, args := buildBatchInsertQuery(tableName, columns, values[start:end])
|
||||||
|
chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{
|
||||||
|
ChunkIndex: chunkIndex,
|
||||||
|
ChunkCount: chunkCount,
|
||||||
|
ChunkRows: end - start,
|
||||||
|
TotalRows: len(values),
|
||||||
|
ColumnCount: len(columns),
|
||||||
|
})
|
||||||
|
result, execErr := tx.exec(chunkCtx, query, args...)
|
||||||
|
if execErr != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return nil, execErr
|
||||||
|
}
|
||||||
|
results = append(results, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return multiSQLResult{results: results}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// BatchInsert performs batch insert operation
|
// BatchInsert performs batch insert operation
|
||||||
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
|
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
|
||||||
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
||||||
@ -21,26 +233,33 @@ func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, colum
|
|||||||
|
|
||||||
// batchInsert is the internal implementation
|
// batchInsert is the internal implementation
|
||||||
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
||||||
|
if strings.TrimSpace(tableName) == "" {
|
||||||
|
return nil, ErrTableNameEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(columns) == 0 {
|
||||||
|
return nil, ErrNoInsertColumns
|
||||||
|
}
|
||||||
|
|
||||||
if len(values) == 0 {
|
if len(values) == 0 {
|
||||||
return nil, fmt.Errorf("no values to insert")
|
return nil, ErrNoInsertValues
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build placeholders: (?, ?), (?, ?), ...
|
for i, row := range values {
|
||||||
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
|
if len(row) != len(columns) {
|
||||||
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
|
return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns))
|
||||||
|
}
|
||||||
// Build SQL
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
|
|
||||||
tableName,
|
|
||||||
strings.Join(columns, ", "),
|
|
||||||
placeholders)
|
|
||||||
|
|
||||||
// Flatten values
|
|
||||||
var args []interface{}
|
|
||||||
for _, row := range values {
|
|
||||||
args = append(args, row...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chunkSize, err := s.batchInsertChunkSize(len(columns))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if chunkSize > 0 && len(values) > chunkSize {
|
||||||
|
return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
query, args := buildBatchInsertQuery(tableName, columns, values)
|
||||||
return s.exec(ctx, query, args...)
|
return s.exec(ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -56,18 +275,25 @@ func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string
|
|||||||
|
|
||||||
// batchInsertStructs is the internal implementation
|
// batchInsertStructs is the internal implementation
|
||||||
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
||||||
|
if structs == nil {
|
||||||
|
return nil, ErrStructsNil
|
||||||
|
}
|
||||||
|
|
||||||
// Get slice of structs
|
// Get slice of structs
|
||||||
targetValue := reflect.ValueOf(structs)
|
targetValue := reflect.ValueOf(structs)
|
||||||
if targetValue.Kind() == reflect.Ptr {
|
if targetValue.Kind() == reflect.Ptr {
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return nil, ErrStructsPointerNil
|
||||||
|
}
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
|
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
|
||||||
return nil, fmt.Errorf("structs must be a slice or array")
|
return nil, ErrStructsNotSlice
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Len() == 0 {
|
if targetValue.Len() == 0 {
|
||||||
return nil, fmt.Errorf("no structs to insert")
|
return nil, ErrNoStructsToInsert
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get field names from first struct
|
// Get field names from first struct
|
||||||
|
|||||||
57
builder.go
57
builder.go
@ -7,13 +7,17 @@ import (
|
|||||||
|
|
||||||
// QueryBuilder helps build SQL queries
|
// QueryBuilder helps build SQL queries
|
||||||
type QueryBuilder struct {
|
type QueryBuilder struct {
|
||||||
table string
|
table string
|
||||||
columns []string
|
columns []string
|
||||||
where []string
|
joins []string
|
||||||
whereArgs []interface{}
|
where []string
|
||||||
orderBy string
|
whereArgs []interface{}
|
||||||
limit int
|
groupBy []string
|
||||||
offset int
|
having []string
|
||||||
|
havingArgs []interface{}
|
||||||
|
orderBy string
|
||||||
|
limit int
|
||||||
|
offset int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewQueryBuilder creates a new query builder
|
// NewQueryBuilder creates a new query builder
|
||||||
@ -37,6 +41,25 @@ func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuild
|
|||||||
return qb
|
return qb
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Join adds a JOIN clause.
|
||||||
|
func (qb *QueryBuilder) Join(clause string) *QueryBuilder {
|
||||||
|
qb.joins = append(qb.joins, clause)
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupBy sets GROUP BY columns.
|
||||||
|
func (qb *QueryBuilder) GroupBy(columns ...string) *QueryBuilder {
|
||||||
|
qb.groupBy = append(qb.groupBy, columns...)
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
|
// Having adds a HAVING condition.
|
||||||
|
func (qb *QueryBuilder) Having(condition string, args ...interface{}) *QueryBuilder {
|
||||||
|
qb.having = append(qb.having, condition)
|
||||||
|
qb.havingArgs = append(qb.havingArgs, args...)
|
||||||
|
return qb
|
||||||
|
}
|
||||||
|
|
||||||
// OrderBy sets the ORDER BY clause
|
// OrderBy sets the ORDER BY clause
|
||||||
func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder {
|
func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder {
|
||||||
qb.orderBy = orderBy
|
qb.orderBy = orderBy
|
||||||
@ -63,11 +86,26 @@ func (qb *QueryBuilder) Build() (string, []interface{}) {
|
|||||||
parts = append(parts, fmt.Sprintf("SELECT %s FROM %s",
|
parts = append(parts, fmt.Sprintf("SELECT %s FROM %s",
|
||||||
strings.Join(qb.columns, ", "), qb.table))
|
strings.Join(qb.columns, ", "), qb.table))
|
||||||
|
|
||||||
|
// JOIN
|
||||||
|
if len(qb.joins) > 0 {
|
||||||
|
parts = append(parts, strings.Join(qb.joins, " "))
|
||||||
|
}
|
||||||
|
|
||||||
// WHERE
|
// WHERE
|
||||||
if len(qb.where) > 0 {
|
if len(qb.where) > 0 {
|
||||||
parts = append(parts, "WHERE "+strings.Join(qb.where, " AND "))
|
parts = append(parts, "WHERE "+strings.Join(qb.where, " AND "))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GROUP BY
|
||||||
|
if len(qb.groupBy) > 0 {
|
||||||
|
parts = append(parts, "GROUP BY "+strings.Join(qb.groupBy, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HAVING
|
||||||
|
if len(qb.having) > 0 {
|
||||||
|
parts = append(parts, "HAVING "+strings.Join(qb.having, " AND "))
|
||||||
|
}
|
||||||
|
|
||||||
// ORDER BY
|
// ORDER BY
|
||||||
if qb.orderBy != "" {
|
if qb.orderBy != "" {
|
||||||
parts = append(parts, "ORDER BY "+qb.orderBy)
|
parts = append(parts, "ORDER BY "+qb.orderBy)
|
||||||
@ -83,7 +121,10 @@ func (qb *QueryBuilder) Build() (string, []interface{}) {
|
|||||||
parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset))
|
parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset))
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Join(parts, " "), qb.whereArgs
|
args := make([]interface{}, 0, len(qb.whereArgs)+len(qb.havingArgs))
|
||||||
|
args = append(args, qb.whereArgs...)
|
||||||
|
args = append(args, qb.havingArgs...)
|
||||||
|
return strings.Join(parts, " "), args
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes the query
|
// Query executes the query
|
||||||
|
|||||||
@ -271,6 +271,21 @@ func TestQueryBuilder_Chaining(t *testing.T) {
|
|||||||
if qb != qb6 {
|
if qb != qb6 {
|
||||||
t.Error("Offset should return the same builder instance")
|
t.Error("Offset should return the same builder instance")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
qb7 := qb.Join("LEFT JOIN orders o ON o.user_id = users.id")
|
||||||
|
if qb != qb7 {
|
||||||
|
t.Error("Join should return the same builder instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
qb8 := qb.GroupBy("users.id")
|
||||||
|
if qb != qb8 {
|
||||||
|
t.Error("GroupBy should return the same builder instance")
|
||||||
|
}
|
||||||
|
|
||||||
|
qb9 := qb.Having("COUNT(o.id) > ?", 1)
|
||||||
|
if qb != qb9 {
|
||||||
|
t.Error("Having should return the same builder instance")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestQueryBuilder_EmptyWhere(t *testing.T) {
|
func TestQueryBuilder_EmptyWhere(t *testing.T) {
|
||||||
@ -439,6 +454,50 @@ func TestQueryBuilder_JoinLikeWhere(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryBuilder_Build_WithJoinGroupByHaving(t *testing.T) {
|
||||||
|
qb := NewQueryBuilder("users u").
|
||||||
|
Select("u.id", "u.name", "COUNT(o.id) AS order_count").
|
||||||
|
Join("LEFT JOIN orders o ON o.user_id = u.id").
|
||||||
|
Where("u.active = ?", true).
|
||||||
|
GroupBy("u.id", "u.name").
|
||||||
|
Having("COUNT(o.id) > ?", 2).
|
||||||
|
OrderBy("order_count DESC")
|
||||||
|
|
||||||
|
query, args := qb.Build()
|
||||||
|
|
||||||
|
expectedQuery := "SELECT u.id, u.name, COUNT(o.id) AS order_count FROM users u LEFT JOIN orders o ON o.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(o.id) > ? ORDER BY order_count DESC"
|
||||||
|
if query != expectedQuery {
|
||||||
|
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedArgs := []interface{}{true, 2}
|
||||||
|
if len(args) != len(expectedArgs) {
|
||||||
|
t.Fatalf("Expected %d args, got %d", len(expectedArgs), len(args))
|
||||||
|
}
|
||||||
|
for i, expected := range expectedArgs {
|
||||||
|
if args[i] != expected {
|
||||||
|
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryBuilder_Build_HavingWithoutWhere(t *testing.T) {
|
||||||
|
qb := NewQueryBuilder("orders").
|
||||||
|
Select("user_id", "COUNT(*) AS cnt").
|
||||||
|
GroupBy("user_id").
|
||||||
|
Having("COUNT(*) >= ?", 3)
|
||||||
|
|
||||||
|
query, args := qb.Build()
|
||||||
|
|
||||||
|
expectedQuery := "SELECT user_id, COUNT(*) AS cnt FROM orders GROUP BY user_id HAVING COUNT(*) >= ?"
|
||||||
|
if query != expectedQuery {
|
||||||
|
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
||||||
|
}
|
||||||
|
if len(args) != 1 || args[0] != 3 {
|
||||||
|
t.Errorf("Expected args [3], got %v", args)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Benchmark tests
|
// Benchmark tests
|
||||||
func BenchmarkQueryBuilder_Simple(b *testing.B) {
|
func BenchmarkQueryBuilder_Simple(b *testing.B) {
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
|
|||||||
156
converter.go
156
converter.go
@ -1,176 +1,32 @@
|
|||||||
package stardb
|
package stardb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"strconv"
|
internalconv "b612.me/stardb/internal/convert"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// convertToInt64 converts any value to int64
|
// convertToInt64 converts any value to int64
|
||||||
func convertToInt64(val interface{}) int64 {
|
func convertToInt64(val interface{}) int64 {
|
||||||
switch v := val.(type) {
|
return internalconv.ToInt64(val)
|
||||||
case nil:
|
|
||||||
return 0
|
|
||||||
case int:
|
|
||||||
return int64(v)
|
|
||||||
case int32:
|
|
||||||
return int64(v)
|
|
||||||
case int64:
|
|
||||||
return v
|
|
||||||
case uint64:
|
|
||||||
return int64(v)
|
|
||||||
case float32:
|
|
||||||
return int64(v)
|
|
||||||
case float64:
|
|
||||||
return int64(v)
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseInt(v, 10, 64)
|
|
||||||
return result
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
case time.Time:
|
|
||||||
return v.Unix()
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseInt(string(v), 10, 64)
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertToUint64 converts any value to uint64
|
// convertToUint64 converts any value to uint64
|
||||||
func convertToUint64(val interface{}) uint64 {
|
func convertToUint64(val interface{}) uint64 {
|
||||||
switch v := val.(type) {
|
return internalconv.ToUint64(val)
|
||||||
case nil:
|
|
||||||
return 0
|
|
||||||
case int:
|
|
||||||
return uint64(v)
|
|
||||||
case int32:
|
|
||||||
return uint64(v)
|
|
||||||
case int64:
|
|
||||||
return uint64(v)
|
|
||||||
case uint64:
|
|
||||||
return v
|
|
||||||
case float32:
|
|
||||||
return uint64(v)
|
|
||||||
case float64:
|
|
||||||
return uint64(v)
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseUint(v, 10, 64)
|
|
||||||
return result
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
case time.Time:
|
|
||||||
return uint64(v.Unix())
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseUint(string(v), 10, 64)
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertToFloat64 converts any value to float64
|
// convertToFloat64 converts any value to float64
|
||||||
func convertToFloat64(val interface{}) float64 {
|
func convertToFloat64(val interface{}) float64 {
|
||||||
switch v := val.(type) {
|
return internalconv.ToFloat64(val)
|
||||||
case nil:
|
|
||||||
return 0
|
|
||||||
case int:
|
|
||||||
return float64(v)
|
|
||||||
case int32:
|
|
||||||
return float64(v)
|
|
||||||
case int64:
|
|
||||||
return float64(v)
|
|
||||||
case uint64:
|
|
||||||
return float64(v)
|
|
||||||
case float32:
|
|
||||||
return float64(v)
|
|
||||||
case float64:
|
|
||||||
return v
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseFloat(v, 64)
|
|
||||||
return result
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
case time.Time:
|
|
||||||
return float64(v.Unix())
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseFloat(string(v), 64)
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertToBool converts any value to bool
|
// convertToBool converts any value to bool
|
||||||
// Non-zero numbers are considered true
|
// Non-zero numbers are considered true
|
||||||
func convertToBool(val interface{}) bool {
|
func convertToBool(val interface{}) bool {
|
||||||
switch v := val.(type) {
|
return internalconv.ToBool(val)
|
||||||
case nil:
|
|
||||||
return false
|
|
||||||
case bool:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return v != 0
|
|
||||||
case int32:
|
|
||||||
return v != 0
|
|
||||||
case int64:
|
|
||||||
return v != 0
|
|
||||||
case uint64:
|
|
||||||
return v != 0
|
|
||||||
case float32:
|
|
||||||
return v != 0
|
|
||||||
case float64:
|
|
||||||
return v != 0
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseBool(v)
|
|
||||||
return result
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseBool(string(v))
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// convertToTime converts any value to time.Time
|
// convertToTime converts any value to time.Time
|
||||||
func convertToTime(val interface{}, layout string) time.Time {
|
func convertToTime(val interface{}, layout string) time.Time {
|
||||||
switch v := val.(type) {
|
return internalconv.ToTime(val, layout)
|
||||||
case nil:
|
|
||||||
return time.Time{}
|
|
||||||
case time.Time:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return time.Unix(int64(v), 0)
|
|
||||||
case int32:
|
|
||||||
return time.Unix(int64(v), 0)
|
|
||||||
case int64:
|
|
||||||
return time.Unix(v, 0)
|
|
||||||
case uint64:
|
|
||||||
return time.Unix(int64(v), 0)
|
|
||||||
case float32:
|
|
||||||
sec := int64(v)
|
|
||||||
nsec := int64((v - float32(sec)) * 1e9)
|
|
||||||
return time.Unix(sec, nsec)
|
|
||||||
case float64:
|
|
||||||
sec := int64(v)
|
|
||||||
nsec := int64((v - float64(sec)) * 1e9)
|
|
||||||
return time.Unix(sec, nsec)
|
|
||||||
case string:
|
|
||||||
result, _ := time.Parse(layout, v)
|
|
||||||
return result
|
|
||||||
case []byte:
|
|
||||||
result, _ := time.Parse(layout, string(v))
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,68 +1,13 @@
|
|||||||
package stardb
|
package stardb
|
||||||
|
|
||||||
import (
|
import internalconv "b612.me/stardb/internal/convert"
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConvertToInt64Safe converts any value to int64 with error handling
|
// ConvertToInt64Safe converts any value to int64 with error handling
|
||||||
func ConvertToInt64Safe(val interface{}) (int64, error) {
|
func ConvertToInt64Safe(val interface{}) (int64, error) {
|
||||||
switch v := val.(type) {
|
return internalconv.ToInt64Safe(val)
|
||||||
case nil:
|
|
||||||
return 0, nil
|
|
||||||
case int:
|
|
||||||
return int64(v), nil
|
|
||||||
case int32:
|
|
||||||
return int64(v), nil
|
|
||||||
case int64:
|
|
||||||
return v, nil
|
|
||||||
case uint64:
|
|
||||||
return int64(v), nil
|
|
||||||
case float32:
|
|
||||||
return int64(v), nil
|
|
||||||
case float64:
|
|
||||||
return int64(v), nil
|
|
||||||
case string:
|
|
||||||
return strconv.ParseInt(v, 10, 64)
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
case time.Time:
|
|
||||||
return v.Unix(), nil
|
|
||||||
case []byte:
|
|
||||||
return strconv.ParseInt(string(v), 10, 64)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("cannot convert %T to int64", val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertToStringSafe converts any value to string with error handling
|
// ConvertToStringSafe converts any value to string with error handling
|
||||||
func ConvertToStringSafe(val interface{}) (string, error) {
|
func ConvertToStringSafe(val interface{}) (string, error) {
|
||||||
switch v := val.(type) {
|
return internalconv.ToStringSafe(val)
|
||||||
case nil:
|
|
||||||
return "", nil
|
|
||||||
case string:
|
|
||||||
return v, nil
|
|
||||||
case int:
|
|
||||||
return strconv.Itoa(v), nil
|
|
||||||
case int32:
|
|
||||||
return strconv.FormatInt(int64(v), 10), nil
|
|
||||||
case int64:
|
|
||||||
return strconv.FormatInt(v, 10), nil
|
|
||||||
case float32:
|
|
||||||
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
|
|
||||||
case float64:
|
|
||||||
return strconv.FormatFloat(v, 'f', -1, 64), nil
|
|
||||||
case bool:
|
|
||||||
return strconv.FormatBool(v), nil
|
|
||||||
case time.Time:
|
|
||||||
return v.String(), nil
|
|
||||||
case []byte:
|
|
||||||
return string(v), nil
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("cannot convert %T to string", val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
64
errors.go
Normal file
64
errors.go
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// Lifecycle errors.
|
||||||
|
ErrDBNotInitialized = errors.New("database is not initialized; call Open or SetDB first")
|
||||||
|
ErrTxNotInitialized = errors.New("transaction is not initialized")
|
||||||
|
ErrStmtNotInitialized = errors.New("statement is not initialized")
|
||||||
|
ErrStmtDBNotInitialized = errors.New("statement database context is not initialized")
|
||||||
|
|
||||||
|
// SQL input errors.
|
||||||
|
ErrQueryEmpty = errors.New("query string cannot be empty")
|
||||||
|
ErrScanStopped = errors.New("scan stopped by callback")
|
||||||
|
ErrScanFuncNil = errors.New("scan callback cannot be nil")
|
||||||
|
ErrScanORMFuncNil = errors.New("scan orm callback cannot be nil")
|
||||||
|
|
||||||
|
// Mapping and schema errors.
|
||||||
|
ErrColumnNotFound = errors.New("column not found")
|
||||||
|
ErrFieldNotFound = errors.New("field not found")
|
||||||
|
ErrRowIndexOutOfRange = errors.New("row index out of range")
|
||||||
|
|
||||||
|
// Target validation errors.
|
||||||
|
ErrTargetNil = errors.New("target cannot be nil")
|
||||||
|
ErrTargetsNil = errors.New("targets cannot be nil")
|
||||||
|
ErrTargetNotPointer = errors.New("target must be a pointer")
|
||||||
|
ErrTargetPointerNil = errors.New("target pointer cannot be nil")
|
||||||
|
ErrTargetsPointerNil = errors.New("targets pointer is nil")
|
||||||
|
ErrTargetNotStruct = errors.New("target is not a struct")
|
||||||
|
ErrTargetNotWritable = errors.New("target is not writable")
|
||||||
|
ErrPointerTargetNil = errors.New("pointer target is nil")
|
||||||
|
|
||||||
|
// SQL builder errors.
|
||||||
|
ErrTableNameEmpty = errors.New("table name cannot be empty")
|
||||||
|
ErrPrimaryKeyRequired = errors.New("at least one primary key is required")
|
||||||
|
ErrPrimaryKeyEmpty = errors.New("primary key cannot be empty")
|
||||||
|
ErrNoInsertColumns = errors.New("no columns to insert")
|
||||||
|
ErrNoInsertValues = errors.New("no values to insert")
|
||||||
|
ErrBatchInsertMaxParamsTooLow = errors.New("batch insert max params is lower than column count")
|
||||||
|
ErrNoUpdateFields = errors.New("no fields to update after excluding primary keys")
|
||||||
|
ErrBatchRowValueCountMismatch = errors.New("row values count does not match columns")
|
||||||
|
ErrStructsNil = errors.New("structs cannot be nil")
|
||||||
|
ErrStructsPointerNil = errors.New("structs pointer is nil")
|
||||||
|
ErrStructsNotSlice = errors.New("structs must be a slice or array")
|
||||||
|
ErrNoStructsToInsert = errors.New("no structs to insert")
|
||||||
|
|
||||||
|
// Transaction helper errors.
|
||||||
|
ErrTxFuncNil = errors.New("transaction callback cannot be nil")
|
||||||
|
)
|
||||||
|
|
||||||
|
func wrapColumnNotFound(column string) error {
|
||||||
|
return fmt.Errorf("%w: %s", ErrColumnNotFound, column)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapFieldNotFound(field string) error {
|
||||||
|
return fmt.Errorf("%w: %s", ErrFieldNotFound, field)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapBatchRowValueCountMismatch(rowIndex, got, expected int) error {
|
||||||
|
return fmt.Errorf("%w: row %d has %d values, expected %d", ErrBatchRowValueCountMismatch, rowIndex, got, expected)
|
||||||
|
}
|
||||||
359
internal/convert/basic.go
Normal file
359
internal/convert/basic.go
Normal file
@ -0,0 +1,359 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultNullTimeLayouts = []string{
|
||||||
|
time.RFC3339Nano,
|
||||||
|
time.RFC3339,
|
||||||
|
"2006-01-02 15:04:05",
|
||||||
|
"2006-01-02",
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToInt64 converts any value to int64.
|
||||||
|
func ToInt64(val interface{}) int64 {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return 0
|
||||||
|
case int:
|
||||||
|
return int64(v)
|
||||||
|
case int32:
|
||||||
|
return int64(v)
|
||||||
|
case int64:
|
||||||
|
return v
|
||||||
|
case uint64:
|
||||||
|
return int64(v)
|
||||||
|
case float32:
|
||||||
|
return int64(v)
|
||||||
|
case float64:
|
||||||
|
return int64(v)
|
||||||
|
case string:
|
||||||
|
result, _ := strconv.ParseInt(v, 10, 64)
|
||||||
|
return result
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
case time.Time:
|
||||||
|
return v.Unix()
|
||||||
|
case []byte:
|
||||||
|
result, _ := strconv.ParseInt(string(v), 10, 64)
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToUint64 converts any value to uint64.
|
||||||
|
func ToUint64(val interface{}) uint64 {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return 0
|
||||||
|
case int:
|
||||||
|
return uint64(v)
|
||||||
|
case int32:
|
||||||
|
return uint64(v)
|
||||||
|
case int64:
|
||||||
|
return uint64(v)
|
||||||
|
case uint64:
|
||||||
|
return v
|
||||||
|
case float32:
|
||||||
|
return uint64(v)
|
||||||
|
case float64:
|
||||||
|
return uint64(v)
|
||||||
|
case string:
|
||||||
|
result, _ := strconv.ParseUint(v, 10, 64)
|
||||||
|
return result
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
case time.Time:
|
||||||
|
return uint64(v.Unix())
|
||||||
|
case []byte:
|
||||||
|
result, _ := strconv.ParseUint(string(v), 10, 64)
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToFloat64 converts any value to float64.
|
||||||
|
func ToFloat64(val interface{}) float64 {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return 0
|
||||||
|
case int:
|
||||||
|
return float64(v)
|
||||||
|
case int32:
|
||||||
|
return float64(v)
|
||||||
|
case int64:
|
||||||
|
return float64(v)
|
||||||
|
case uint64:
|
||||||
|
return float64(v)
|
||||||
|
case float32:
|
||||||
|
return float64(v)
|
||||||
|
case float64:
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
result, _ := strconv.ParseFloat(v, 64)
|
||||||
|
return result
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
case time.Time:
|
||||||
|
return float64(v.Unix())
|
||||||
|
case []byte:
|
||||||
|
result, _ := strconv.ParseFloat(string(v), 64)
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToBool converts any value to bool.
|
||||||
|
func ToBool(val interface{}) bool {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return false
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case int:
|
||||||
|
return v != 0
|
||||||
|
case int32:
|
||||||
|
return v != 0
|
||||||
|
case int64:
|
||||||
|
return v != 0
|
||||||
|
case uint64:
|
||||||
|
return v != 0
|
||||||
|
case float32:
|
||||||
|
return v != 0
|
||||||
|
case float64:
|
||||||
|
return v != 0
|
||||||
|
case string:
|
||||||
|
result, _ := strconv.ParseBool(v)
|
||||||
|
return result
|
||||||
|
case []byte:
|
||||||
|
result, _ := strconv.ParseBool(string(v))
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTime converts any value to time.Time.
|
||||||
|
func ToTime(val interface{}, layout string) time.Time {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return time.Time{}
|
||||||
|
case time.Time:
|
||||||
|
return v
|
||||||
|
case int:
|
||||||
|
return time.Unix(int64(v), 0)
|
||||||
|
case int32:
|
||||||
|
return time.Unix(int64(v), 0)
|
||||||
|
case int64:
|
||||||
|
return time.Unix(v, 0)
|
||||||
|
case uint64:
|
||||||
|
return time.Unix(int64(v), 0)
|
||||||
|
case float32:
|
||||||
|
sec := int64(v)
|
||||||
|
nsec := int64((v - float32(sec)) * 1e9)
|
||||||
|
return time.Unix(sec, nsec)
|
||||||
|
case float64:
|
||||||
|
sec := int64(v)
|
||||||
|
nsec := int64((v - float64(sec)) * 1e9)
|
||||||
|
return time.Unix(sec, nsec)
|
||||||
|
case string:
|
||||||
|
result, _ := time.Parse(layout, v)
|
||||||
|
return result
|
||||||
|
case []byte:
|
||||||
|
result, _ := time.Parse(layout, string(v))
|
||||||
|
return result
|
||||||
|
default:
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToInt64Safe converts any value to int64 with error handling.
|
||||||
|
func ToInt64Safe(val interface{}) (int64, error) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return 0, nil
|
||||||
|
case int:
|
||||||
|
return int64(v), nil
|
||||||
|
case int32:
|
||||||
|
return int64(v), nil
|
||||||
|
case int64:
|
||||||
|
return v, nil
|
||||||
|
case uint64:
|
||||||
|
return int64(v), nil
|
||||||
|
case float32:
|
||||||
|
return int64(v), nil
|
||||||
|
case float64:
|
||||||
|
return int64(v), nil
|
||||||
|
case string:
|
||||||
|
return strconv.ParseInt(v, 10, 64)
|
||||||
|
case bool:
|
||||||
|
if v {
|
||||||
|
return 1, nil
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
case time.Time:
|
||||||
|
return v.Unix(), nil
|
||||||
|
case []byte:
|
||||||
|
return strconv.ParseInt(string(v), 10, 64)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("cannot convert %T to int64", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToStringSafe converts any value to string with error handling.
|
||||||
|
func ToStringSafe(val interface{}) (string, error) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return "", nil
|
||||||
|
case string:
|
||||||
|
return v, nil
|
||||||
|
case int:
|
||||||
|
return strconv.Itoa(v), nil
|
||||||
|
case int32:
|
||||||
|
return strconv.FormatInt(int64(v), 10), nil
|
||||||
|
case int64:
|
||||||
|
return strconv.FormatInt(v, 10), nil
|
||||||
|
case float32:
|
||||||
|
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
|
||||||
|
case float64:
|
||||||
|
return strconv.FormatFloat(v, 'f', -1, 64), nil
|
||||||
|
case bool:
|
||||||
|
return strconv.FormatBool(v), nil
|
||||||
|
case time.Time:
|
||||||
|
return v.String(), nil
|
||||||
|
case []byte:
|
||||||
|
return string(v), nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("cannot convert %T to string", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToFloat64Safe converts any value to float64 with error handling.
|
||||||
|
func ToFloat64Safe(val interface{}) (float64, error) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return 0, nil
|
||||||
|
case float64:
|
||||||
|
return v, nil
|
||||||
|
case float32:
|
||||||
|
return float64(v), nil
|
||||||
|
case int, int32, int64, uint64:
|
||||||
|
intVal, err := ToInt64Safe(v)
|
||||||
|
return float64(intVal), err
|
||||||
|
case string:
|
||||||
|
return strconv.ParseFloat(v, 64)
|
||||||
|
case []byte:
|
||||||
|
return strconv.ParseFloat(string(v), 64)
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("cannot convert %T to float64", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToBoolSafe converts any value to bool with error handling.
|
||||||
|
func ToBoolSafe(val interface{}) (bool, error) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return false, nil
|
||||||
|
case bool:
|
||||||
|
return v, nil
|
||||||
|
case int:
|
||||||
|
return v != 0, nil
|
||||||
|
case int8:
|
||||||
|
return v != 0, nil
|
||||||
|
case int16:
|
||||||
|
return v != 0, nil
|
||||||
|
case int32:
|
||||||
|
return v != 0, nil
|
||||||
|
case int64:
|
||||||
|
return v != 0, nil
|
||||||
|
case uint:
|
||||||
|
return v != 0, nil
|
||||||
|
case uint8:
|
||||||
|
return v != 0, nil
|
||||||
|
case uint16:
|
||||||
|
return v != 0, nil
|
||||||
|
case uint32:
|
||||||
|
return v != 0, nil
|
||||||
|
case uint64:
|
||||||
|
return v != 0, nil
|
||||||
|
case float32:
|
||||||
|
return v != 0, nil
|
||||||
|
case float64:
|
||||||
|
return v != 0, nil
|
||||||
|
case string:
|
||||||
|
return ParseBoolString(v)
|
||||||
|
case []byte:
|
||||||
|
return ParseBoolString(string(v))
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("cannot convert %T to bool", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseBoolString parses string-like bool values.
|
||||||
|
func ParseBoolString(raw string) (bool, error) {
|
||||||
|
normalized := strings.TrimSpace(strings.ToLower(raw))
|
||||||
|
switch normalized {
|
||||||
|
case "", "0", "false", "f", "off", "no", "n":
|
||||||
|
return false, nil
|
||||||
|
case "1", "true", "t", "on", "yes", "y":
|
||||||
|
return true, nil
|
||||||
|
default:
|
||||||
|
return false, fmt.Errorf("cannot parse bool value: %q", raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToTimeSafe converts any value to time.Time with error handling.
|
||||||
|
func ToTimeSafe(val interface{}) (time.Time, error) {
|
||||||
|
switch v := val.(type) {
|
||||||
|
case nil:
|
||||||
|
return time.Time{}, nil
|
||||||
|
case time.Time:
|
||||||
|
return v, nil
|
||||||
|
case int:
|
||||||
|
return time.Unix(int64(v), 0), nil
|
||||||
|
case int64:
|
||||||
|
return time.Unix(v, 0), nil
|
||||||
|
case string:
|
||||||
|
return ParseTimeValue(v)
|
||||||
|
case []byte:
|
||||||
|
return ParseTimeValue(string(v))
|
||||||
|
default:
|
||||||
|
return time.Time{}, fmt.Errorf("cannot convert %T to time.Time", val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseTimeValue parses common SQL date-time formats and unix timestamp.
|
||||||
|
func ParseTimeValue(raw string) (time.Time, error) {
|
||||||
|
trimmed := strings.TrimSpace(raw)
|
||||||
|
if trimmed == "" {
|
||||||
|
return time.Time{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, layout := range defaultNullTimeLayouts {
|
||||||
|
if t, err := time.Parse(layout, trimmed); err == nil {
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ts, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
||||||
|
return time.Unix(ts, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return time.Time{}, fmt.Errorf("cannot parse time value: %q", raw)
|
||||||
|
}
|
||||||
12
internal/scanutil/value_clone.go
Normal file
12
internal/scanutil/value_clone.go
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
package scanutil
|
||||||
|
|
||||||
|
// CloneScannedValue copies driver-scanned values that may be reused by driver.
|
||||||
|
// []byte is deep-copied; other types are returned as-is.
|
||||||
|
func CloneScannedValue(val interface{}) interface{} {
|
||||||
|
if b, ok := val.([]byte); ok {
|
||||||
|
copied := make([]byte, len(b))
|
||||||
|
copy(copied, b)
|
||||||
|
return copied
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
119
internal/sqlplaceholder/convert.go
Normal file
119
internal/sqlplaceholder/convert.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package sqlplaceholder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
|
||||||
|
// It skips quoted strings, quoted identifiers and comments.
|
||||||
|
func ConvertQuestionToDollarPlaceholders(query string) string {
|
||||||
|
if query == "" || !strings.Contains(query, "?") {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateNormal = iota
|
||||||
|
stateSingleQuote
|
||||||
|
stateDoubleQuote
|
||||||
|
stateBacktick
|
||||||
|
stateLineComment
|
||||||
|
stateBlockComment
|
||||||
|
)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(query) + 8)
|
||||||
|
|
||||||
|
state := stateNormal
|
||||||
|
index := 1
|
||||||
|
|
||||||
|
for i := 0; i < len(query); i++ {
|
||||||
|
c := query[i]
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case stateNormal:
|
||||||
|
if c == '\'' {
|
||||||
|
state = stateSingleQuote
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '"' {
|
||||||
|
state = stateDoubleQuote
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '`' {
|
||||||
|
state = stateBacktick
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
||||||
|
state = stateLineComment
|
||||||
|
b.WriteByte(c)
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
||||||
|
state = stateBlockComment
|
||||||
|
b.WriteByte(c)
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '?' {
|
||||||
|
b.WriteByte('$')
|
||||||
|
b.WriteString(strconv.Itoa(index))
|
||||||
|
index++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte(c)
|
||||||
|
|
||||||
|
case stateSingleQuote:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '\'' {
|
||||||
|
// SQL escaped single quote: ''
|
||||||
|
if i+1 < len(query) && query[i+1] == '\'' {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateDoubleQuote:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '"' {
|
||||||
|
// escaped double quote: ""
|
||||||
|
if i+1 < len(query) && query[i+1] == '"' {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateBacktick:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '`' {
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateLineComment:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '\n' {
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateBlockComment:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
323
internal/sqlruntime/fingerprint.go
Normal file
323
internal/sqlruntime/fingerprint.go
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
package sqlruntime
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
|
// FingerprintSQL creates a normalized SQL fingerprint.
|
||||||
|
// mode controls literal masking; keepComments controls whether comments are preserved.
|
||||||
|
func FingerprintSQL(query string, mode int, keepComments bool) string {
|
||||||
|
prepared := query
|
||||||
|
if !keepComments {
|
||||||
|
prepared = stripSQLComments(prepared)
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized := normalizeSQL(prepared)
|
||||||
|
if normalized == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if NormalizeFingerprintMode(mode) == fingerprintModeMaskLiterals {
|
||||||
|
return maskSQLLiterals(normalized, keepComments)
|
||||||
|
}
|
||||||
|
return normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSQL(query string) string {
|
||||||
|
normalized := strings.ToLower(strings.TrimSpace(query))
|
||||||
|
if normalized == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strings.Join(strings.Fields(normalized), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripSQLComments(query string) string {
|
||||||
|
if query == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateNormal = iota
|
||||||
|
stateSingleQuote
|
||||||
|
stateDoubleQuote
|
||||||
|
stateBacktick
|
||||||
|
stateLineComment
|
||||||
|
stateBlockComment
|
||||||
|
)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(query))
|
||||||
|
state := stateNormal
|
||||||
|
|
||||||
|
for i := 0; i < len(query); i++ {
|
||||||
|
c := query[i]
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case stateNormal:
|
||||||
|
if c == '\'' {
|
||||||
|
state = stateSingleQuote
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '"' {
|
||||||
|
state = stateDoubleQuote
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '`' {
|
||||||
|
state = stateBacktick
|
||||||
|
b.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
i++
|
||||||
|
state = stateLineComment
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
i++
|
||||||
|
state = stateBlockComment
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
b.WriteByte(c)
|
||||||
|
|
||||||
|
case stateSingleQuote:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '\'' {
|
||||||
|
if i+1 < len(query) && query[i+1] == '\'' {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateDoubleQuote:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '"' {
|
||||||
|
if i+1 < len(query) && query[i+1] == '"' {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateBacktick:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '`' {
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateLineComment:
|
||||||
|
if c == '\n' {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateBlockComment:
|
||||||
|
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
i++
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func maskSQLLiterals(query string, keepComments bool) string {
|
||||||
|
if query == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateNormal = iota
|
||||||
|
stateSingleQuote
|
||||||
|
stateDoubleQuote
|
||||||
|
stateBacktick
|
||||||
|
stateLineComment
|
||||||
|
stateBlockComment
|
||||||
|
)
|
||||||
|
|
||||||
|
var b strings.Builder
|
||||||
|
b.Grow(len(query))
|
||||||
|
state := stateNormal
|
||||||
|
|
||||||
|
for i := 0; i < len(query); i++ {
|
||||||
|
c := query[i]
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case stateNormal:
|
||||||
|
if c == '\'' {
|
||||||
|
b.WriteByte('?')
|
||||||
|
state = stateSingleQuote
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '"' {
|
||||||
|
b.WriteByte(c)
|
||||||
|
state = stateDoubleQuote
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '`' {
|
||||||
|
b.WriteByte(c)
|
||||||
|
state = stateBacktick
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
||||||
|
if keepComments {
|
||||||
|
b.WriteByte(c)
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
} else {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
state = stateLineComment
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
||||||
|
if keepComments {
|
||||||
|
b.WriteByte(c)
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
} else {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
state = stateBlockComment
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c == '$' {
|
||||||
|
j := i + 1
|
||||||
|
for j < len(query) && isDigit(query[j]) {
|
||||||
|
j++
|
||||||
|
}
|
||||||
|
if j > i+1 {
|
||||||
|
b.WriteByte('?')
|
||||||
|
i = j - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c == '-' && i+1 < len(query) && isDigit(query[i+1]) && isNumberBoundaryBefore(query, i) {
|
||||||
|
j := scanNumber(query, i+1)
|
||||||
|
if isNumberBoundaryAfter(query, j) {
|
||||||
|
b.WriteByte('?')
|
||||||
|
i = j - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if isDigit(c) && isNumberBoundaryBefore(query, i) {
|
||||||
|
j := scanNumber(query, i)
|
||||||
|
if isNumberBoundaryAfter(query, j) {
|
||||||
|
b.WriteByte('?')
|
||||||
|
i = j - 1
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteByte(c)
|
||||||
|
|
||||||
|
case stateSingleQuote:
|
||||||
|
if c == '\'' {
|
||||||
|
if i+1 < len(query) && query[i+1] == '\'' {
|
||||||
|
i++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateDoubleQuote:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '"' {
|
||||||
|
if i+1 < len(query) && query[i+1] == '"' {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateBacktick:
|
||||||
|
b.WriteByte(c)
|
||||||
|
if c == '`' {
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateLineComment:
|
||||||
|
if keepComments {
|
||||||
|
b.WriteByte(c)
|
||||||
|
}
|
||||||
|
if c == '\n' {
|
||||||
|
if !keepComments {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
|
||||||
|
case stateBlockComment:
|
||||||
|
if keepComments {
|
||||||
|
b.WriteByte(c)
|
||||||
|
}
|
||||||
|
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
||||||
|
if keepComments {
|
||||||
|
i++
|
||||||
|
b.WriteByte(query[i])
|
||||||
|
} else {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
state = stateNormal
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(strings.Fields(b.String()), " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDigit(c byte) bool {
|
||||||
|
return c >= '0' && c <= '9'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isIdentifierChar(c byte) bool {
|
||||||
|
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNumberBoundaryBefore(query string, index int) bool {
|
||||||
|
if index <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
prev := query[index-1]
|
||||||
|
return !isIdentifierChar(prev) && prev != '$' && prev != '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
func isNumberBoundaryAfter(query string, index int) bool {
|
||||||
|
if index >= len(query) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
next := query[index]
|
||||||
|
return !isIdentifierChar(next) && next != '.'
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanNumber(query string, start int) int {
|
||||||
|
i := start
|
||||||
|
for i < len(query) && isDigit(query[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
if i < len(query) && query[i] == '.' {
|
||||||
|
i++
|
||||||
|
for i < len(query) && isDigit(query[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i < len(query) && (query[i] == 'e' || query[i] == 'E') {
|
||||||
|
i++
|
||||||
|
if i < len(query) && (query[i] == '+' || query[i] == '-') {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
for i < len(query) && isDigit(query[i]) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
27
internal/sqlruntime/hooks.go
Normal file
27
internal/sqlruntime/hooks.go
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
package sqlruntime
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// CloneHookArgs creates a shallow copy for hook consumers to avoid mutation races.
|
||||||
|
func CloneHookArgs(args []interface{}) []interface{} {
|
||||||
|
if len(args) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copied := make([]interface{}, len(args))
|
||||||
|
copy(copied, args)
|
||||||
|
return copied
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldRunAfterHook decides whether after-hook should run.
|
||||||
|
func ShouldRunAfterHook(hasAfterHook bool, slowThreshold, duration time.Duration, err error) bool {
|
||||||
|
if !hasAfterHook {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if slowThreshold <= 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return duration >= slowThreshold
|
||||||
|
}
|
||||||
269
internal/sqlruntime/state.go
Normal file
269
internal/sqlruntime/state.go
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
package sqlruntime
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
placeholderQuestion = 0
|
||||||
|
placeholderDollar = 1
|
||||||
|
|
||||||
|
fingerprintModeBasic = 0
|
||||||
|
fingerprintModeMaskLiterals = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
// NormalizePlaceholderStyle converts unknown style values to default question style.
|
||||||
|
func NormalizePlaceholderStyle(style int) int {
|
||||||
|
switch style {
|
||||||
|
case placeholderDollar:
|
||||||
|
return placeholderDollar
|
||||||
|
default:
|
||||||
|
return placeholderQuestion
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeFingerprintMode converts unknown mode values to default basic mode.
|
||||||
|
func NormalizeFingerprintMode(mode int) int {
|
||||||
|
switch mode {
|
||||||
|
case fingerprintModeMaskLiterals:
|
||||||
|
return fingerprintModeMaskLiterals
|
||||||
|
default:
|
||||||
|
return fingerprintModeBasic
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// State stores runtime SQL behavior toggles in a thread-safe manner.
|
||||||
|
type State struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
beforeHook interface{}
|
||||||
|
afterHook interface{}
|
||||||
|
placeholder int
|
||||||
|
slowThreshold time.Duration
|
||||||
|
fingerprintEnabled bool
|
||||||
|
fingerprintMode int
|
||||||
|
fingerprintKeepComments bool
|
||||||
|
fingerprintCounterEnabled bool
|
||||||
|
fingerprintCounts map[string]uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options returns snapshot of current runtime options.
|
||||||
|
func (s *State) Options() (before, after interface{}, placeholder int, slowThreshold time.Duration) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil, placeholderQuestion, 0
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
before = s.beforeHook
|
||||||
|
after = s.afterHook
|
||||||
|
placeholder = NormalizePlaceholderStyle(s.placeholder)
|
||||||
|
slowThreshold = s.slowThreshold
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return before, after, placeholder, slowThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hooks returns before/after hooks and slow threshold.
|
||||||
|
func (s *State) Hooks() (before, after interface{}, slowThreshold time.Duration) {
|
||||||
|
before, after, _, slowThreshold = s.Options()
|
||||||
|
return before, after, slowThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetHooks sets before/after hooks.
|
||||||
|
func (s *State) SetHooks(before, after interface{}) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.beforeHook = before
|
||||||
|
s.afterHook = after
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBeforeHook sets before hook.
|
||||||
|
func (s *State) SetBeforeHook(before interface{}) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.beforeHook = before
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetAfterHook sets after hook.
|
||||||
|
func (s *State) SetAfterHook(after interface{}) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.afterHook = after
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPlaceholderStyle sets placeholder style.
|
||||||
|
func (s *State) SetPlaceholderStyle(style int) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.placeholder = NormalizePlaceholderStyle(style)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderStyle returns placeholder style.
|
||||||
|
func (s *State) PlaceholderStyle() int {
|
||||||
|
if s == nil {
|
||||||
|
return placeholderQuestion
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
style := NormalizePlaceholderStyle(s.placeholder)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return style
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSlowThreshold sets minimum duration for triggering after hook.
|
||||||
|
func (s *State) SetSlowThreshold(threshold time.Duration) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if threshold < 0 {
|
||||||
|
threshold = 0
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.slowThreshold = threshold
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlowThreshold returns current slow threshold.
|
||||||
|
func (s *State) SlowThreshold() time.Duration {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
threshold := s.slowThreshold
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return threshold
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
|
||||||
|
func (s *State) SetFingerprintEnabled(enabled bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fingerprintEnabled = enabled
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
|
||||||
|
func (s *State) FingerprintEnabled() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
enabled := s.fingerprintEnabled
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFingerprintMode sets SQL fingerprint mode.
|
||||||
|
func (s *State) SetFingerprintMode(mode int) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fingerprintMode = NormalizeFingerprintMode(mode)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintMode returns SQL fingerprint mode.
|
||||||
|
func (s *State) FingerprintMode() int {
|
||||||
|
if s == nil {
|
||||||
|
return fingerprintModeBasic
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
mode := NormalizeFingerprintMode(s.fingerprintMode)
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return mode
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFingerprintKeepComments toggles comment preservation in generated SQL fingerprints.
|
||||||
|
func (s *State) SetFingerprintKeepComments(keep bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fingerprintKeepComments = keep
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintKeepComments reports whether comments are kept in generated SQL fingerprints.
|
||||||
|
func (s *State) FingerprintKeepComments() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
keep := s.fingerprintKeepComments
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return keep
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
|
||||||
|
func (s *State) SetFingerprintCounterEnabled(enabled bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fingerprintCounterEnabled = enabled
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
|
||||||
|
func (s *State) FingerprintCounterEnabled() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
enabled := s.fingerprintCounterEnabled
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// IncFingerprintCount increments hit count for a fingerprint.
|
||||||
|
func (s *State) IncFingerprintCount(fingerprint string) {
|
||||||
|
if s == nil || fingerprint == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.fingerprintCounts == nil {
|
||||||
|
s.fingerprintCounts = make(map[string]uint64)
|
||||||
|
}
|
||||||
|
s.fingerprintCounts[fingerprint]++
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FingerprintCountsSnapshot returns a snapshot copy of fingerprint counters.
|
||||||
|
func (s *State) FingerprintCountsSnapshot() map[string]uint64 {
|
||||||
|
if s == nil {
|
||||||
|
return map[string]uint64{}
|
||||||
|
}
|
||||||
|
s.mu.RLock()
|
||||||
|
if len(s.fingerprintCounts) == 0 {
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return map[string]uint64{}
|
||||||
|
}
|
||||||
|
out := make(map[string]uint64, len(s.fingerprintCounts))
|
||||||
|
for k, v := range s.fingerprintCounts {
|
||||||
|
out[k] = v
|
||||||
|
}
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetFingerprintCounts clears all fingerprint counters.
|
||||||
|
func (s *State) ResetFingerprintCounts() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fingerprintCounts = nil
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
250
orm.go
250
orm.go
@ -3,7 +3,6 @@ package stardb
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
@ -23,20 +22,27 @@ func (r *StarRows) Orm(target interface{}) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if target == nil {
|
||||||
|
return ErrTargetNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
targetType := reflect.TypeOf(target)
|
||||||
targetValue := reflect.ValueOf(target)
|
targetValue := reflect.ValueOf(target)
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Ptr {
|
if targetType.Kind() != reflect.Ptr {
|
||||||
return errors.New("target must be a pointer")
|
return ErrTargetNotPointer
|
||||||
|
}
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return ErrTargetPointerNil
|
||||||
}
|
}
|
||||||
|
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
|
|
||||||
// Handle slice/array
|
// Handle slice
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
if targetValue.Kind() == reflect.Slice {
|
||||||
elementType := targetType.Elem()
|
elementType := targetType.Elem()
|
||||||
result := reflect.New(targetType).Elem()
|
result := reflect.MakeSlice(targetType, 0, r.Length())
|
||||||
|
|
||||||
if r.Length() == 0 {
|
if r.Length() == 0 {
|
||||||
targetValue.Set(result)
|
targetValue.Set(result)
|
||||||
@ -55,6 +61,29 @@ func (r *StarRows) Orm(target interface{}) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Handle array
|
||||||
|
if targetValue.Kind() == reflect.Array {
|
||||||
|
elementType := targetType.Elem()
|
||||||
|
|
||||||
|
if r.Length() == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.Length() > targetValue.Len() {
|
||||||
|
return fmt.Errorf("target array length %d is smaller than rows %d", targetValue.Len(), r.Length())
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < r.Length(); i++ {
|
||||||
|
element := reflect.New(elementType)
|
||||||
|
if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
targetValue.Index(i).Set(element.Elem())
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Handle single struct
|
// Handle single struct
|
||||||
if r.Length() == 0 {
|
if r.Length() == 0 {
|
||||||
return nil
|
return nil
|
||||||
@ -63,6 +92,35 @@ func (r *StarRows) Orm(target interface{}) error {
|
|||||||
return r.setStructFieldsFromRow(target, "db", 0)
|
return r.setStructFieldsFromRow(target, "db", 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func bindNamedArgs(args []interface{}, fieldValues map[string]interface{}) ([]interface{}, error) {
|
||||||
|
processedArgs := make([]interface{}, len(args))
|
||||||
|
for i, arg := range args {
|
||||||
|
str, ok := arg.(string)
|
||||||
|
if !ok {
|
||||||
|
processedArgs[i] = arg
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(str, `\:`) {
|
||||||
|
processedArgs[i] = str[1:]
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(str, ":") {
|
||||||
|
fieldName := str[1:]
|
||||||
|
val, exists := fieldValues[fieldName]
|
||||||
|
if !exists {
|
||||||
|
return nil, wrapFieldNotFound(fieldName)
|
||||||
|
}
|
||||||
|
processedArgs[i] = val
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
processedArgs[i] = arg
|
||||||
|
}
|
||||||
|
return processedArgs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// QueryX executes a query with named parameter binding
|
// QueryX executes a query with named parameter binding
|
||||||
// Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id")
|
// Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id")
|
||||||
func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
||||||
@ -81,25 +139,9 @@ func (s *StarDB) queryX(ctx context.Context, target interface{}, query string, a
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace named parameters with actual values
|
processedArgs, err := bindNamedArgs(args, fieldValues)
|
||||||
processedArgs := make([]interface{}, len(args))
|
if err != nil {
|
||||||
for i, arg := range args {
|
return nil, err
|
||||||
if str, ok := arg.(string); ok {
|
|
||||||
if strings.HasPrefix(str, ":") {
|
|
||||||
fieldName := str[1:]
|
|
||||||
if val, exists := fieldValues[fieldName]; exists {
|
|
||||||
processedArgs[i] = val
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = ""
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(str, `\:`) {
|
|
||||||
processedArgs[i] = str[1:]
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.query(ctx, query, processedArgs...)
|
return s.query(ctx, query, processedArgs...)
|
||||||
@ -118,16 +160,23 @@ func (s *StarDB) QueryXSContext(ctx context.Context, targets interface{}, query
|
|||||||
// queryXS is the internal implementation
|
// queryXS is the internal implementation
|
||||||
func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
||||||
var results []*StarRows
|
var results []*StarRows
|
||||||
|
if targets == nil {
|
||||||
|
return results, ErrTargetsNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
targetType := reflect.TypeOf(targets)
|
||||||
targetValue := reflect.ValueOf(targets)
|
targetValue := reflect.ValueOf(targets)
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return results, ErrTargetsPointerNil
|
||||||
|
}
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
||||||
|
results = make([]*StarRows, 0, targetValue.Len())
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
for i := 0; i < targetValue.Len(); i++ {
|
||||||
result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
|
result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -163,25 +212,9 @@ func (s *StarDB) execX(ctx context.Context, target interface{}, query string, ar
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Replace named parameters with actual values
|
processedArgs, err := bindNamedArgs(args, fieldValues)
|
||||||
processedArgs := make([]interface{}, len(args))
|
if err != nil {
|
||||||
for i, arg := range args {
|
return nil, err
|
||||||
if str, ok := arg.(string); ok {
|
|
||||||
if strings.HasPrefix(str, ":") {
|
|
||||||
fieldName := str[1:]
|
|
||||||
if val, exists := fieldValues[fieldName]; exists {
|
|
||||||
processedArgs[i] = val
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = ""
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(str, `\:`) {
|
|
||||||
processedArgs[i] = str[1:]
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.exec(ctx, query, processedArgs...)
|
return s.exec(ctx, query, processedArgs...)
|
||||||
@ -200,16 +233,23 @@ func (s *StarDB) ExecXSContext(ctx context.Context, targets interface{}, query s
|
|||||||
// execXS is the internal implementation
|
// execXS is the internal implementation
|
||||||
func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
||||||
var results []sql.Result
|
var results []sql.Result
|
||||||
|
if targets == nil {
|
||||||
|
return results, ErrTargetsNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
targetType := reflect.TypeOf(targets)
|
||||||
targetValue := reflect.ValueOf(targets)
|
targetValue := reflect.ValueOf(targets)
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return results, ErrTargetsPointerNil
|
||||||
|
}
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
||||||
|
results = make([]sql.Result, 0, targetValue.Len())
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
for i := 0; i < targetValue.Len(); i++ {
|
||||||
result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...)
|
result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -246,9 +286,9 @@ func (s *StarDB) insert(ctx context.Context, target interface{}, tableName strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []interface{}{}
|
args := make([]interface{}, len(params))
|
||||||
for _, param := range params {
|
for i, param := range params {
|
||||||
args = append(args, param)
|
args[i] = param
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.execX(ctx, target, query, args...)
|
return s.execX(ctx, target, query, args...)
|
||||||
@ -272,9 +312,9 @@ func (s *StarDB) update(ctx context.Context, target interface{}, tableName strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []interface{}{}
|
args := make([]interface{}, len(params))
|
||||||
for _, param := range params {
|
for i, param := range params {
|
||||||
args = append(args, param)
|
args[i] = param
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.execX(ctx, target, query, args...)
|
return s.execX(ctx, target, query, args...)
|
||||||
@ -282,6 +322,10 @@ func (s *StarDB) update(ctx context.Context, target interface{}, tableName strin
|
|||||||
|
|
||||||
// buildInsertSQL builds an INSERT SQL statement
|
// buildInsertSQL builds an INSERT SQL statement
|
||||||
func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) {
|
func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) {
|
||||||
|
if strings.TrimSpace(tableName) == "" {
|
||||||
|
return "", []string{}, ErrTableNameEmpty
|
||||||
|
}
|
||||||
|
|
||||||
fieldNames, err := getStructFieldNames(target, "db")
|
fieldNames, err := getStructFieldNames(target, "db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", []string{}, err
|
return "", []string{}, err
|
||||||
@ -290,17 +334,14 @@ func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ..
|
|||||||
var columns []string
|
var columns []string
|
||||||
var placeholders []string
|
var placeholders []string
|
||||||
var params []string
|
var params []string
|
||||||
|
autoIncrementSet := make(map[string]struct{}, len(autoIncrementFields))
|
||||||
|
for _, autoField := range autoIncrementFields {
|
||||||
|
autoIncrementSet[autoField] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
for _, fieldName := range fieldNames {
|
for _, fieldName := range fieldNames {
|
||||||
// Skip auto-increment fields
|
// Skip auto-increment fields
|
||||||
isAutoIncrement := false
|
if _, isAutoIncrement := autoIncrementSet[fieldName]; isAutoIncrement {
|
||||||
for _, autoField := range autoIncrementFields {
|
|
||||||
if fieldName == autoField {
|
|
||||||
isAutoIncrement = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isAutoIncrement {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,6 +350,10 @@ func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ..
|
|||||||
params = append(params, ":"+fieldName)
|
params = append(params, ":"+fieldName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(columns) == 0 {
|
||||||
|
return "", []string{}, ErrNoInsertColumns
|
||||||
|
}
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
||||||
tableName,
|
tableName,
|
||||||
strings.Join(columns, ", "),
|
strings.Join(columns, ", "),
|
||||||
@ -319,20 +364,43 @@ func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ..
|
|||||||
|
|
||||||
// buildUpdateSQL builds an UPDATE SQL statement
|
// buildUpdateSQL builds an UPDATE SQL statement
|
||||||
func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) {
|
func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) {
|
||||||
|
if strings.TrimSpace(tableName) == "" {
|
||||||
|
return "", []string{}, ErrTableNameEmpty
|
||||||
|
}
|
||||||
|
|
||||||
fieldNames, err := getStructFieldNames(target, "db")
|
fieldNames, err := getStructFieldNames(target, "db")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", []string{}, err
|
return "", []string{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(primaryKeys) == 0 {
|
||||||
|
return "", []string{}, ErrPrimaryKeyRequired
|
||||||
|
}
|
||||||
|
|
||||||
|
primaryKeySet := make(map[string]struct{}, len(primaryKeys))
|
||||||
|
for _, pk := range primaryKeys {
|
||||||
|
if pk == "" {
|
||||||
|
return "", []string{}, ErrPrimaryKeyEmpty
|
||||||
|
}
|
||||||
|
primaryKeySet[pk] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
var setClauses []string
|
var setClauses []string
|
||||||
var params []string
|
var params []string
|
||||||
|
|
||||||
// Build SET clause
|
// Build SET clause
|
||||||
for _, fieldName := range fieldNames {
|
for _, fieldName := range fieldNames {
|
||||||
|
if _, isPrimaryKey := primaryKeySet[fieldName]; isPrimaryKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName))
|
setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName))
|
||||||
params = append(params, ":"+fieldName)
|
params = append(params, ":"+fieldName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(setClauses) == 0 {
|
||||||
|
return "", []string{}, ErrNoUpdateFields
|
||||||
|
}
|
||||||
|
|
||||||
// Build WHERE clause
|
// Build WHERE clause
|
||||||
var whereClauses []string
|
var whereClauses []string
|
||||||
for _, pk := range primaryKeys {
|
for _, pk := range primaryKeys {
|
||||||
@ -367,24 +435,9 @@ func (t *StarTx) queryX(ctx context.Context, target interface{}, query string, a
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
processedArgs := make([]interface{}, len(args))
|
processedArgs, err := bindNamedArgs(args, fieldValues)
|
||||||
for i, arg := range args {
|
if err != nil {
|
||||||
if str, ok := arg.(string); ok {
|
return nil, err
|
||||||
if strings.HasPrefix(str, ":") {
|
|
||||||
fieldName := str[1:]
|
|
||||||
if val, exists := fieldValues[fieldName]; exists {
|
|
||||||
processedArgs[i] = val
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = ""
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(str, `\:`) {
|
|
||||||
processedArgs[i] = str[1:]
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.query(ctx, query, processedArgs...)
|
return t.query(ctx, query, processedArgs...)
|
||||||
@ -403,16 +456,23 @@ func (t *StarTx) QueryXSContext(ctx context.Context, targets interface{}, query
|
|||||||
// queryXS is the internal implementation
|
// queryXS is the internal implementation
|
||||||
func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
||||||
var results []*StarRows
|
var results []*StarRows
|
||||||
|
if targets == nil {
|
||||||
|
return results, ErrTargetsNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
targetType := reflect.TypeOf(targets)
|
||||||
targetValue := reflect.ValueOf(targets)
|
targetValue := reflect.ValueOf(targets)
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return results, ErrTargetsPointerNil
|
||||||
|
}
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
||||||
|
results = make([]*StarRows, 0, targetValue.Len())
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
for i := 0; i < targetValue.Len(); i++ {
|
||||||
result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
|
result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -448,24 +508,9 @@ func (t *StarTx) execX(ctx context.Context, target interface{}, query string, ar
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
processedArgs := make([]interface{}, len(args))
|
processedArgs, err := bindNamedArgs(args, fieldValues)
|
||||||
for i, arg := range args {
|
if err != nil {
|
||||||
if str, ok := arg.(string); ok {
|
return nil, err
|
||||||
if strings.HasPrefix(str, ":") {
|
|
||||||
fieldName := str[1:]
|
|
||||||
if val, exists := fieldValues[fieldName]; exists {
|
|
||||||
processedArgs[i] = val
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = ""
|
|
||||||
}
|
|
||||||
} else if strings.HasPrefix(str, `\:`) {
|
|
||||||
processedArgs[i] = str[1:]
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.exec(ctx, query, processedArgs...)
|
return t.exec(ctx, query, processedArgs...)
|
||||||
@ -484,16 +529,23 @@ func (t *StarTx) ExecXSContext(ctx context.Context, targets interface{}, query s
|
|||||||
// execXS is the internal implementation
|
// execXS is the internal implementation
|
||||||
func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
||||||
var results []sql.Result
|
var results []sql.Result
|
||||||
|
if targets == nil {
|
||||||
|
return results, ErrTargetsNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
targetType := reflect.TypeOf(targets)
|
||||||
targetValue := reflect.ValueOf(targets)
|
targetValue := reflect.ValueOf(targets)
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return results, ErrTargetsPointerNil
|
||||||
|
}
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
||||||
|
results = make([]sql.Result, 0, targetValue.Len())
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
for i := 0; i < targetValue.Len(); i++ {
|
||||||
result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...)
|
result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -529,9 +581,9 @@ func (t *StarTx) insert(ctx context.Context, target interface{}, tableName strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []interface{}{}
|
args := make([]interface{}, len(params))
|
||||||
for _, param := range params {
|
for i, param := range params {
|
||||||
args = append(args, param)
|
args[i] = param
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.execX(ctx, target, query, args...)
|
return t.execX(ctx, target, query, args...)
|
||||||
@ -554,9 +606,9 @@ func (t *StarTx) update(ctx context.Context, target interface{}, tableName strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []interface{}{}
|
args := make([]interface{}, len(params))
|
||||||
for _, param := range params {
|
for i, param := range params {
|
||||||
args = append(args, param)
|
args[i] = param
|
||||||
}
|
}
|
||||||
|
|
||||||
return t.execX(ctx, target, query, args...)
|
return t.execX(ctx, target, query, args...)
|
||||||
|
|||||||
191
orm_test.go
191
orm_test.go
@ -1,6 +1,8 @@
|
|||||||
package stardb
|
package stardb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -27,6 +29,21 @@ type NestedUser struct {
|
|||||||
Profile `db:"---"`
|
Profile `db:"---"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type NestedUserPtr struct {
|
||||||
|
ID int64 `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
Profile *Profile `db:"---"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AutoIDOnly struct {
|
||||||
|
ID int64 `db:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HiddenTagged struct {
|
||||||
|
ID int64 `db:"id"`
|
||||||
|
hidden string `db:"hidden"`
|
||||||
|
}
|
||||||
|
|
||||||
func TestBuildInsertSQL(t *testing.T) {
|
func TestBuildInsertSQL(t *testing.T) {
|
||||||
user := User{
|
user := User{
|
||||||
ID: 1,
|
ID: 1,
|
||||||
@ -76,13 +93,183 @@ func TestBuildUpdateSQL(t *testing.T) {
|
|||||||
t.Fatalf("buildUpdateSQL failed: %v", err)
|
t.Fatalf("buildUpdateSQL failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedQuery := "UPDATE users SET id = ?, name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?"
|
expectedQuery := "UPDATE users SET name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?"
|
||||||
if query != expectedQuery {
|
if query != expectedQuery {
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedParamCount := 8 // 7 fields + 1 primary key
|
expectedParamCount := 7 // 6 fields + 1 primary key
|
||||||
if len(params) != expectedParamCount {
|
if len(params) != expectedParamCount {
|
||||||
t.Errorf("Expected %d params, got %d", expectedParamCount, len(params))
|
t.Errorf("Expected %d params, got %d", expectedParamCount, len(params))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestBuildInsertSQL_NoColumns(t *testing.T) {
|
||||||
|
model := AutoIDOnly{ID: 1}
|
||||||
|
|
||||||
|
_, _, err := buildInsertSQL(&model, "users", "id")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when no columns remain to insert, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNoInsertColumns) {
|
||||||
|
t.Fatalf("Expected ErrNoInsertColumns, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildInsertSQL_EmptyTableName(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := buildInsertSQL(&user, "", "id")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when table name is empty, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrTableNameEmpty) {
|
||||||
|
t.Fatalf("Expected ErrTableNameEmpty, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpdateSQL_NoPrimaryKey(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := buildUpdateSQL(&user, "users")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when no primary key is provided, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrPrimaryKeyRequired) {
|
||||||
|
t.Fatalf("Expected ErrPrimaryKeyRequired, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpdateSQL_EmptyTableName(t *testing.T) {
|
||||||
|
user := User{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _, err := buildUpdateSQL(&user, "", "id")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when table name is empty, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrTableNameEmpty) {
|
||||||
|
t.Fatalf("Expected ErrTableNameEmpty, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildUpdateSQL_OnlyPrimaryKey(t *testing.T) {
|
||||||
|
model := AutoIDOnly{ID: 1}
|
||||||
|
|
||||||
|
_, _, err := buildUpdateSQL(&model, "users", "id")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error when no fields remain for SET clause, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrNoUpdateFields) {
|
||||||
|
t.Fatalf("Expected ErrNoUpdateFields, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStructFieldValues_NilNestedPointer(t *testing.T) {
|
||||||
|
user := NestedUserPtr{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
Profile: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
values, err := getStructFieldValues(user, "db")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getStructFieldValues failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if values["id"] != int64(1) {
|
||||||
|
t.Errorf("Expected id=1, got %v", values["id"])
|
||||||
|
}
|
||||||
|
if values["name"] != "Test" {
|
||||||
|
t.Errorf("Expected name=Test, got %v", values["name"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStructFieldNames_NilNestedPointer(t *testing.T) {
|
||||||
|
user := NestedUserPtr{
|
||||||
|
ID: 1,
|
||||||
|
Name: "Test",
|
||||||
|
Profile: nil,
|
||||||
|
}
|
||||||
|
|
||||||
|
names, err := getStructFieldNames(user, "db")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getStructFieldNames failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"id", "name"}
|
||||||
|
if !reflect.DeepEqual(names, expected) {
|
||||||
|
t.Errorf("Expected names %v, got %v", expected, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStructFieldNames_SkipUnexportedField(t *testing.T) {
|
||||||
|
model := HiddenTagged{
|
||||||
|
ID: 1,
|
||||||
|
hidden: "secret",
|
||||||
|
}
|
||||||
|
|
||||||
|
names, err := getStructFieldNames(model, "db")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getStructFieldNames failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := []string{"id"}
|
||||||
|
if !reflect.DeepEqual(names, expected) {
|
||||||
|
t.Errorf("Expected names %v, got %v", expected, names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStructFieldValues_NilTarget(t *testing.T) {
|
||||||
|
_, err := getStructFieldValues(nil, "db")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for nil target, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrTargetNil) {
|
||||||
|
t.Fatalf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetStructFieldNames_NilTarget(t *testing.T) {
|
||||||
|
_, err := getStructFieldNames(nil, "db")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error for nil target, got nil")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, ErrTargetNil) {
|
||||||
|
t.Fatalf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClearReflectCache(t *testing.T) {
|
||||||
|
type cacheUser struct {
|
||||||
|
ID int64 `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := reflect.TypeOf(cacheUser{})
|
||||||
|
plan1, err := getStructTagPlan(typ, "db")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getStructTagPlan failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(plan1) != 2 {
|
||||||
|
t.Fatalf("Expected 2 fields in plan, got %d", len(plan1))
|
||||||
|
}
|
||||||
|
|
||||||
|
ClearReflectCache()
|
||||||
|
|
||||||
|
plan2, err := getStructTagPlan(typ, "db")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("getStructTagPlan after clear failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(plan2) != 2 {
|
||||||
|
t.Fatalf("Expected 2 fields in plan after clear, got %d", len(plan2))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
4
pool.go
4
pool.go
@ -24,6 +24,10 @@ func DefaultPoolConfig() *PoolConfig {
|
|||||||
|
|
||||||
// SetPoolConfig applies pool configuration to the database
|
// SetPoolConfig applies pool configuration to the database
|
||||||
func (s *StarDB) SetPoolConfig(config *PoolConfig) {
|
func (s *StarDB) SetPoolConfig(config *PoolConfig) {
|
||||||
|
if s == nil || s.db == nil || config == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if config.MaxOpenConns > 0 {
|
if config.MaxOpenConns > 0 {
|
||||||
s.db.SetMaxOpenConns(config.MaxOpenConns)
|
s.db.SetMaxOpenConns(config.MaxOpenConns)
|
||||||
}
|
}
|
||||||
|
|||||||
302
reflect.go
302
reflect.go
@ -1,13 +1,126 @@
|
|||||||
package stardb
|
package stardb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type structTagField struct {
|
||||||
|
path []int
|
||||||
|
tag string
|
||||||
|
}
|
||||||
|
|
||||||
|
type structTagPlanKey struct {
|
||||||
|
typ reflect.Type
|
||||||
|
tagKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
var structTagPlanCache sync.Map
|
||||||
|
|
||||||
|
// ClearReflectCache clears internal reflection metadata cache.
|
||||||
|
// Useful after schema/tag refactors in long-running processes.
|
||||||
|
func ClearReflectCache() {
|
||||||
|
structTagPlanCache = sync.Map{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStructTagPlan(targetType reflect.Type, tagKey string) ([]structTagField, error) {
|
||||||
|
if targetType.Kind() == reflect.Ptr {
|
||||||
|
targetType = targetType.Elem()
|
||||||
|
}
|
||||||
|
if targetType.Kind() != reflect.Struct {
|
||||||
|
return nil, ErrTargetNotStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := structTagPlanKey{
|
||||||
|
typ: targetType,
|
||||||
|
tagKey: tagKey,
|
||||||
|
}
|
||||||
|
if cached, ok := structTagPlanCache.Load(cacheKey); ok {
|
||||||
|
return cached.([]structTagField), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := make([]structTagField, 0, targetType.NumField())
|
||||||
|
if err := buildStructTagPlan(targetType, tagKey, nil, &fields); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
structTagPlanCache.Store(cacheKey, fields)
|
||||||
|
return fields, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildStructTagPlan(currentType reflect.Type, tagKey string, prefix []int, out *[]structTagField) error {
|
||||||
|
if currentType.Kind() == reflect.Ptr {
|
||||||
|
currentType = currentType.Elem()
|
||||||
|
}
|
||||||
|
if currentType.Kind() != reflect.Struct {
|
||||||
|
return ErrTargetNotStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < currentType.NumField(); i++ {
|
||||||
|
field := currentType.Field(i)
|
||||||
|
tagValue := field.Tag.Get(tagKey)
|
||||||
|
fieldType := field.Type
|
||||||
|
|
||||||
|
path := make([]int, len(prefix)+1)
|
||||||
|
copy(path, prefix)
|
||||||
|
path[len(prefix)] = i
|
||||||
|
|
||||||
|
if tagValue == "---" {
|
||||||
|
if fieldType.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct {
|
||||||
|
if err := buildStructTagPlan(fieldType.Elem(), tagKey, path, out); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fieldType.Kind() == reflect.Struct {
|
||||||
|
if err := buildStructTagPlan(fieldType, tagKey, path, out); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tagValue != "" {
|
||||||
|
*out = append(*out, structTagField{
|
||||||
|
path: path,
|
||||||
|
tag: tagValue,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func resolveFieldByPath(root reflect.Value, path []int) (reflect.Value, bool) {
|
||||||
|
current := root
|
||||||
|
for _, idx := range path {
|
||||||
|
if current.Kind() == reflect.Ptr {
|
||||||
|
if current.IsNil() {
|
||||||
|
return reflect.Value{}, false
|
||||||
|
}
|
||||||
|
current = current.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.Kind() != reflect.Struct {
|
||||||
|
return reflect.Value{}, false
|
||||||
|
}
|
||||||
|
if idx < 0 || idx >= current.NumField() {
|
||||||
|
return reflect.Value{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
current = current.Field(idx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return current, true
|
||||||
|
}
|
||||||
|
|
||||||
// setStructFieldsFromRow sets struct fields from a row result using reflection
|
// setStructFieldsFromRow sets struct fields from a row result using reflection
|
||||||
func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error {
|
func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error {
|
||||||
|
if target == nil {
|
||||||
|
return ErrTargetNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
targetType := reflect.TypeOf(target)
|
||||||
targetValue := reflect.ValueOf(target)
|
targetValue := reflect.ValueOf(target)
|
||||||
|
|
||||||
@ -16,7 +129,7 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row
|
|||||||
}
|
}
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() {
|
if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() {
|
||||||
return errors.New("target is not writable")
|
return ErrTargetNotWritable
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
@ -24,7 +137,12 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row
|
|||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Struct {
|
if targetValue.Kind() != reflect.Struct {
|
||||||
return errors.New("target is not a struct")
|
return ErrTargetNotStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
row := r.Row(rowIndex)
|
||||||
|
if row.columnIndex == nil {
|
||||||
|
return ErrRowIndexOutOfRange
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < targetType.NumField(); i++ {
|
for i := 0; i < targetType.NumField(); i++ {
|
||||||
@ -32,14 +150,25 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row
|
|||||||
fieldValue := targetValue.Field(i)
|
fieldValue := targetValue.Field(i)
|
||||||
tagValue := field.Tag.Get(tagKey)
|
tagValue := field.Tag.Get(tagKey)
|
||||||
|
|
||||||
|
if !fieldValue.CanInterface() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip unexported or otherwise non-settable fields.
|
||||||
|
if !fieldValue.CanSet() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
// Handle nested structs
|
// Handle nested structs
|
||||||
if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct {
|
if fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct {
|
||||||
if tagValue == "" {
|
if tagValue == "" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if tagValue == "---" {
|
if tagValue == "---" {
|
||||||
nestedPtr := reflect.New(reflect.TypeOf(fieldValue.Interface()).Elem()).Interface()
|
nestedPtr := reflect.New(fieldValue.Type().Elem()).Interface()
|
||||||
r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex)
|
if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr))
|
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -51,7 +180,9 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row
|
|||||||
}
|
}
|
||||||
if tagValue == "---" {
|
if tagValue == "---" {
|
||||||
nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface()
|
nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface()
|
||||||
r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex)
|
if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem())
|
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem())
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@ -62,20 +193,21 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if column exists
|
// Check if column exists
|
||||||
if _, ok := r.Row(rowIndex).columnIndex[tagValue]; !ok {
|
if _, ok := row.columnIndex[tagValue]; !ok {
|
||||||
|
if r.db != nil && r.db.StrictORM {
|
||||||
|
return wrapColumnNotFound(tagValue)
|
||||||
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set field value based on type
|
// Set field value based on type
|
||||||
r.setFieldValue(fieldValue, tagValue, rowIndex)
|
r.setFieldValue(fieldValue, tagValue, row)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setFieldValue sets a single field value
|
// setFieldValue sets a single field value
|
||||||
func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, rowIndex int) {
|
func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, row *StarResult) {
|
||||||
row := r.Row(rowIndex)
|
|
||||||
|
|
||||||
switch fieldValue.Kind() {
|
switch fieldValue.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
fieldValue.SetString(row.MustString(columnName))
|
fieldValue.SetString(row.MustString(columnName))
|
||||||
@ -105,79 +237,70 @@ func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, ro
|
|||||||
fieldValue.SetFloat(float64(row.MustFloat32(columnName)))
|
fieldValue.SetFloat(float64(row.MustFloat32(columnName)))
|
||||||
case reflect.Float64:
|
case reflect.Float64:
|
||||||
fieldValue.SetFloat(row.MustFloat64(columnName))
|
fieldValue.SetFloat(row.MustFloat64(columnName))
|
||||||
case reflect.Interface, reflect.Struct, reflect.Ptr:
|
case reflect.Struct:
|
||||||
// Handle special types like time.Time
|
// Handle special struct types like time.Time
|
||||||
colIndex := r.columnIndex[columnName]
|
colIndex := row.columnIndex[columnName]
|
||||||
val := row.Result()[colIndex]
|
val := row.Result()[colIndex]
|
||||||
if t, ok := val.(time.Time); ok {
|
if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
|
||||||
fieldValue.Set(reflect.ValueOf(t))
|
if t, ok := val.(time.Time); ok {
|
||||||
|
fieldValue.Set(reflect.ValueOf(t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Ptr:
|
||||||
|
// Handle pointer to special types like *time.Time
|
||||||
|
colIndex := row.columnIndex[columnName]
|
||||||
|
val := row.Result()[colIndex]
|
||||||
|
if fieldValue.Type().Elem() == reflect.TypeOf(time.Time{}) {
|
||||||
|
if t, ok := val.(time.Time); ok {
|
||||||
|
tCopy := t
|
||||||
|
fieldValue.Set(reflect.ValueOf(&tCopy))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case reflect.Interface:
|
||||||
|
colIndex := row.columnIndex[columnName]
|
||||||
|
val := row.Result()[colIndex]
|
||||||
|
if val != nil {
|
||||||
|
fieldValue.Set(reflect.ValueOf(val))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getStructFieldValues extracts all field values from a struct
|
// getStructFieldValues extracts all field values from a struct
|
||||||
func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) {
|
func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) {
|
||||||
result := make(map[string]interface{})
|
if target == nil {
|
||||||
|
return nil, ErrTargetNil
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
targetType := reflect.TypeOf(target)
|
||||||
targetValue := reflect.ValueOf(target)
|
targetValue := reflect.ValueOf(target)
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
if targetValue.IsNil() {
|
if targetValue.IsNil() {
|
||||||
return nil, errors.New("pointer target is nil")
|
return nil, ErrPointerTargetNil
|
||||||
}
|
}
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Struct {
|
if targetValue.Kind() != reflect.Struct {
|
||||||
return nil, errors.New("target is not a struct")
|
return nil, ErrTargetNotStruct
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < targetType.NumField(); i++ {
|
plan, err := getStructTagPlan(targetType, tagKey)
|
||||||
field := targetType.Field(i)
|
if err != nil {
|
||||||
fieldValue := targetValue.Field(i)
|
return nil, err
|
||||||
tagValue := field.Tag.Get(tagKey)
|
}
|
||||||
|
|
||||||
// Handle nested pointer structs
|
result := make(map[string]interface{}, len(plan))
|
||||||
if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct {
|
for _, field := range plan {
|
||||||
if fieldValue.IsNil() {
|
fieldValue, ok := resolveFieldByPath(targetValue, field.path)
|
||||||
continue
|
if !ok {
|
||||||
}
|
|
||||||
if tagValue == "---" {
|
|
||||||
nestedValues, err := getStructFieldValues(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey)
|
|
||||||
if err != nil {
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
for k, v := range nestedValues {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle nested structs
|
|
||||||
if targetValue.Field(i).Kind() == reflect.Struct {
|
|
||||||
if tagValue == "---" {
|
|
||||||
nestedValues, err := getStructFieldValues(targetValue.Field(i).Interface(), tagKey)
|
|
||||||
if err != nil {
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
for k, v := range nestedValues {
|
|
||||||
result[k] = v
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tagValue == "" {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if !fieldValue.CanInterface() {
|
if !fieldValue.CanInterface() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
result[field.tag] = fieldValue.Interface()
|
||||||
result[tagValue] = fieldValue.Interface()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@ -185,10 +308,8 @@ func getStructFieldValues(target interface{}, tagKey string) (map[string]interfa
|
|||||||
|
|
||||||
// getStructFieldNames extracts all field names (tag values) from a struct
|
// getStructFieldNames extracts all field names (tag values) from a struct
|
||||||
func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
|
func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
|
||||||
var result []string
|
if target == nil {
|
||||||
|
return []string{}, ErrTargetNil
|
||||||
if !isStruct(target) {
|
|
||||||
return []string{}, errors.New("target is not a struct")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
targetType := reflect.TypeOf(target)
|
||||||
@ -196,45 +317,31 @@ func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
|
|||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
if targetType.Kind() == reflect.Ptr {
|
||||||
if targetValue.IsNil() {
|
if targetValue.IsNil() {
|
||||||
return []string{}, errors.New("pointer target is nil")
|
return []string{}, ErrPointerTargetNil
|
||||||
}
|
}
|
||||||
targetType = targetType.Elem()
|
targetType = targetType.Elem()
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < targetType.NumField(); i++ {
|
if targetValue.Kind() != reflect.Struct {
|
||||||
fieldValue := targetValue.Field(i)
|
return []string{}, ErrTargetNotStruct
|
||||||
field := targetType.Field(i)
|
}
|
||||||
tagValue := field.Tag.Get(tagKey)
|
|
||||||
|
|
||||||
// Handle nested pointer structs
|
plan, err := getStructTagPlan(targetType, tagKey)
|
||||||
if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct {
|
if err != nil {
|
||||||
if fieldValue.IsNil() {
|
return []string{}, err
|
||||||
continue
|
}
|
||||||
}
|
|
||||||
if tagValue == "---" {
|
|
||||||
nestedNames, err := getStructFieldNames(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey)
|
|
||||||
if err != nil {
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
result = append(result, nestedNames...)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle nested structs
|
result := make([]string, 0, len(plan))
|
||||||
if targetValue.Field(i).Kind() == reflect.Struct && tagValue == "---" {
|
for _, field := range plan {
|
||||||
nestedNames, err := getStructFieldNames(targetValue.Field(i).Interface(), tagKey)
|
fieldValue, ok := resolveFieldByPath(targetValue, field.path)
|
||||||
if err != nil {
|
if !ok {
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
result = append(result, nestedNames...)
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if !fieldValue.CanInterface() {
|
||||||
if tagValue != "" {
|
continue
|
||||||
result = append(result, tagValue)
|
|
||||||
}
|
}
|
||||||
|
result = append(result, field.tag)
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@ -242,6 +349,10 @@ func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
|
|||||||
|
|
||||||
// isWritable checks if a value is writable
|
// isWritable checks if a value is writable
|
||||||
func isWritable(target interface{}) bool {
|
func isWritable(target interface{}) bool {
|
||||||
|
if target == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
targetType := reflect.TypeOf(target)
|
||||||
targetValue := reflect.ValueOf(target)
|
targetValue := reflect.ValueOf(target)
|
||||||
return targetType.Kind() == reflect.Ptr || targetValue.CanSet()
|
return targetType.Kind() == reflect.Ptr || targetValue.CanSet()
|
||||||
@ -249,8 +360,15 @@ func isWritable(target interface{}) bool {
|
|||||||
|
|
||||||
// isStruct checks if a value is a struct
|
// isStruct checks if a value is a struct
|
||||||
func isStruct(target interface{}) bool {
|
func isStruct(target interface{}) bool {
|
||||||
|
if target == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
targetValue := reflect.ValueOf(target)
|
targetValue := reflect.ValueOf(target)
|
||||||
if targetValue.Kind() == reflect.Ptr {
|
if targetValue.Kind() == reflect.Ptr {
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
targetValue = targetValue.Elem()
|
targetValue = targetValue.Elem()
|
||||||
}
|
}
|
||||||
return targetValue.Kind() == reflect.Struct
|
return targetValue.Kind() == reflect.Struct
|
||||||
|
|||||||
137
result_safe.go
137
result_safe.go
@ -1,51 +1,124 @@
|
|||||||
package stardb
|
package stardb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
internalconv "b612.me/stardb/internal/convert"
|
||||||
"fmt"
|
"database/sql"
|
||||||
"strconv"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func (r *StarResult) getColumnValue(name string) (interface{}, error) {
|
||||||
|
if r == nil || r.columnIndex == nil {
|
||||||
|
return nil, wrapColumnNotFound(name)
|
||||||
|
}
|
||||||
|
index, ok := r.columnIndex[name]
|
||||||
|
if !ok {
|
||||||
|
return nil, wrapColumnNotFound(name)
|
||||||
|
}
|
||||||
|
return r.Result()[index], nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetString returns column value as string with error
|
// GetString returns column value as string with error
|
||||||
func (r *StarResult) GetString(name string) (string, error) {
|
func (r *StarResult) GetString(name string) (string, error) {
|
||||||
index, ok := r.columnIndex[name]
|
val, err := r.getColumnValue(name)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return "", errors.New("column not found: " + name)
|
return "", err
|
||||||
}
|
}
|
||||||
return ConvertToStringSafe(r.Result()[index])
|
return ConvertToStringSafe(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetInt64 returns column value as int64 with error
|
// GetInt64 returns column value as int64 with error
|
||||||
func (r *StarResult) GetInt64(name string) (int64, error) {
|
func (r *StarResult) GetInt64(name string) (int64, error) {
|
||||||
index, ok := r.columnIndex[name]
|
val, err := r.getColumnValue(name)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return 0, errors.New("column not found: " + name)
|
return 0, err
|
||||||
}
|
}
|
||||||
return ConvertToInt64Safe(r.Result()[index])
|
return ConvertToInt64Safe(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFloat64 returns column value as float64 with error
|
// GetFloat64 returns column value as float64 with error
|
||||||
func (r *StarResult) GetFloat64(name string) (float64, error) {
|
func (r *StarResult) GetFloat64(name string) (float64, error) {
|
||||||
index, ok := r.columnIndex[name]
|
val, err := r.getColumnValue(name)
|
||||||
if !ok {
|
if err != nil {
|
||||||
return 0, errors.New("column not found: " + name)
|
return 0, err
|
||||||
}
|
|
||||||
|
|
||||||
switch v := r.Result()[index].(type) {
|
|
||||||
case nil:
|
|
||||||
return 0, nil
|
|
||||||
case float64:
|
|
||||||
return v, nil
|
|
||||||
case float32:
|
|
||||||
return float64(v), nil
|
|
||||||
case int, int32, int64, uint64:
|
|
||||||
val, err := ConvertToInt64Safe(v)
|
|
||||||
return float64(val), err
|
|
||||||
case string:
|
|
||||||
return strconv.ParseFloat(v, 64)
|
|
||||||
case []byte:
|
|
||||||
return strconv.ParseFloat(string(v), 64)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("cannot convert %T to float64", v)
|
|
||||||
}
|
}
|
||||||
|
return internalconv.ToFloat64Safe(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNullString returns a nullable string value.
|
||||||
|
func (r *StarResult) GetNullString(name string) (sql.NullString, error) {
|
||||||
|
val, err := r.getColumnValue(name)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullString{}, err
|
||||||
|
}
|
||||||
|
if val == nil {
|
||||||
|
return sql.NullString{}, nil
|
||||||
|
}
|
||||||
|
str, err := ConvertToStringSafe(val)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullString{}, err
|
||||||
|
}
|
||||||
|
return sql.NullString{String: str, Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNullInt64 returns a nullable int64 value.
|
||||||
|
func (r *StarResult) GetNullInt64(name string) (sql.NullInt64, error) {
|
||||||
|
val, err := r.getColumnValue(name)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullInt64{}, err
|
||||||
|
}
|
||||||
|
if val == nil {
|
||||||
|
return sql.NullInt64{}, nil
|
||||||
|
}
|
||||||
|
i, err := ConvertToInt64Safe(val)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullInt64{}, err
|
||||||
|
}
|
||||||
|
return sql.NullInt64{Int64: i, Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNullFloat64 returns a nullable float64 value.
|
||||||
|
func (r *StarResult) GetNullFloat64(name string) (sql.NullFloat64, error) {
|
||||||
|
val, err := r.getColumnValue(name)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullFloat64{}, err
|
||||||
|
}
|
||||||
|
if val == nil {
|
||||||
|
return sql.NullFloat64{}, nil
|
||||||
|
}
|
||||||
|
f, err := internalconv.ToFloat64Safe(val)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullFloat64{}, err
|
||||||
|
}
|
||||||
|
return sql.NullFloat64{Float64: f, Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNullBool returns a nullable bool value.
|
||||||
|
func (r *StarResult) GetNullBool(name string) (sql.NullBool, error) {
|
||||||
|
val, err := r.getColumnValue(name)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullBool{}, err
|
||||||
|
}
|
||||||
|
if val == nil {
|
||||||
|
return sql.NullBool{}, nil
|
||||||
|
}
|
||||||
|
b, err := internalconv.ToBoolSafe(val)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullBool{}, err
|
||||||
|
}
|
||||||
|
return sql.NullBool{Bool: b, Valid: true}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetNullTime returns a nullable time value.
|
||||||
|
func (r *StarResult) GetNullTime(name string) (sql.NullTime, error) {
|
||||||
|
val, err := r.getColumnValue(name)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullTime{}, err
|
||||||
|
}
|
||||||
|
if val == nil {
|
||||||
|
return sql.NullTime{}, nil
|
||||||
|
}
|
||||||
|
t, err := internalconv.ToTimeSafe(val)
|
||||||
|
if err != nil {
|
||||||
|
return sql.NullTime{}, err
|
||||||
|
}
|
||||||
|
return sql.NullTime{Time: t, Valid: true}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
21
rows.go
21
rows.go
@ -1,6 +1,7 @@
|
|||||||
package stardb
|
package stardb
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"b612.me/stardb/internal/scanutil"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -53,7 +54,7 @@ func (r *StarRows) Rescan() error {
|
|||||||
// Row returns a specific row by index
|
// Row returns a specific row by index
|
||||||
func (r *StarRows) Row(index int) *StarResult {
|
func (r *StarRows) Row(index int) *StarResult {
|
||||||
result := &StarResult{}
|
result := &StarResult{}
|
||||||
if index >= len(r.data) {
|
if index < 0 || index >= len(r.data) {
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
result.result = r.data[index]
|
result.result = r.data[index]
|
||||||
@ -82,13 +83,11 @@ func (r *StarRows) parse() error {
|
|||||||
if r.parsed {
|
if r.parsed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
r.parsed = true
|
|
||||||
}()
|
|
||||||
|
|
||||||
r.data = [][]interface{}{}
|
r.data = [][]interface{}{}
|
||||||
r.columnIndex = make(map[string]int)
|
r.columnIndex = make(map[string]int)
|
||||||
r.stringResult = []map[string]string{}
|
r.stringResult = []map[string]string{}
|
||||||
|
r.columnsType = []reflect.Type{}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
r.columns, err = r.rows.Columns()
|
r.columns, err = r.rows.Columns()
|
||||||
@ -127,18 +126,28 @@ func (r *StarRows) parse() error {
|
|||||||
rowCopy := make([]interface{}, len(values))
|
rowCopy := make([]interface{}, len(values))
|
||||||
|
|
||||||
for i, val := range values {
|
for i, val := range values {
|
||||||
rowCopy[i] = val
|
copiedVal := cloneScannedValue(val)
|
||||||
record[r.columns[i]] = convertToString(val)
|
rowCopy[i] = copiedVal
|
||||||
|
record[r.columns[i]] = convertToString(copiedVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.data = append(r.data, rowCopy)
|
r.data = append(r.data, rowCopy)
|
||||||
r.stringResult = append(r.stringResult, record)
|
r.stringResult = append(r.stringResult, record)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.rows.Err(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
r.length = len(r.stringResult)
|
r.length = len(r.stringResult)
|
||||||
|
r.parsed = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cloneScannedValue(val interface{}) interface{} {
|
||||||
|
return scanutil.CloneScannedValue(val)
|
||||||
|
}
|
||||||
|
|
||||||
// convertToString converts any value to string
|
// convertToString converts any value to string
|
||||||
func convertToString(val interface{}) string {
|
func convertToString(val interface{}) string {
|
||||||
switch v := val.(type) {
|
switch v := val.(type) {
|
||||||
|
|||||||
29
rows_internal_test.go
Normal file
29
rows_internal_test.go
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestCloneScannedValue_BytesAreCopied(t *testing.T) {
|
||||||
|
original := []byte("hello")
|
||||||
|
clonedAny := cloneScannedValue(original)
|
||||||
|
cloned, ok := clonedAny.([]byte)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected []byte, got %T", clonedAny)
|
||||||
|
}
|
||||||
|
|
||||||
|
original[0] = 'H'
|
||||||
|
if string(cloned) != "hello" {
|
||||||
|
t.Fatalf("expected cloned value to remain 'hello', got '%s'", string(cloned))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(cloned) > 0 && &cloned[0] == &original[0] {
|
||||||
|
t.Fatal("expected cloned bytes to have a different backing array")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCloneScannedValue_NonBytesKeepReference(t *testing.T) {
|
||||||
|
in := int64(42)
|
||||||
|
out := cloneScannedValue(in)
|
||||||
|
if out != in {
|
||||||
|
t.Fatalf("expected %v, got %v", in, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
124
scan_each.go
Normal file
124
scan_each.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScanEachFunc is called for each scanned row in streaming mode.
|
||||||
|
type ScanEachFunc func(row *StarResult) error
|
||||||
|
|
||||||
|
// ScanEach executes query in streaming mode and invokes fn for each row.
|
||||||
|
func (s *StarDB) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
return s.scanEach(nil, query, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachContext executes query with context in streaming mode and invokes fn for each row.
|
||||||
|
func (s *StarDB) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
return s.scanEach(ctx, query, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
rows, err := s.queryRaw(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanEachSQLRows(rows, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEach executes query in transaction streaming mode and invokes fn for each row.
|
||||||
|
func (t *StarTx) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
return t.scanEach(nil, query, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachContext executes query with context in transaction streaming mode and invokes fn for each row.
|
||||||
|
func (t *StarTx) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
return t.scanEach(ctx, query, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *StarTx) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
rows, err := t.queryRaw(ctx, query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanEachSQLRows(rows, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEach executes prepared statement query in streaming mode and invokes fn for each row.
|
||||||
|
func (s *StarStmt) ScanEach(fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
return s.scanEach(nil, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachContext executes prepared statement query with context in streaming mode and invokes fn for each row.
|
||||||
|
func (s *StarStmt) ScanEachContext(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
return s.scanEach(ctx, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarStmt) scanEach(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
|
||||||
|
rows, err := s.queryRaw(ctx, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanEachSQLRows(rows, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func scanEachSQLRows(rows *sql.Rows, fn ScanEachFunc) error {
|
||||||
|
if fn == nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return ErrScanFuncNil
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
columns, err := rows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
columnTypes, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
types := make([]reflect.Type, len(columnTypes))
|
||||||
|
for i, colType := range columnTypes {
|
||||||
|
types[i] = colType.ScanType()
|
||||||
|
}
|
||||||
|
|
||||||
|
columnIndex := make(map[string]int, len(columns))
|
||||||
|
for i, colName := range columns {
|
||||||
|
columnIndex[colName] = i
|
||||||
|
}
|
||||||
|
|
||||||
|
scanArgs := make([]interface{}, len(columns))
|
||||||
|
values := make([]interface{}, len(columns))
|
||||||
|
for i := range values {
|
||||||
|
scanArgs[i] = &values[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(scanArgs...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCopy := make([]interface{}, len(values))
|
||||||
|
for i, val := range values {
|
||||||
|
rowCopy[i] = cloneScannedValue(val)
|
||||||
|
}
|
||||||
|
|
||||||
|
row := &StarResult{
|
||||||
|
result: rowCopy,
|
||||||
|
columns: columns,
|
||||||
|
columnIndex: columnIndex,
|
||||||
|
columnsType: types,
|
||||||
|
}
|
||||||
|
if err := fn(row); err != nil {
|
||||||
|
if errors.Is(err, ErrScanStopped) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rows.Err()
|
||||||
|
}
|
||||||
119
scan_each_orm.go
Normal file
119
scan_each_orm.go
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ScanEachORMFunc is called for each mapped struct in streaming ORM mode.
|
||||||
|
type ScanEachORMFunc func(target interface{}) error
|
||||||
|
|
||||||
|
// ScanEachORM streams query rows and maps each row to target before invoking fn.
|
||||||
|
// target must be a pointer to struct; it is reused for each row.
|
||||||
|
func (s *StarDB) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
||||||
|
return s.ScanEachORMContext(nil, query, target, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachORMContext streams query rows with context and maps each row to target before invoking fn.
|
||||||
|
// target must be a pointer to struct; it is reused for each row.
|
||||||
|
func (s *StarDB) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrScanORMFuncNil
|
||||||
|
}
|
||||||
|
if err := validateScanORMTarget(target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.ScanEachContext(ctx, query, func(row *StarResult) error {
|
||||||
|
if err := mapResultToStructTarget(row, target, s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fn(target)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachORM streams transaction rows and maps each row to target before invoking fn.
|
||||||
|
// target must be a pointer to struct; it is reused for each row.
|
||||||
|
func (t *StarTx) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
||||||
|
return t.ScanEachORMContext(nil, query, target, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachORMContext streams transaction rows with context and maps each row to target before invoking fn.
|
||||||
|
// target must be a pointer to struct; it is reused for each row.
|
||||||
|
func (t *StarTx) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrScanORMFuncNil
|
||||||
|
}
|
||||||
|
if err := validateScanORMTarget(target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return t.ScanEachContext(ctx, query, func(row *StarResult) error {
|
||||||
|
if err := mapResultToStructTarget(row, target, t.db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fn(target)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachORM streams prepared statement rows and maps each row to target before invoking fn.
|
||||||
|
// target must be a pointer to struct; it is reused for each row.
|
||||||
|
func (s *StarStmt) ScanEachORM(target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
||||||
|
return s.ScanEachORMContext(nil, target, fn, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanEachORMContext streams prepared statement rows with context and maps each row to target before invoking fn.
|
||||||
|
// target must be a pointer to struct; it is reused for each row.
|
||||||
|
func (s *StarStmt) ScanEachORMContext(ctx context.Context, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrScanORMFuncNil
|
||||||
|
}
|
||||||
|
if err := validateScanORMTarget(target); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.ScanEachContext(ctx, func(row *StarResult) error {
|
||||||
|
if err := mapResultToStructTarget(row, target, s.db); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fn(target)
|
||||||
|
}, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateScanORMTarget(target interface{}) error {
|
||||||
|
if target == nil {
|
||||||
|
return ErrTargetNil
|
||||||
|
}
|
||||||
|
|
||||||
|
targetType := reflect.TypeOf(target)
|
||||||
|
targetValue := reflect.ValueOf(target)
|
||||||
|
|
||||||
|
if targetType.Kind() != reflect.Ptr {
|
||||||
|
return ErrTargetNotPointer
|
||||||
|
}
|
||||||
|
if targetValue.IsNil() {
|
||||||
|
return ErrTargetPointerNil
|
||||||
|
}
|
||||||
|
if targetValue.Elem().Kind() != reflect.Struct {
|
||||||
|
return ErrTargetNotStruct
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapResultToStructTarget(row *StarResult, target interface{}, db *StarDB) error {
|
||||||
|
targetValue := reflect.ValueOf(target)
|
||||||
|
targetValue.Elem().Set(reflect.Zero(targetValue.Elem().Type()))
|
||||||
|
|
||||||
|
rowWrapper := &StarRows{
|
||||||
|
db: db,
|
||||||
|
length: 1,
|
||||||
|
columns: row.columns,
|
||||||
|
columnsType: row.columnsType,
|
||||||
|
columnIndex: row.columnIndex,
|
||||||
|
data: [][]interface{}{row.result},
|
||||||
|
parsed: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
return rowWrapper.setStructFieldsFromRow(target, "db", 0)
|
||||||
|
}
|
||||||
19
sql_placeholder.go
Normal file
19
sql_placeholder.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import internalsqlplaceholder "b612.me/stardb/internal/sqlplaceholder"
|
||||||
|
|
||||||
|
// ConvertPlaceholders converts placeholders according to style.
|
||||||
|
func ConvertPlaceholders(query string, style PlaceholderStyle) string {
|
||||||
|
switch normalizePlaceholderStyle(style) {
|
||||||
|
case PlaceholderDollar:
|
||||||
|
return ConvertQuestionToDollarPlaceholders(query)
|
||||||
|
default:
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
|
||||||
|
// It skips quoted strings, quoted identifiers and comments.
|
||||||
|
func ConvertQuestionToDollarPlaceholders(query string) string {
|
||||||
|
return internalsqlplaceholder.ConvertQuestionToDollarPlaceholders(query)
|
||||||
|
}
|
||||||
34
sql_placeholder_test.go
Normal file
34
sql_placeholder_test.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestConvertQuestionToDollarPlaceholders(t *testing.T) {
|
||||||
|
query := "SELECT * FROM users WHERE id = ? AND name = ?"
|
||||||
|
got := ConvertQuestionToDollarPlaceholders(query)
|
||||||
|
want := "SELECT * FROM users WHERE id = $1 AND name = $2"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("expected %q, got %q", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertQuestionToDollarPlaceholders_SkipQuotedAndComments(t *testing.T) {
|
||||||
|
query := "SELECT '?', \"?\", `?`, col FROM t WHERE id = ? -- ?\nAND note = '??' /* ? */ AND x = ?"
|
||||||
|
got := ConvertQuestionToDollarPlaceholders(query)
|
||||||
|
want := "SELECT '?', \"?\", `?`, col FROM t WHERE id = $1 -- ?\nAND note = '??' /* ? */ AND x = $2"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("expected %q, got %q", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConvertPlaceholders(t *testing.T) {
|
||||||
|
query := "SELECT * FROM t WHERE a = ? AND b = ?"
|
||||||
|
if got := ConvertPlaceholders(query, PlaceholderQuestion); got != query {
|
||||||
|
t.Fatalf("question style should keep query unchanged, got %q", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := ConvertPlaceholders(query, PlaceholderDollar)
|
||||||
|
want := "SELECT * FROM t WHERE a = $1 AND b = $2"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("expected %q, got %q", want, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
276
sql_runtime.go
Normal file
276
sql_runtime.go
Normal file
@ -0,0 +1,276 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
internalsqlruntime "b612.me/stardb/internal/sqlruntime"
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLBeforeHook runs before a SQL statement is executed.
|
||||||
|
type SQLBeforeHook func(ctx context.Context, query string, args []interface{})
|
||||||
|
|
||||||
|
// SQLAfterHook runs after a SQL statement is executed.
|
||||||
|
type SQLAfterHook func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error)
|
||||||
|
|
||||||
|
// PlaceholderStyle controls SQL placeholder format conversion.
|
||||||
|
type PlaceholderStyle int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PlaceholderQuestion keeps '?' placeholders unchanged.
|
||||||
|
PlaceholderQuestion PlaceholderStyle = iota
|
||||||
|
// PlaceholderDollar converts '?' placeholders to '$1,$2,...'.
|
||||||
|
PlaceholderDollar
|
||||||
|
)
|
||||||
|
|
||||||
|
// SQLFingerprintMode controls SQL fingerprint generation strategy.
|
||||||
|
type SQLFingerprintMode int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// SQLFingerprintBasic lowercases SQL and collapses whitespace.
|
||||||
|
SQLFingerprintBasic SQLFingerprintMode = iota
|
||||||
|
// SQLFingerprintMaskLiterals also masks numeric/string literals and $n placeholders.
|
||||||
|
SQLFingerprintMaskLiterals
|
||||||
|
)
|
||||||
|
|
||||||
|
type sqlHookMetaKey struct{}
|
||||||
|
|
||||||
|
// SQLHookMeta contains extra hook metadata attached to context.
|
||||||
|
type SQLHookMeta struct {
|
||||||
|
Fingerprint string
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLHookMetaFromContext extracts SQL hook metadata from context.
|
||||||
|
func SQLHookMetaFromContext(ctx context.Context) (SQLHookMeta, bool) {
|
||||||
|
if ctx == nil {
|
||||||
|
return SQLHookMeta{}, false
|
||||||
|
}
|
||||||
|
meta, ok := ctx.Value(sqlHookMetaKey{}).(SQLHookMeta)
|
||||||
|
return meta, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func withSQLHookMeta(ctx context.Context, meta SQLHookMeta) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, sqlHookMetaKey{}, meta)
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlRuntime struct {
|
||||||
|
state internalsqlruntime.State
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizePlaceholderStyle(style PlaceholderStyle) PlaceholderStyle {
|
||||||
|
return PlaceholderStyle(internalsqlruntime.NormalizePlaceholderStyle(int(style)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeSQLFingerprintMode(mode SQLFingerprintMode) SQLFingerprintMode {
|
||||||
|
return SQLFingerprintMode(internalsqlruntime.NormalizeFingerprintMode(int(mode)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneHookArgs(args []interface{}) []interface{} {
|
||||||
|
return internalsqlruntime.CloneHookArgs(args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) runtimeOptions() (SQLBeforeHook, SQLAfterHook, PlaceholderStyle, time.Duration) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil, PlaceholderQuestion, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
beforeAny, afterAny, rawStyle, slowThreshold := s.runtime.state.Options()
|
||||||
|
var before SQLBeforeHook
|
||||||
|
if b, ok := beforeAny.(SQLBeforeHook); ok {
|
||||||
|
before = b
|
||||||
|
}
|
||||||
|
var after SQLAfterHook
|
||||||
|
if a, ok := afterAny.(SQLAfterHook); ok {
|
||||||
|
after = a
|
||||||
|
}
|
||||||
|
|
||||||
|
return before, after, normalizePlaceholderStyle(PlaceholderStyle(rawStyle)), slowThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) sqlHooks() (SQLBeforeHook, SQLAfterHook, time.Duration) {
|
||||||
|
before, after, _, slowThreshold := s.runtimeOptions()
|
||||||
|
return before, after, slowThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) prepareSQLCall(query string, args []interface{}) (string, SQLBeforeHook, SQLAfterHook, []interface{}, time.Duration) {
|
||||||
|
before, after, style, slowThreshold := s.runtimeOptions()
|
||||||
|
query = ConvertPlaceholders(query, style)
|
||||||
|
if before == nil && after == nil {
|
||||||
|
return query, nil, nil, nil, slowThreshold
|
||||||
|
}
|
||||||
|
return query, before, after, cloneHookArgs(args), slowThreshold
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) hookContext(ctx context.Context, query string, before SQLBeforeHook, after SQLAfterHook) context.Context {
|
||||||
|
if s == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
needCounter := s.SQLFingerprintCounterEnabled()
|
||||||
|
needMeta := (before != nil || after != nil) && s.SQLFingerprintEnabled()
|
||||||
|
if !needCounter && !needMeta {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := s.SQLFingerprintMode()
|
||||||
|
keepComments := s.SQLFingerprintKeepComments()
|
||||||
|
fingerprint := internalsqlruntime.FingerprintSQL(query, int(mode), keepComments)
|
||||||
|
if fingerprint == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
|
||||||
|
if needCounter {
|
||||||
|
s.runtime.state.IncFingerprintCount(fingerprint)
|
||||||
|
}
|
||||||
|
if !needMeta {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
return withSQLHookMeta(ctx, SQLHookMeta{Fingerprint: fingerprint})
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldRunAfterHook(after SQLAfterHook, slowThreshold, duration time.Duration, err error) bool {
|
||||||
|
return internalsqlruntime.ShouldRunAfterHook(after != nil, slowThreshold, duration, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLHooks sets SQL before/after hooks.
|
||||||
|
func (s *StarDB) SetSQLHooks(before SQLBeforeHook, after SQLAfterHook) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetHooks(before, after)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLBeforeHook sets SQL before hook.
|
||||||
|
func (s *StarDB) SetSQLBeforeHook(before SQLBeforeHook) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetBeforeHook(before)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLAfterHook sets SQL after hook.
|
||||||
|
func (s *StarDB) SetSQLAfterHook(after SQLAfterHook) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetAfterHook(after)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPlaceholderStyle sets placeholder conversion style.
|
||||||
|
func (s *StarDB) SetPlaceholderStyle(style PlaceholderStyle) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetPlaceholderStyle(int(style))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLSlowThreshold sets minimum duration for triggering after hook.
|
||||||
|
// When threshold > 0, after hook runs only for statements slower than threshold or those with error.
|
||||||
|
func (s *StarDB) SetSQLSlowThreshold(threshold time.Duration) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetSlowThreshold(threshold)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLSlowThreshold returns current slow SQL threshold.
|
||||||
|
func (s *StarDB) SQLSlowThreshold() time.Duration {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.runtime.state.SlowThreshold()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PlaceholderStyle returns current placeholder style.
|
||||||
|
func (s *StarDB) PlaceholderStyle() PlaceholderStyle {
|
||||||
|
if s == nil {
|
||||||
|
return PlaceholderQuestion
|
||||||
|
}
|
||||||
|
style := PlaceholderStyle(s.runtime.state.PlaceholderStyle())
|
||||||
|
return normalizePlaceholderStyle(style)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
|
||||||
|
func (s *StarDB) SetSQLFingerprintEnabled(enabled bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetFingerprintEnabled(enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLFingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
|
||||||
|
func (s *StarDB) SQLFingerprintEnabled() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.runtime.state.FingerprintEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLFingerprintMode sets SQL fingerprint generation mode.
|
||||||
|
func (s *StarDB) SetSQLFingerprintMode(mode SQLFingerprintMode) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetFingerprintMode(int(mode))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLFingerprintMode returns SQL fingerprint generation mode.
|
||||||
|
func (s *StarDB) SQLFingerprintMode() SQLFingerprintMode {
|
||||||
|
if s == nil {
|
||||||
|
return SQLFingerprintBasic
|
||||||
|
}
|
||||||
|
mode := SQLFingerprintMode(s.runtime.state.FingerprintMode())
|
||||||
|
return normalizeSQLFingerprintMode(mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLFingerprintKeepComments controls whether comments are preserved in SQL fingerprints.
|
||||||
|
// Default is false.
|
||||||
|
func (s *StarDB) SetSQLFingerprintKeepComments(keep bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetFingerprintKeepComments(keep)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLFingerprintKeepComments reports whether SQL fingerprints preserve comments.
|
||||||
|
func (s *StarDB) SQLFingerprintKeepComments() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.runtime.state.FingerprintKeepComments()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSQLFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
|
||||||
|
// Default is false.
|
||||||
|
func (s *StarDB) SetSQLFingerprintCounterEnabled(enabled bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.SetFingerprintCounterEnabled(enabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLFingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
|
||||||
|
func (s *StarDB) SQLFingerprintCounterEnabled() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return s.runtime.state.FingerprintCounterEnabled()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SQLFingerprintCounters returns a snapshot of fingerprint hit counters.
|
||||||
|
func (s *StarDB) SQLFingerprintCounters() map[string]uint64 {
|
||||||
|
if s == nil {
|
||||||
|
return map[string]uint64{}
|
||||||
|
}
|
||||||
|
return s.runtime.state.FingerprintCountsSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetSQLFingerprintCounters clears all in-memory fingerprint hit counters.
|
||||||
|
func (s *StarDB) ResetSQLFingerprintCounters() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.state.ResetFingerprintCounts()
|
||||||
|
}
|
||||||
102
sql_runtime_test.go
Normal file
102
sql_runtime_test.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStarDB_RuntimeConfigConcurrent(t *testing.T) {
|
||||||
|
db := NewStarDB()
|
||||||
|
|
||||||
|
before := func(ctx context.Context, query string, args []interface{}) {}
|
||||||
|
after := func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for i := 0; i < 16; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 1000; j++ {
|
||||||
|
if (i+j)%2 == 0 {
|
||||||
|
db.SetPlaceholderStyle(PlaceholderDollar)
|
||||||
|
} else {
|
||||||
|
db.SetPlaceholderStyle(PlaceholderQuestion)
|
||||||
|
}
|
||||||
|
db.SetSQLSlowThreshold(time.Duration((i+j)%5) * time.Millisecond)
|
||||||
|
db.SetSQLFingerprintEnabled((i+j)%3 == 0)
|
||||||
|
db.SetSQLFingerprintMode(SQLFingerprintMode((i + j) % 3))
|
||||||
|
db.SetSQLFingerprintKeepComments((i+j)%4 == 0)
|
||||||
|
db.SetSQLFingerprintCounterEnabled((i+j)%5 == 0)
|
||||||
|
if (i+j)%7 == 0 {
|
||||||
|
db.ResetSQLFingerprintCounters()
|
||||||
|
}
|
||||||
|
db.SetSQLHooks(before, after)
|
||||||
|
_ = db.PlaceholderStyle()
|
||||||
|
_ = db.SQLSlowThreshold()
|
||||||
|
_ = db.SQLFingerprintEnabled()
|
||||||
|
_ = db.SQLFingerprintMode()
|
||||||
|
_ = db.SQLFingerprintKeepComments()
|
||||||
|
_ = db.SQLFingerprintCounterEnabled()
|
||||||
|
_ = db.SQLFingerprintCounters()
|
||||||
|
_, _, _, _ = db.runtimeOptions()
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLFingerprintMode(t *testing.T) {
|
||||||
|
db := NewStarDB()
|
||||||
|
|
||||||
|
if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic {
|
||||||
|
t.Fatalf("expected default mode SQLFingerprintBasic, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetSQLFingerprintMode(SQLFingerprintMaskLiterals)
|
||||||
|
if got := db.SQLFingerprintMode(); got != SQLFingerprintMaskLiterals {
|
||||||
|
t.Fatalf("expected SQLFingerprintMaskLiterals, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetSQLFingerprintMode(SQLFingerprintMode(99))
|
||||||
|
if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic {
|
||||||
|
t.Fatalf("expected invalid mode fallback to SQLFingerprintBasic, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLFingerprintKeepComments(t *testing.T) {
|
||||||
|
db := NewStarDB()
|
||||||
|
|
||||||
|
if db.SQLFingerprintKeepComments() {
|
||||||
|
t.Fatal("expected default keep comments to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetSQLFingerprintKeepComments(true)
|
||||||
|
if !db.SQLFingerprintKeepComments() {
|
||||||
|
t.Fatal("expected keep comments to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetSQLFingerprintKeepComments(false)
|
||||||
|
if db.SQLFingerprintKeepComments() {
|
||||||
|
t.Fatal("expected keep comments to be false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLFingerprintCounterSwitch(t *testing.T) {
|
||||||
|
db := NewStarDB()
|
||||||
|
|
||||||
|
if db.SQLFingerprintCounterEnabled() {
|
||||||
|
t.Fatal("expected default counter switch to be false")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetSQLFingerprintCounterEnabled(true)
|
||||||
|
if !db.SQLFingerprintCounterEnabled() {
|
||||||
|
t.Fatal("expected counter switch to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
db.ResetSQLFingerprintCounters()
|
||||||
|
if got := len(db.SQLFingerprintCounters()); got != 0 {
|
||||||
|
t.Fatalf("expected empty counters after reset, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
324
stardb.go
324
stardb.go
@ -3,13 +3,20 @@ package stardb
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StarDB is a simple wrapper around sql.DB providing enhanced functionality
|
// StarDB is a simple wrapper around sql.DB providing enhanced functionality
|
||||||
type StarDB struct {
|
type StarDB struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
ManualScan bool // If true, rows won't be automatically parsed
|
ManualScan bool // If true, rows won't be automatically parsed
|
||||||
|
StrictORM bool // If true, Orm requires all tagged columns to exist in query results
|
||||||
|
// batchInsertMaxRows controls batch split size for BatchInsert/BatchInsertStructs.
|
||||||
|
// <= 0 means no split (single SQL statement).
|
||||||
|
batchInsertMaxRows int64
|
||||||
|
batchInsertMaxParams int64
|
||||||
|
runtime sqlRuntime
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStarDB creates a new StarDB instance
|
// NewStarDB creates a new StarDB instance
|
||||||
@ -32,6 +39,21 @@ func (s *StarDB) SetDB(db *sql.DB) {
|
|||||||
s.db = db
|
s.db = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetStrictORM enables or disables strict column validation for Orm mapping.
|
||||||
|
func (s *StarDB) SetStrictORM(strict bool) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.StrictORM = strict
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) ensureDB() error {
|
||||||
|
if s == nil || s.db == nil {
|
||||||
|
return ErrDBNotInitialized
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Open opens a new database connection
|
// Open opens a new database connection
|
||||||
func (s *StarDB) Open(driver, connStr string) error {
|
func (s *StarDB) Open(driver, connStr string) error {
|
||||||
var err error
|
var err error
|
||||||
@ -41,36 +63,57 @@ func (s *StarDB) Open(driver, connStr string) error {
|
|||||||
|
|
||||||
// Close closes the database connection
|
// Close closes the database connection
|
||||||
func (s *StarDB) Close() error {
|
func (s *StarDB) Close() error {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return s.db.Close()
|
return s.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ping verifies the database connection is alive
|
// Ping verifies the database connection is alive
|
||||||
func (s *StarDB) Ping() error {
|
func (s *StarDB) Ping() error {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return s.db.Ping()
|
return s.db.Ping()
|
||||||
}
|
}
|
||||||
|
|
||||||
// PingContext verifies the database connection with context
|
// PingContext verifies the database connection with context
|
||||||
func (s *StarDB) PingContext(ctx context.Context) error {
|
func (s *StarDB) PingContext(ctx context.Context) error {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return s.db.PingContext(ctx)
|
return s.db.PingContext(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stats returns database statistics
|
// Stats returns database statistics
|
||||||
func (s *StarDB) Stats() sql.DBStats {
|
func (s *StarDB) Stats() sql.DBStats {
|
||||||
|
if s == nil || s.db == nil {
|
||||||
|
return sql.DBStats{}
|
||||||
|
}
|
||||||
return s.db.Stats()
|
return s.db.Stats()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMaxOpenConns sets the maximum number of open connections
|
// SetMaxOpenConns sets the maximum number of open connections
|
||||||
func (s *StarDB) SetMaxOpenConns(n int) {
|
func (s *StarDB) SetMaxOpenConns(n int) {
|
||||||
|
if s == nil || s.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.db.SetMaxOpenConns(n)
|
s.db.SetMaxOpenConns(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetMaxIdleConns sets the maximum number of idle connections
|
// SetMaxIdleConns sets the maximum number of idle connections
|
||||||
func (s *StarDB) SetMaxIdleConns(n int) {
|
func (s *StarDB) SetMaxIdleConns(n int) {
|
||||||
|
if s == nil || s.db == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
s.db.SetMaxIdleConns(n)
|
s.db.SetMaxIdleConns(n)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Conn returns a single connection from the pool
|
// Conn returns a single connection from the pool
|
||||||
func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) {
|
func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
return s.db.Conn(ctx)
|
return s.db.Conn(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -85,21 +128,55 @@ func (s *StarDB) QueryContext(ctx context.Context, query string, args ...interfa
|
|||||||
return s.query(ctx, query, args...)
|
return s.query(ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query is the internal query implementation
|
// QueryRaw executes a query and returns *sql.Rows without automatic parsing.
|
||||||
func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
func (s *StarDB) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
if err := s.db.Ping(); err != nil {
|
return s.queryRaw(nil, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRawContext executes a query with context and returns *sql.Rows without automatic parsing.
|
||||||
|
func (s *StarDB) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
return s.queryRaw(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarDB) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, ErrQueryEmpty
|
||||||
|
}
|
||||||
|
|
||||||
var rows *sql.Rows
|
query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args)
|
||||||
var err error
|
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, hookArgs)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var (
|
||||||
|
rows *sql.Rows
|
||||||
|
err error
|
||||||
|
)
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
rows, err = s.db.Query(query, args...)
|
rows, err = s.db.Query(query, args...)
|
||||||
} else {
|
} else {
|
||||||
rows, err = s.db.QueryContext(ctx, query, args...)
|
rows, err = s.db.QueryContext(ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, hookArgs, duration, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// query is the internal query implementation
|
||||||
|
func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
||||||
|
rows, err := s.queryRaw(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -110,10 +187,13 @@ func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !s.ManualScan {
|
if !s.ManualScan {
|
||||||
err = starRows.parse()
|
if err := starRows.parse(); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return starRows, err
|
return starRows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec executes a query that doesn't return rows
|
// Exec executes a query that doesn't return rows
|
||||||
@ -129,39 +209,100 @@ func (s *StarDB) ExecContext(ctx context.Context, query string, args ...interfac
|
|||||||
|
|
||||||
// exec is the internal exec implementation
|
// exec is the internal exec implementation
|
||||||
func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
if err := s.db.Ping(); err != nil {
|
if err := s.ensureDB(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
if ctx == nil {
|
return nil, ErrQueryEmpty
|
||||||
return s.db.Exec(query, args...)
|
|
||||||
}
|
}
|
||||||
return s.db.ExecContext(ctx, query, args...)
|
|
||||||
|
query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args)
|
||||||
|
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, hookArgs)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var (
|
||||||
|
result sql.Result
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if ctx == nil {
|
||||||
|
result, err = s.db.Exec(query, args...)
|
||||||
|
} else {
|
||||||
|
result, err = s.db.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, hookArgs, duration, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare creates a prepared statement
|
// Prepare creates a prepared statement
|
||||||
func (s *StarDB) Prepare(query string) (*StarStmt, error) {
|
func (s *StarDB) Prepare(query string) (*StarStmt, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, ErrQueryEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil)
|
||||||
|
hookCtx := s.hookContext(nil, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, nil)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
stmt, err := s.db.Prepare(query)
|
stmt, err := s.db.Prepare(query)
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, nil, duration, err)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &StarStmt{stmt: stmt, db: s}, nil
|
return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrepareContext creates a prepared statement with context
|
// PrepareContext creates a prepared statement with context
|
||||||
func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, ErrQueryEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil)
|
||||||
|
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, nil)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
stmt, err := s.db.PrepareContext(ctx, query)
|
stmt, err := s.db.PrepareContext(ctx, query)
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, nil, duration, err)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &StarStmt{stmt: stmt, db: s}, nil
|
return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryStmt executes a prepared statement query
|
// QueryStmt executes a prepared statement query
|
||||||
// Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1)
|
// Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1)
|
||||||
func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := s.Prepare(query)
|
stmt, err := s.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -173,8 +314,8 @@ func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error)
|
|||||||
|
|
||||||
// QueryStmtContext executes a prepared statement query with context
|
// QueryStmtContext executes a prepared statement query with context
|
||||||
func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := s.PrepareContext(ctx, query)
|
stmt, err := s.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -187,8 +328,8 @@ func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...int
|
|||||||
// ExecStmt executes a prepared statement
|
// ExecStmt executes a prepared statement
|
||||||
// Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John")
|
// Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John")
|
||||||
func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := s.Prepare(query)
|
stmt, err := s.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -200,8 +341,8 @@ func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error)
|
|||||||
|
|
||||||
// ExecStmtContext executes a prepared statement with context
|
// ExecStmtContext executes a prepared statement with context
|
||||||
func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := s.PrepareContext(ctx, query)
|
stmt, err := s.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -213,6 +354,9 @@ func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...inte
|
|||||||
|
|
||||||
// Begin starts a transaction
|
// Begin starts a transaction
|
||||||
func (s *StarDB) Begin() (*StarTx, error) {
|
func (s *StarDB) Begin() (*StarTx, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
tx, err := s.db.Begin()
|
tx, err := s.db.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -222,6 +366,9 @@ func (s *StarDB) Begin() (*StarTx, error) {
|
|||||||
|
|
||||||
// BeginTx starts a transaction with options
|
// BeginTx starts a transaction with options
|
||||||
func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) {
|
func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) {
|
||||||
|
if err := s.ensureDB(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
tx, err := s.db.BeginTx(ctx, opts)
|
tx, err := s.db.BeginTx(ctx, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -231,8 +378,26 @@ func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, err
|
|||||||
|
|
||||||
// StarStmt represents a prepared statement
|
// StarStmt represents a prepared statement
|
||||||
type StarStmt struct {
|
type StarStmt struct {
|
||||||
stmt *sql.Stmt
|
stmt *sql.Stmt
|
||||||
db *StarDB
|
db *StarDB
|
||||||
|
sqlText string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarStmt) ensureStmt() error {
|
||||||
|
if s == nil || s.stmt == nil {
|
||||||
|
return ErrStmtNotInitialized
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarStmt) ensureStmtWithDB() error {
|
||||||
|
if err := s.ensureStmt(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if s.db == nil {
|
||||||
|
return ErrStmtDBNotInitialized
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query executes a prepared statement query
|
// Query executes a prepared statement query
|
||||||
@ -245,17 +410,66 @@ func (s *StarStmt) QueryContext(ctx context.Context, args ...interface{}) (*Star
|
|||||||
return s.query(ctx, args...)
|
return s.query(ctx, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query is the internal query implementation
|
// QueryRaw executes a prepared statement query and returns *sql.Rows without automatic parsing.
|
||||||
func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) {
|
func (s *StarStmt) QueryRaw(args ...interface{}) (*sql.Rows, error) {
|
||||||
var rows *sql.Rows
|
return s.queryRaw(nil, args...)
|
||||||
var err error
|
}
|
||||||
|
|
||||||
|
// QueryRawContext executes a prepared statement query with context and returns *sql.Rows without automatic parsing.
|
||||||
|
func (s *StarStmt) QueryRawContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
return s.queryRaw(ctx, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *StarStmt) queryRaw(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
if err := s.ensureStmt(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var beforeHook SQLBeforeHook
|
||||||
|
var afterHook SQLAfterHook
|
||||||
|
var slowThreshold time.Duration
|
||||||
|
if s.db != nil {
|
||||||
|
beforeHook, afterHook, slowThreshold = s.db.sqlHooks()
|
||||||
|
}
|
||||||
|
var hookArgs []interface{}
|
||||||
|
if beforeHook != nil || afterHook != nil {
|
||||||
|
hookArgs = cloneHookArgs(args)
|
||||||
|
}
|
||||||
|
hookCtx := ctx
|
||||||
|
if s.db != nil {
|
||||||
|
hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook)
|
||||||
|
}
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, s.sqlText, hookArgs)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var (
|
||||||
|
rows *sql.Rows
|
||||||
|
err error
|
||||||
|
)
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
rows, err = s.stmt.Query(args...)
|
rows, err = s.stmt.Query(args...)
|
||||||
} else {
|
} else {
|
||||||
rows, err = s.stmt.QueryContext(ctx, args...)
|
rows, err = s.stmt.QueryContext(ctx, args...)
|
||||||
}
|
}
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, s.sqlText, hookArgs, duration, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// query is the internal query implementation
|
||||||
|
func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) {
|
||||||
|
if err := s.ensureStmtWithDB(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := s.queryRaw(ctx, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -266,10 +480,13 @@ func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, e
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !s.db.ManualScan {
|
if !s.db.ManualScan {
|
||||||
err = starRows.parse()
|
if err := starRows.parse(); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return starRows, err
|
return starRows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec executes a prepared statement
|
// Exec executes a prepared statement
|
||||||
@ -284,13 +501,52 @@ func (s *StarStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Re
|
|||||||
|
|
||||||
// exec is the internal exec implementation
|
// exec is the internal exec implementation
|
||||||
func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
||||||
if ctx == nil {
|
if err := s.ensureStmt(); err != nil {
|
||||||
return s.stmt.Exec(args...)
|
return nil, err
|
||||||
}
|
}
|
||||||
return s.stmt.ExecContext(ctx, args...)
|
|
||||||
|
var beforeHook SQLBeforeHook
|
||||||
|
var afterHook SQLAfterHook
|
||||||
|
var slowThreshold time.Duration
|
||||||
|
if s.db != nil {
|
||||||
|
beforeHook, afterHook, slowThreshold = s.db.sqlHooks()
|
||||||
|
}
|
||||||
|
var hookArgs []interface{}
|
||||||
|
if beforeHook != nil || afterHook != nil {
|
||||||
|
hookArgs = cloneHookArgs(args)
|
||||||
|
}
|
||||||
|
hookCtx := ctx
|
||||||
|
if s.db != nil {
|
||||||
|
hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook)
|
||||||
|
}
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, s.sqlText, hookArgs)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var (
|
||||||
|
result sql.Result
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if ctx == nil {
|
||||||
|
result, err = s.stmt.Exec(args...)
|
||||||
|
} else {
|
||||||
|
result, err = s.stmt.ExecContext(ctx, args...)
|
||||||
|
}
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, s.sqlText, hookArgs, duration, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the prepared statement
|
// Close closes the prepared statement
|
||||||
func (s *StarStmt) Close() error {
|
func (s *StarStmt) Close() error {
|
||||||
|
if err := s.ensureStmt(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return s.stmt.Close()
|
return s.stmt.Close()
|
||||||
}
|
}
|
||||||
|
|||||||
127
stardb_safe_test.go
Normal file
127
stardb_safe_test.go
Normal file
@ -0,0 +1,127 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStarDB_NotInitialized(t *testing.T) {
|
||||||
|
db := NewStarDB()
|
||||||
|
|
||||||
|
if err := db.Close(); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Close, got %v", err)
|
||||||
|
}
|
||||||
|
if err := db.Ping(); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Ping, got %v", err)
|
||||||
|
}
|
||||||
|
if err := db.PingContext(context.Background()); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from PingContext, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Conn(context.Background()); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Conn, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Query("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Query, got %v", err)
|
||||||
|
}
|
||||||
|
if err := db.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from ScanEach, got %v", err)
|
||||||
|
}
|
||||||
|
var model struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
}
|
||||||
|
if err := db.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from ScanEachORM, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.QueryRaw("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from QueryRaw, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Exec, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Prepare("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Prepare, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Begin(); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from Begin, got %v", err)
|
||||||
|
}
|
||||||
|
if err := db.WithTx(nil); !errors.Is(err, ErrTxFuncNil) {
|
||||||
|
t.Fatalf("expected ErrTxFuncNil from WithTx, got %v", err)
|
||||||
|
}
|
||||||
|
if err := db.WithTx(func(tx *StarTx) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrDBNotInitialized from WithTx, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
||||||
|
t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
||||||
|
t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = db.Stats()
|
||||||
|
db.SetMaxOpenConns(5)
|
||||||
|
db.SetMaxIdleConns(2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarTx_NotInitialized(t *testing.T) {
|
||||||
|
tx := &StarTx{}
|
||||||
|
var model struct {
|
||||||
|
ID int `db:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Query("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from Query, got %v", err)
|
||||||
|
}
|
||||||
|
if err := tx.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from ScanEach, got %v", err)
|
||||||
|
}
|
||||||
|
if err := tx.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from ScanEachORM, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.QueryRaw("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from QueryRaw, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Exec("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from Exec, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.Prepare("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from Prepare, got %v", err)
|
||||||
|
}
|
||||||
|
if err := tx.Commit(); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from Commit, got %v", err)
|
||||||
|
}
|
||||||
|
if err := tx.Rollback(); !errors.Is(err, ErrTxNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrTxNotInitialized from Rollback, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
||||||
|
t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := tx.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
||||||
|
t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarStmt_NotInitialized(t *testing.T) {
|
||||||
|
stmt := &StarStmt{}
|
||||||
|
|
||||||
|
if _, err := stmt.Query(); !errors.Is(err, ErrStmtNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrStmtNotInitialized from Query, got %v", err)
|
||||||
|
}
|
||||||
|
if err := stmt.ScanEach(func(row *StarResult) error { return nil }); !errors.Is(err, ErrStmtNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrStmtNotInitialized from ScanEach, got %v", err)
|
||||||
|
}
|
||||||
|
if err := stmt.ScanEachORM(nil, func(target interface{}) error { return nil }); !errors.Is(err, ErrTargetNil) {
|
||||||
|
t.Fatalf("expected ErrTargetNil from ScanEachORM, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := stmt.QueryRaw(); !errors.Is(err, ErrStmtNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrStmtNotInitialized from QueryRaw, got %v", err)
|
||||||
|
}
|
||||||
|
if _, err := stmt.Exec(); !errors.Is(err, ErrStmtNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrStmtNotInitialized from Exec, got %v", err)
|
||||||
|
}
|
||||||
|
if err := stmt.Close(); !errors.Is(err, ErrStmtNotInitialized) {
|
||||||
|
t.Fatalf("expected ErrStmtNotInitialized from Close, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -2,6 +2,7 @@ package testing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -110,8 +111,53 @@ func TestStarDB_BatchInsert_Empty(t *testing.T) {
|
|||||||
values := [][]interface{}{}
|
values := [][]interface{}{}
|
||||||
|
|
||||||
_, err := db.BatchInsert("users", columns, values)
|
_, err := db.BatchInsert("users", columns, values)
|
||||||
if err == nil {
|
if !errors.Is(err, stardb.ErrNoInsertValues) {
|
||||||
t.Error("Expected error with empty values, got nil")
|
t.Errorf("Expected ErrNoInsertValues, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_EmptyColumns(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.BatchInsert("users", nil, values)
|
||||||
|
if !errors.Is(err, stardb.ErrNoInsertColumns) {
|
||||||
|
t.Errorf("Expected ErrNoInsertColumns, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_EmptyTableName(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.BatchInsert("", columns, values)
|
||||||
|
if !errors.Is(err, stardb.ErrTableNameEmpty) {
|
||||||
|
t.Errorf("Expected ErrTableNameEmpty, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_RowLengthMismatch(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
{"Bob", "bob@example.com"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.BatchInsert("users", columns, values)
|
||||||
|
if !errors.Is(err, stardb.ErrBatchRowValueCountMismatch) {
|
||||||
|
t.Errorf("Expected ErrBatchRowValueCountMismatch, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,6 +250,236 @@ func TestStarDB_BatchInsertContext_Timeout(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsertMaxRows_Config(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if got := db.BatchInsertMaxRows(); got != 0 {
|
||||||
|
t.Fatalf("Expected default chunk size 0, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(3)
|
||||||
|
if got := db.BatchInsertMaxRows(); got != 3 {
|
||||||
|
t.Fatalf("Expected chunk size 3, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(-10)
|
||||||
|
if got := db.BatchInsertMaxRows(); got != 0 {
|
||||||
|
t.Fatalf("Expected chunk size reset to 0, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsertMaxParams_Config(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
if got := db.BatchInsertMaxParams(); got != 0 {
|
||||||
|
t.Fatalf("Expected default max params 0, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxParams(100)
|
||||||
|
if got := db.BatchInsertMaxParams(); got != 100 {
|
||||||
|
t.Fatalf("Expected max params 100, got %d", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxParams(-1)
|
||||||
|
if got := db.BatchInsertMaxParams(); got != 0 {
|
||||||
|
t.Fatalf("Expected max params reset to 0, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_Chunked(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(2)
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
{"Bob", "bob@example.com", 30},
|
||||||
|
{"Charlie", "charlie@example.com", 35},
|
||||||
|
{"David", "david@example.com", 40},
|
||||||
|
{"Eva", "eva@example.com", 28},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := db.BatchInsert("users", columns, values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chunked BatchInsert failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RowsAffected failed: %v", err)
|
||||||
|
}
|
||||||
|
if affected != int64(len(values)) {
|
||||||
|
t.Fatalf("Expected %d affected rows, got %d", len(values), affected)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if count := rows.Row(0).MustInt("count"); count != len(values) {
|
||||||
|
t.Fatalf("Expected %d rows in db, got %d", len(values), count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_ChunkedRollbackOnError(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(2)
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
{"Bob", "bob@example.com", 30},
|
||||||
|
{"Charlie", nil, 35}, // email NOT NULL, forces second chunk failure
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.BatchInsert("users", columns, values); err == nil {
|
||||||
|
t.Fatal("Expected chunked BatchInsert to fail, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
if count := rows.Row(0).MustInt("count"); count != 0 {
|
||||||
|
t.Fatalf("Expected rollback to keep table empty, got %d rows", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_ChunkedByMaxParams(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(0) // disabled
|
||||||
|
db.SetBatchInsertMaxParams(4) // 3 columns -> 1 row per chunk
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
{"Bob", "bob@example.com", 30},
|
||||||
|
{"Charlie", "charlie@example.com", 35},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := db.BatchInsert("users", columns, values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BatchInsert by max params failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RowsAffected failed: %v", err)
|
||||||
|
}
|
||||||
|
if affected != int64(len(values)) {
|
||||||
|
t.Fatalf("Expected %d affected rows, got %d", len(values), affected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_MaxParamsTooLow(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(0)
|
||||||
|
db.SetBatchInsertMaxParams(2) // columns=3 -> invalid
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.BatchInsert("users", columns, values)
|
||||||
|
if !errors.Is(err, stardb.ErrBatchInsertMaxParamsTooLow) {
|
||||||
|
t.Fatalf("Expected ErrBatchInsertMaxParamsTooLow, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_ChunkedHookMeta(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(2)
|
||||||
|
db.SetBatchInsertMaxParams(0)
|
||||||
|
|
||||||
|
var metas []stardb.BatchExecMeta
|
||||||
|
db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
|
||||||
|
if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok {
|
||||||
|
metas = append(metas, meta)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
{"Bob", "bob@example.com", 30},
|
||||||
|
{"Charlie", "charlie@example.com", 35},
|
||||||
|
{"David", "david@example.com", 40},
|
||||||
|
{"Eva", "eva@example.com", 28},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.BatchInsertContext(context.Background(), "users", columns, values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BatchInsertContext failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(metas) != 3 {
|
||||||
|
t.Fatalf("Expected 3 chunk metas, got %d", len(metas))
|
||||||
|
}
|
||||||
|
|
||||||
|
wantRows := []int{2, 2, 1}
|
||||||
|
for i, meta := range metas {
|
||||||
|
if meta.ChunkIndex != i+1 {
|
||||||
|
t.Fatalf("Chunk %d: expected index %d, got %d", i, i+1, meta.ChunkIndex)
|
||||||
|
}
|
||||||
|
if meta.ChunkCount != 3 {
|
||||||
|
t.Fatalf("Chunk %d: expected count 3, got %d", i, meta.ChunkCount)
|
||||||
|
}
|
||||||
|
if meta.ChunkRows != wantRows[i] {
|
||||||
|
t.Fatalf("Chunk %d: expected rows %d, got %d", i, wantRows[i], meta.ChunkRows)
|
||||||
|
}
|
||||||
|
if meta.TotalRows != len(values) {
|
||||||
|
t.Fatalf("Chunk %d: expected total rows %d, got %d", i, len(values), meta.TotalRows)
|
||||||
|
}
|
||||||
|
if meta.ColumnCount != len(columns) {
|
||||||
|
t.Fatalf("Chunk %d: expected column count %d, got %d", i, len(columns), meta.ColumnCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsert_HookMetaAbsentWithoutChunking(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(0)
|
||||||
|
db.SetBatchInsertMaxParams(0)
|
||||||
|
|
||||||
|
metaFound := false
|
||||||
|
db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
|
||||||
|
if _, ok := stardb.BatchExecMetaFromContext(ctx); ok {
|
||||||
|
metaFound = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age"}
|
||||||
|
values := [][]interface{}{
|
||||||
|
{"Alice", "alice@example.com", 25},
|
||||||
|
{"Bob", "bob@example.com", 30},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.BatchInsertContext(context.Background(), "users", columns, values)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BatchInsertContext failed: %v", err)
|
||||||
|
}
|
||||||
|
if metaFound {
|
||||||
|
t.Fatal("Expected no batch meta for non-chunked execution")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Basic(t *testing.T) {
|
func TestStarDB_BatchInsertStructs_Basic(t *testing.T) {
|
||||||
db := setupBatchTestDB(t)
|
db := setupBatchTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -246,6 +522,34 @@ func TestStarDB_BatchInsertStructs_Basic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsertStructs_Chunked(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetBatchInsertMaxRows(2)
|
||||||
|
|
||||||
|
users := []TestUser{
|
||||||
|
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
|
||||||
|
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
|
||||||
|
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
|
||||||
|
{Name: "David", Email: "david@example.com", Age: 40, CreatedAt: time.Now()},
|
||||||
|
{Name: "Eva", Email: "eva@example.com", Age: 28, CreatedAt: time.Now()},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := db.BatchInsertStructs("users", users, "id")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Chunked BatchInsertStructs failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RowsAffected failed: %v", err)
|
||||||
|
}
|
||||||
|
if affected != int64(len(users)) {
|
||||||
|
t.Fatalf("Expected %d affected rows, got %d", len(users), affected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Single(t *testing.T) {
|
func TestStarDB_BatchInsertStructs_Single(t *testing.T) {
|
||||||
db := setupBatchTestDB(t)
|
db := setupBatchTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -276,8 +580,8 @@ func TestStarDB_BatchInsertStructs_Empty(t *testing.T) {
|
|||||||
users := []TestUser{}
|
users := []TestUser{}
|
||||||
|
|
||||||
_, err := db.BatchInsertStructs("users", users, "id")
|
_, err := db.BatchInsertStructs("users", users, "id")
|
||||||
if err == nil {
|
if !errors.Is(err, stardb.ErrNoStructsToInsert) {
|
||||||
t.Error("Expected error with empty slice, got nil")
|
t.Errorf("Expected ErrNoStructsToInsert, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -288,8 +592,29 @@ func TestStarDB_BatchInsertStructs_NotSlice(t *testing.T) {
|
|||||||
user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25}
|
user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25}
|
||||||
|
|
||||||
_, err := db.BatchInsertStructs("users", user, "id")
|
_, err := db.BatchInsertStructs("users", user, "id")
|
||||||
if err == nil {
|
if !errors.Is(err, stardb.ErrStructsNotSlice) {
|
||||||
t.Error("Expected error with non-slice, got nil")
|
t.Errorf("Expected ErrStructsNotSlice, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsertStructs_Nil(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.BatchInsertStructs("users", nil, "id")
|
||||||
|
if !errors.Is(err, stardb.ErrStructsNil) {
|
||||||
|
t.Errorf("Expected ErrStructsNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_BatchInsertStructs_NilPointer(t *testing.T) {
|
||||||
|
db := setupBatchTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var users *[]TestUser
|
||||||
|
_, err := db.BatchInsertStructs("users", users, "id")
|
||||||
|
if !errors.Is(err, stardb.ErrStructsPointerNil) {
|
||||||
|
t.Errorf("Expected ErrStructsPointerNil, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -2,8 +2,11 @@ package testing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"b612.me/stardb"
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
@ -28,6 +31,12 @@ type NestedUser struct {
|
|||||||
Profile `db:"---"`
|
Profile `db:"---"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserWithPrivateField struct {
|
||||||
|
ID int64 `db:"id"`
|
||||||
|
Name string `db:"name"`
|
||||||
|
age int `db:"age"`
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarRows_Orm_Single(t *testing.T) {
|
func TestStarRows_Orm_Single(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -93,6 +102,47 @@ func TestStarRows_Orm_Multiple(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_Array(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var users [3]User
|
||||||
|
err = rows.Orm(&users)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Orm failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedNames := []string{"Alice", "Bob", "Charlie"}
|
||||||
|
for i, user := range users {
|
||||||
|
if user.Name != expectedNames[i] {
|
||||||
|
t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_ArrayTooSmall(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var users [2]User
|
||||||
|
err = rows.Orm(&users)
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Expected error when target array is smaller than row count, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarRows_Orm_Empty(t *testing.T) {
|
func TestStarRows_Orm_Empty(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -126,8 +176,105 @@ func TestStarRows_Orm_NotPointer(t *testing.T) {
|
|||||||
|
|
||||||
var user User
|
var user User
|
||||||
err = rows.Orm(user) // Not a pointer
|
err = rows.Orm(user) // Not a pointer
|
||||||
|
if !errors.Is(err, stardb.ErrTargetNotPointer) {
|
||||||
|
t.Errorf("Expected ErrTargetNotPointer, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_NilTarget(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
err = rows.Orm(nil)
|
||||||
|
if !errors.Is(err, stardb.ErrTargetNil) {
|
||||||
|
t.Errorf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_NilPointerTarget(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var user *User
|
||||||
|
err = rows.Orm(user)
|
||||||
|
if !errors.Is(err, stardb.ErrTargetPointerNil) {
|
||||||
|
t.Errorf("Expected ErrTargetPointerNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_MissingColumns_NonStrict(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var user User
|
||||||
|
err = rows.Orm(&user)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected non-strict ORM to ignore missing columns, got error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.Name != "Alice" {
|
||||||
|
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_MissingColumns_Strict(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetStrictORM(true)
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var user User
|
||||||
|
err = rows.Orm(&user)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Errorf("Expected error when passing non-pointer, got nil")
|
t.Fatalf("Expected strict ORM to fail on missing columns, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarRows_Orm_UnexportedTaggedField(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var user UserWithPrivateField
|
||||||
|
err = rows.Orm(&user)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Orm failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if user.ID <= 0 {
|
||||||
|
t.Errorf("Expected positive ID, got %d", user.ID)
|
||||||
|
}
|
||||||
|
if user.Name != "Alice" {
|
||||||
|
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -355,6 +502,28 @@ func TestStarDB_QueryXContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_QueryX_MissingField(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
user := User{Name: "Alice"}
|
||||||
|
|
||||||
|
_, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":unknown")
|
||||||
|
if !errors.Is(err, stardb.ErrFieldNotFound) {
|
||||||
|
t.Errorf("Expected ErrFieldNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_QueryX_NilTarget(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id")
|
||||||
|
if !errors.Is(err, stardb.ErrTargetNil) {
|
||||||
|
t.Errorf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_QueryXS(t *testing.T) {
|
func TestStarDB_QueryXS(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -380,6 +549,16 @@ func TestStarDB_QueryXS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_QueryXS_NilTargets(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.QueryXS(nil, "SELECT * FROM users")
|
||||||
|
if !errors.Is(err, stardb.ErrTargetsNil) {
|
||||||
|
t.Errorf("Expected ErrTargetsNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_ExecX(t *testing.T) {
|
func TestStarDB_ExecX(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -441,6 +620,28 @@ func TestStarDB_ExecXContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ExecX_MissingField(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
user := User{Name: "Alice", Age: 99}
|
||||||
|
|
||||||
|
_, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":unknown")
|
||||||
|
if !errors.Is(err, stardb.ErrFieldNotFound) {
|
||||||
|
t.Errorf("Expected ErrFieldNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ExecX_NilTarget(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id")
|
||||||
|
if !errors.Is(err, stardb.ErrTargetNil) {
|
||||||
|
t.Errorf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_ExecXS(t *testing.T) {
|
func TestStarDB_ExecXS(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -470,6 +671,16 @@ func TestStarDB_ExecXS(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ExecXS_NilTargets(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.ExecXS(nil, "UPDATE users SET age = age")
|
||||||
|
if !errors.Is(err, stardb.ErrTargetsNil) {
|
||||||
|
t.Errorf("Expected ErrTargetsNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarTx_Insert(t *testing.T) {
|
func TestStarTx_Insert(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -592,6 +803,22 @@ func TestStarTx_QueryX(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarTx_QueryX_NilTarget(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
_, err = tx.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id")
|
||||||
|
if !errors.Is(err, stardb.ErrTargetNil) {
|
||||||
|
t.Errorf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarTx_ExecX(t *testing.T) {
|
func TestStarTx_ExecX(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -627,6 +854,22 @@ func TestStarTx_ExecX(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarTx_ExecX_NilTarget(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
_, err = tx.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id")
|
||||||
|
if !errors.Is(err, stardb.ErrTargetNil) {
|
||||||
|
t.Errorf("Expected ErrTargetNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarTx_Rollback(t *testing.T) {
|
func TestStarTx_Rollback(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
|
|||||||
128
testing/perf_test.go
Normal file
128
testing/perf_test.go
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
package testing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/stardb"
|
||||||
|
_ "modernc.org/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setupBenchmarkDB(b *testing.B) *stardb.StarDB {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
db := &stardb.StarDB{}
|
||||||
|
if err := db.Open("sqlite", ":memory:"); err != nil {
|
||||||
|
b.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.Exec(`
|
||||||
|
CREATE TABLE users (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
email TEXT NOT NULL,
|
||||||
|
age INTEGER,
|
||||||
|
balance REAL,
|
||||||
|
active BOOLEAN,
|
||||||
|
created_at DATETIME
|
||||||
|
)
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to create table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(`
|
||||||
|
INSERT INTO users (name, email, age, balance, active, created_at) VALUES
|
||||||
|
('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'),
|
||||||
|
('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'),
|
||||||
|
('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00')
|
||||||
|
`)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Failed to insert seed data: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkQueryX(b *testing.B) {
|
||||||
|
db := setupBenchmarkDB(b)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
target := User{Name: "Alice"}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rows, err := db.QueryX(&target, "SELECT * FROM users WHERE name = ?", ":name")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("QueryX failed: %v", err)
|
||||||
|
}
|
||||||
|
_ = rows.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkOrm(b *testing.B) {
|
||||||
|
db := setupBenchmarkDB(b)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var users []User
|
||||||
|
if err := rows.Orm(&users); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
b.Fatalf("Orm failed: %v", err)
|
||||||
|
}
|
||||||
|
_ = rows.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkScanEach(b *testing.B) {
|
||||||
|
db := setupBenchmarkDB(b)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
count := 0
|
||||||
|
err := db.ScanEach("SELECT * FROM users ORDER BY name", func(row *stardb.StarResult) error {
|
||||||
|
_ = row.MustString("name")
|
||||||
|
count++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("ScanEach failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 3 {
|
||||||
|
b.Fatalf("Unexpected row count: %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBatchInsert(b *testing.B) {
|
||||||
|
db := setupBenchmarkDB(b)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
columns := []string{"name", "email", "age", "balance", "active", "created_at"}
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
base := i * 2
|
||||||
|
values := [][]interface{}{
|
||||||
|
{fmt.Sprintf("bench_user_%d", base), fmt.Sprintf("bench_%d@example.com", base), 20 + (base % 20), 99.5, true, time.Now()},
|
||||||
|
{fmt.Sprintf("bench_user_%d", base+1), fmt.Sprintf("bench_%d@example.com", base+1), 20 + ((base + 1) % 20), 199.5, false, time.Now()},
|
||||||
|
}
|
||||||
|
if _, err := db.BatchInsert("users", columns, values); err != nil {
|
||||||
|
b.Fatalf("BatchInsert failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -95,6 +95,32 @@ func TestStarDB_SetPoolConfig_Zero(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SetPoolConfig_NilConfig(t *testing.T) {
|
||||||
|
db := stardb.NewStarDB()
|
||||||
|
err := db.Open("sqlite", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to open database: %v", err)
|
||||||
|
}
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetPoolConfig(nil)
|
||||||
|
|
||||||
|
err = db.Ping()
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Ping failed after SetPoolConfig(nil): %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SetPoolConfig_BeforeOpen(t *testing.T) {
|
||||||
|
db := stardb.NewStarDB()
|
||||||
|
|
||||||
|
db.SetPoolConfig(&stardb.PoolConfig{
|
||||||
|
MaxOpenConns: 10,
|
||||||
|
})
|
||||||
|
|
||||||
|
// should not panic when called before Open
|
||||||
|
}
|
||||||
|
|
||||||
func TestOpenWithPool_Default(t *testing.T) {
|
func TestOpenWithPool_Default(t *testing.T) {
|
||||||
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
|
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
package testing
|
package testing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"b612.me/stardb"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestStarResult_MustString(t *testing.T) {
|
func TestStarResult_MustString(t *testing.T) {
|
||||||
@ -229,3 +232,100 @@ func TestStarResultCol_MustBool(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarResult_GetColumnNotFoundError(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
row := rows.Row(0)
|
||||||
|
_, err = row.GetString("does_not_exist")
|
||||||
|
if !errors.Is(err, stardb.ErrColumnNotFound) {
|
||||||
|
t.Fatalf("Expected ErrColumnNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarResult_GetNullValues(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.Exec(
|
||||||
|
"INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
||||||
|
"NullUser", "null@example.com", nil, nil, nil, nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Insert failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NullUser")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
row := rows.Row(0)
|
||||||
|
|
||||||
|
name, err := row.GetNullString("name")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNullString failed: %v", err)
|
||||||
|
}
|
||||||
|
if !name.Valid || name.String != "NullUser" {
|
||||||
|
t.Fatalf("Expected valid name NullUser, got %+v", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
age, err := row.GetNullInt64("age")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNullInt64 failed: %v", err)
|
||||||
|
}
|
||||||
|
if age.Valid {
|
||||||
|
t.Fatalf("Expected NULL age, got %+v", age)
|
||||||
|
}
|
||||||
|
|
||||||
|
balance, err := row.GetNullFloat64("balance")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNullFloat64 failed: %v", err)
|
||||||
|
}
|
||||||
|
if balance.Valid {
|
||||||
|
t.Fatalf("Expected NULL balance, got %+v", balance)
|
||||||
|
}
|
||||||
|
|
||||||
|
active, err := row.GetNullBool("active")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNullBool failed: %v", err)
|
||||||
|
}
|
||||||
|
if active.Valid {
|
||||||
|
t.Fatalf("Expected NULL active, got %+v", active)
|
||||||
|
}
|
||||||
|
|
||||||
|
createdAt, err := row.GetNullTime("created_at")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNullTime failed: %v", err)
|
||||||
|
}
|
||||||
|
if createdAt.Valid {
|
||||||
|
t.Fatalf("Expected NULL created_at, got %+v", createdAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarResult_GetNullTime_Valid(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT created_at FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
value, err := rows.Row(0).GetNullTime("created_at")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetNullTime failed: %v", err)
|
||||||
|
}
|
||||||
|
if !value.Valid {
|
||||||
|
t.Fatal("Expected valid created_at")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -26,6 +26,12 @@ func TestStarRows_Row(t *testing.T) {
|
|||||||
if len(row.Result()) != 0 {
|
if len(row.Result()) != 0 {
|
||||||
t.Errorf("Expected empty result for out of bounds index")
|
t.Errorf("Expected empty result for out of bounds index")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test negative index
|
||||||
|
row = rows.Row(-1)
|
||||||
|
if len(row.Result()) != 0 {
|
||||||
|
t.Errorf("Expected empty result for negative index")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStarRows_Col(t *testing.T) {
|
func TestStarRows_Col(t *testing.T) {
|
||||||
|
|||||||
@ -2,6 +2,10 @@ package testing
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@ -60,6 +64,457 @@ func TestStarDB_QueryContext(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_QueryRaw(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
rows, err := db.QueryRaw("SELECT name FROM users WHERE age > ? ORDER BY name", 25)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("QueryRaw failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var (
|
||||||
|
name string
|
||||||
|
count int
|
||||||
|
)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
t.Fatalf("Rows.Err failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Errorf("Expected 2 rows, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_QueryRawContext(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
rows, err := db.QueryRawContext(ctx, "SELECT name FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("QueryRawContext failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal("Expected at least one row")
|
||||||
|
}
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if name != "Alice" {
|
||||||
|
t.Errorf("Expected name Alice, got %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_QueryRaw_EmptyQuery(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
_, err := db.QueryRaw(" ")
|
||||||
|
if !errors.Is(err, stardb.ErrQueryEmpty) {
|
||||||
|
t.Fatalf("Expected ErrQueryEmpty, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ScanEach(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
err := db.ScanEach("SELECT name FROM users ORDER BY name", func(row *stardb.StarResult) error {
|
||||||
|
names = append(names, row.MustString("name"))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ScanEach failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(names) != 3 {
|
||||||
|
t.Fatalf("Expected 3 rows, got %d", len(names))
|
||||||
|
}
|
||||||
|
if names[0] != "Alice" || names[1] != "Bob" || names[2] != "Charlie" {
|
||||||
|
t.Fatalf("Unexpected names: %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ScanEach_Stop(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
count := 0
|
||||||
|
err := db.ScanEach("SELECT name FROM users ORDER BY name", func(row *stardb.StarResult) error {
|
||||||
|
count++
|
||||||
|
if row.MustString("name") == "Bob" {
|
||||||
|
return stardb.ErrScanStopped
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ScanEach stop failed: %v", err)
|
||||||
|
}
|
||||||
|
if count != 2 {
|
||||||
|
t.Fatalf("Expected callback count 2, got %d", count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ScanEach_NilCallback(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
err := db.ScanEach("SELECT name FROM users", nil)
|
||||||
|
if !errors.Is(err, stardb.ErrScanFuncNil) {
|
||||||
|
t.Fatalf("Expected ErrScanFuncNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ScanEachORM(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var user User
|
||||||
|
var names []string
|
||||||
|
err := db.ScanEachORM("SELECT * FROM users ORDER BY name", &user, func(target interface{}) error {
|
||||||
|
u := target.(*User)
|
||||||
|
names = append(names, u.Name)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ScanEachORM failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(names) != 3 {
|
||||||
|
t.Fatalf("Expected 3 names, got %d", len(names))
|
||||||
|
}
|
||||||
|
if names[0] != "Alice" || names[1] != "Bob" || names[2] != "Charlie" {
|
||||||
|
t.Fatalf("Unexpected names: %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ScanEachORM_NilCallback(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var user User
|
||||||
|
err := db.ScanEachORM("SELECT * FROM users", &user, nil)
|
||||||
|
if !errors.Is(err, stardb.ErrScanORMFuncNil) {
|
||||||
|
t.Fatalf("Expected ErrScanORMFuncNil, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_PlaceholderDollar(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetPlaceholderStyle(stardb.PlaceholderDollar)
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT name FROM users WHERE name = ? AND age = ?", "Alice", 25)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query with dollar placeholders failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if rows.Length() != 1 {
|
||||||
|
t.Fatalf("Expected 1 row, got %d", rows.Length())
|
||||||
|
}
|
||||||
|
if got := rows.Row(0).MustString("name"); got != "Alice" {
|
||||||
|
t.Fatalf("Expected Alice, got %s", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec("UPDATE users SET age = ? WHERE name = ?", 26, "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Exec with dollar placeholders failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetPlaceholderStyle(stardb.PlaceholderDollar)
|
||||||
|
|
||||||
|
var beforeCount int64
|
||||||
|
var afterCount int64
|
||||||
|
var mu sync.Mutex
|
||||||
|
var beforeQuery string
|
||||||
|
var afterQuery string
|
||||||
|
var afterErr error
|
||||||
|
|
||||||
|
db.SetSQLHooks(
|
||||||
|
func(ctx context.Context, query string, args []interface{}) {
|
||||||
|
atomic.AddInt64(&beforeCount, 1)
|
||||||
|
mu.Lock()
|
||||||
|
beforeQuery = query
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
atomic.AddInt64(&afterCount, 1)
|
||||||
|
mu.Lock()
|
||||||
|
afterQuery = query
|
||||||
|
afterErr = err
|
||||||
|
mu.Unlock()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 31, "Bob"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if atomic.LoadInt64(&beforeCount) == 0 {
|
||||||
|
t.Fatal("Expected before hook to be called")
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&afterCount) == 0 {
|
||||||
|
t.Fatal("Expected after hook to be called")
|
||||||
|
}
|
||||||
|
if !strings.Contains(beforeQuery, "$1") {
|
||||||
|
t.Fatalf("Expected converted query in before hook, got %s", beforeQuery)
|
||||||
|
}
|
||||||
|
if !strings.Contains(afterQuery, "$1") {
|
||||||
|
t.Fatalf("Expected converted query in after hook, got %s", afterQuery)
|
||||||
|
}
|
||||||
|
if afterErr != nil {
|
||||||
|
t.Fatalf("Expected nil error in after hook, got %v", afterErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, execErr := db.Exec("UPDATE table_does_not_exist SET age = ? WHERE name = ?", 31, "Bob")
|
||||||
|
if execErr == nil {
|
||||||
|
t.Fatal("Expected execution error for invalid table")
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&afterCount) < 2 {
|
||||||
|
t.Fatalf("Expected after hook call count >= 2, got %d", atomic.LoadInt64(&afterCount))
|
||||||
|
}
|
||||||
|
if afterErr == nil {
|
||||||
|
t.Fatal("Expected after hook to capture non-nil error for failed SQL")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks_SlowThreshold(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var afterCount int64
|
||||||
|
db.SetSQLHooks(
|
||||||
|
nil,
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
atomic.AddInt64(&afterCount, 1)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
db.SetSQLSlowThreshold(time.Hour)
|
||||||
|
if got := db.SQLSlowThreshold(); got != time.Hour {
|
||||||
|
t.Fatalf("Expected threshold 1h, got %v", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 41, "Alice"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&afterCount) != 0 {
|
||||||
|
t.Fatalf("Expected after hook to be skipped under threshold, got %d", afterCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := db.Exec("UPDATE table_does_not_exist SET age = ? WHERE name = ?", 31, "Bob")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected execution error for invalid table")
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&afterCount) != 1 {
|
||||||
|
t.Fatalf("Expected error path to trigger after hook, got %d", afterCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
db.SetSQLSlowThreshold(0)
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 42, "Alice"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&afterCount) != 2 {
|
||||||
|
t.Fatalf("Expected after hook to run after disabling threshold, got %d", afterCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks_FingerprintMeta(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintEnabled(true)
|
||||||
|
|
||||||
|
var gotMeta stardb.SQLHookMeta
|
||||||
|
var metaFound bool
|
||||||
|
db.SetSQLHooks(
|
||||||
|
nil,
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 31, "Bob"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !metaFound {
|
||||||
|
t.Fatal("Expected SQL fingerprint metadata in hook context")
|
||||||
|
}
|
||||||
|
if gotMeta.Fingerprint != "update users set age = ? where name = ?" {
|
||||||
|
t.Fatalf("Unexpected fingerprint: %q", gotMeta.Fingerprint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks_FingerprintMetaMaskLiterals(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintEnabled(true)
|
||||||
|
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals)
|
||||||
|
|
||||||
|
var gotMeta stardb.SQLHookMeta
|
||||||
|
var metaFound bool
|
||||||
|
db.SetSQLHooks(
|
||||||
|
nil,
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = 42 WHERE name = 'Bob' AND age < 100"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !metaFound {
|
||||||
|
t.Fatal("Expected SQL fingerprint metadata in hook context")
|
||||||
|
}
|
||||||
|
if gotMeta.Fingerprint != "update users set age = ? where name = ? and age < ?" {
|
||||||
|
t.Fatalf("Unexpected fingerprint for mask mode: %q", gotMeta.Fingerprint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks_FingerprintMetaMaskLiterals_StripComments(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintEnabled(true)
|
||||||
|
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals)
|
||||||
|
db.SetSQLFingerprintKeepComments(false) // default, explicit for clarity
|
||||||
|
|
||||||
|
var gotMeta stardb.SQLHookMeta
|
||||||
|
var metaFound bool
|
||||||
|
db.SetSQLHooks(
|
||||||
|
nil,
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = 42 /* trace:abc */ WHERE name = 'Bob'"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !metaFound {
|
||||||
|
t.Fatal("Expected SQL fingerprint metadata in hook context")
|
||||||
|
}
|
||||||
|
if gotMeta.Fingerprint != "update users set age = ? where name = ?" {
|
||||||
|
t.Fatalf("Unexpected fingerprint with stripped comments: %q", gotMeta.Fingerprint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks_FingerprintMetaMaskLiterals_KeepComments(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintEnabled(true)
|
||||||
|
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals)
|
||||||
|
db.SetSQLFingerprintKeepComments(true)
|
||||||
|
|
||||||
|
var gotMeta stardb.SQLHookMeta
|
||||||
|
var metaFound bool
|
||||||
|
db.SetSQLHooks(
|
||||||
|
nil,
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = 42 /* trace:abc */ WHERE name = 'Bob'"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !metaFound {
|
||||||
|
t.Fatal("Expected SQL fingerprint metadata in hook context")
|
||||||
|
}
|
||||||
|
if gotMeta.Fingerprint != "update users set age = ? /* trace:abc */ where name = ?" {
|
||||||
|
t.Fatalf("Unexpected fingerprint with kept comments: %q", gotMeta.Fingerprint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLFingerprintCounter(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintCounterEnabled(true)
|
||||||
|
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals)
|
||||||
|
db.SetSQLFingerprintKeepComments(false)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = 41 WHERE name = 'Alice'"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = 42 WHERE name = 'Bob'"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
counters := db.SQLFingerprintCounters()
|
||||||
|
key := "update users set age = ? where name = ?"
|
||||||
|
if counters[key] != 2 {
|
||||||
|
t.Fatalf("Expected fingerprint %q count=2, got %d", key, counters[key])
|
||||||
|
}
|
||||||
|
|
||||||
|
db.ResetSQLFingerprintCounters()
|
||||||
|
if got := len(db.SQLFingerprintCounters()); got != 0 {
|
||||||
|
t.Fatalf("Expected counters to be reset, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLFingerprintCounterDisabled(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintCounterEnabled(false)
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = 31 WHERE name = ?", "Bob"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := len(db.SQLFingerprintCounters()); got != 0 {
|
||||||
|
t.Fatalf("Expected no counters when disabled, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_SQLHooks_FingerprintMetaDisabled(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
db.SetSQLFingerprintEnabled(false)
|
||||||
|
|
||||||
|
metaFound := false
|
||||||
|
db.SetSQLHooks(
|
||||||
|
nil,
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
_, metaFound = stardb.SQLHookMetaFromContext(ctx)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 31, "Bob"); err != nil {
|
||||||
|
t.Fatalf("Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
if metaFound {
|
||||||
|
t.Fatal("Expected no SQL fingerprint metadata when disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_Exec(t *testing.T) {
|
func TestStarDB_Exec(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -154,6 +609,119 @@ func TestStarDB_Prepare(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarStmt_QueryRaw(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
stmt, err := db.Prepare("SELECT name FROM users WHERE name = ?")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Prepare failed: %v", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
rows, err := stmt.QueryRaw("Bob")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stmt.QueryRaw failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal("Expected one row for Bob")
|
||||||
|
}
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if name != "Bob" {
|
||||||
|
t.Errorf("Expected Bob, got %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarStmt_ScanEach(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
stmt, err := db.Prepare("SELECT name FROM users WHERE age >= ? ORDER BY name")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Prepare failed: %v", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
err = stmt.ScanEach(func(row *stardb.StarResult) error {
|
||||||
|
names = append(names, row.MustString("name"))
|
||||||
|
return nil
|
||||||
|
}, 30)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stmt.ScanEach failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" {
|
||||||
|
t.Fatalf("Unexpected names: %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarStmt_ScanEachORM(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
stmt, err := db.Prepare("SELECT * FROM users WHERE age >= ? ORDER BY name")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Prepare failed: %v", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
var user User
|
||||||
|
var names []string
|
||||||
|
err = stmt.ScanEachORM(&user, func(target interface{}) error {
|
||||||
|
u := target.(*User)
|
||||||
|
names = append(names, u.Name)
|
||||||
|
return nil
|
||||||
|
}, 30)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stmt.ScanEachORM failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" {
|
||||||
|
t.Fatalf("Unexpected names: %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarStmt_SQLHooks(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var (
|
||||||
|
beforeCount int64
|
||||||
|
afterCount int64
|
||||||
|
)
|
||||||
|
db.SetSQLHooks(
|
||||||
|
func(ctx context.Context, query string, args []interface{}) {
|
||||||
|
atomic.AddInt64(&beforeCount, 1)
|
||||||
|
},
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
atomic.AddInt64(&afterCount, 1)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt, err := db.Prepare("SELECT name FROM users WHERE name = ?")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Prepare failed: %v", err)
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
|
||||||
|
rows, err := stmt.Query("Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stmt.Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if rows.Length() != 1 {
|
||||||
|
t.Fatalf("Expected 1 row, got %d", rows.Length())
|
||||||
|
}
|
||||||
|
if atomic.LoadInt64(&beforeCount) == 0 || atomic.LoadInt64(&afterCount) == 0 {
|
||||||
|
t.Fatalf("Expected stmt execution to trigger hooks, before=%d after=%d", beforeCount, afterCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_Transaction(t *testing.T) {
|
func TestStarDB_Transaction(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -219,6 +787,131 @@ func TestStarDB_TransactionRollback(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarTx_QueryRaw(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
rows, err := tx.QueryRaw("SELECT name FROM users WHERE name = ?", "Charlie")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tx.QueryRaw failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if !rows.Next() {
|
||||||
|
t.Fatal("Expected one row for Charlie")
|
||||||
|
}
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
t.Fatalf("Scan failed: %v", err)
|
||||||
|
}
|
||||||
|
if name != "Charlie" {
|
||||||
|
t.Errorf("Expected Charlie, got %s", name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarTx_ScanEach(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
err = tx.ScanEach("SELECT name FROM users WHERE age >= ? ORDER BY name", func(row *stardb.StarResult) error {
|
||||||
|
names = append(names, row.MustString("name"))
|
||||||
|
return nil
|
||||||
|
}, 30)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tx.ScanEach failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" {
|
||||||
|
t.Fatalf("Unexpected names: %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarTx_ScanEachORM(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
var user User
|
||||||
|
var names []string
|
||||||
|
err = tx.ScanEachORM("SELECT * FROM users WHERE age >= ? ORDER BY name", &user, func(target interface{}) error {
|
||||||
|
u := target.(*User)
|
||||||
|
names = append(names, u.Name)
|
||||||
|
return nil
|
||||||
|
}, 30)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Tx.ScanEachORM failed: %v", err)
|
||||||
|
}
|
||||||
|
if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" {
|
||||||
|
t.Fatalf("Unexpected names: %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_WithTx(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
err := db.WithTx(func(tx *stardb.StarTx) error {
|
||||||
|
_, err := tx.Exec("UPDATE users SET age = ? WHERE name = ?", 41, "Alice")
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("WithTx failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if got := rows.Row(0).MustInt("age"); got != 41 {
|
||||||
|
t.Fatalf("Expected age 41, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarDB_WithTx_RollbackOnError(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
expectedErr := errors.New("business error")
|
||||||
|
err := db.WithTx(func(tx *stardb.StarTx) error {
|
||||||
|
if _, err := tx.Exec("UPDATE users SET age = ? WHERE name = ?", 55, "Alice"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return expectedErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, expectedErr) {
|
||||||
|
t.Fatalf("Expected %v, got %v", expectedErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Query failed: %v", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
if got := rows.Row(0).MustInt("age"); got == 55 {
|
||||||
|
t.Fatalf("Expected rollback to keep original age, got %d", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestStarDB_SetMaxConnections(t *testing.T) {
|
func TestStarDB_SetMaxConnections(t *testing.T) {
|
||||||
db := setupTestDB(t)
|
db := setupTestDB(t)
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
@ -231,3 +924,80 @@ func TestStarDB_SetMaxConnections(t *testing.T) {
|
|||||||
t.Errorf("Expected MaxOpenConnections 10, got %d", stats.MaxOpenConnections)
|
t.Errorf("Expected MaxOpenConnections 10, got %d", stats.MaxOpenConnections)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStarDB_ConcurrentRuntimeAndQuery(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
db.SetMaxOpenConns(1)
|
||||||
|
db.SetMaxIdleConns(1)
|
||||||
|
|
||||||
|
db.SetSQLHooks(
|
||||||
|
func(ctx context.Context, query string, args []interface{}) {},
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {},
|
||||||
|
)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
for i := 0; i < 4; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
rows, err := db.Query("SELECT id FROM users WHERE name = ?", "Alice")
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Concurrent query failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = rows.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
if j%2 == 0 {
|
||||||
|
db.SetPlaceholderStyle(stardb.PlaceholderDollar)
|
||||||
|
db.SetSQLSlowThreshold(time.Millisecond)
|
||||||
|
} else {
|
||||||
|
db.SetPlaceholderStyle(stardb.PlaceholderQuestion)
|
||||||
|
db.SetSQLSlowThreshold(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStarTx_SQLHooks(t *testing.T) {
|
||||||
|
db := setupTestDB(t)
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
var beforeCount int64
|
||||||
|
var afterCount int64
|
||||||
|
db.SetSQLHooks(
|
||||||
|
func(ctx context.Context, query string, args []interface{}) {
|
||||||
|
atomic.AddInt64(&beforeCount, 1)
|
||||||
|
},
|
||||||
|
func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {
|
||||||
|
atomic.AddInt64(&afterCount, 1)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Begin failed: %v", err)
|
||||||
|
}
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
if _, err := tx.Exec("UPDATE users SET age = ? WHERE name = ?", 27, "Alice"); err != nil {
|
||||||
|
t.Fatalf("Tx.Exec failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if atomic.LoadInt64(&beforeCount) == 0 || atomic.LoadInt64(&afterCount) == 0 {
|
||||||
|
t.Fatalf("Expected tx execution to trigger hooks, before=%d after=%d", beforeCount, afterCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
154
tx.go
154
tx.go
@ -3,7 +3,8 @@ package stardb
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StarTx represents a database transaction
|
// StarTx represents a database transaction
|
||||||
@ -12,6 +13,13 @@ type StarTx struct {
|
|||||||
db *StarDB
|
db *StarDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *StarTx) ensureTx() error {
|
||||||
|
if t == nil || t.tx == nil || t.db == nil {
|
||||||
|
return ErrTxNotInitialized
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Query executes a query within the transaction
|
// Query executes a query within the transaction
|
||||||
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
|
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
|
||||||
return t.query(nil, query, args...)
|
return t.query(nil, query, args...)
|
||||||
@ -22,21 +30,53 @@ func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interfa
|
|||||||
return t.query(ctx, query, args...)
|
return t.query(ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// query is the internal query implementation
|
// QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing.
|
||||||
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
if err := t.db.Ping(); err != nil {
|
return t.queryRaw(nil, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing.
|
||||||
|
func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
return t.queryRaw(ctx, query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||||
|
if err := t.ensureTx(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, ErrQueryEmpty
|
||||||
|
}
|
||||||
|
|
||||||
var rows *sql.Rows
|
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
|
||||||
var err error
|
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, hookArgs)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var (
|
||||||
|
rows *sql.Rows
|
||||||
|
err error
|
||||||
|
)
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
rows, err = t.tx.Query(query, args...)
|
rows, err = t.tx.Query(query, args...)
|
||||||
} else {
|
} else {
|
||||||
rows, err = t.tx.QueryContext(ctx, query, args...)
|
rows, err = t.tx.QueryContext(ctx, query, args...)
|
||||||
}
|
}
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, hookArgs, duration, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return rows, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// query is the internal query implementation
|
||||||
|
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
||||||
|
rows, err := t.queryRaw(ctx, query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -47,10 +87,13 @@ func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !t.db.ManualScan {
|
if !t.db.ManualScan {
|
||||||
err = starRows.parse()
|
if err := starRows.parse(); err != nil {
|
||||||
|
_ = rows.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return starRows, err
|
return starRows, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec executes a query within the transaction
|
// Exec executes a query within the transaction
|
||||||
@ -65,38 +108,97 @@ func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interfac
|
|||||||
|
|
||||||
// exec is the internal exec implementation
|
// exec is the internal exec implementation
|
||||||
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
if err := t.db.Ping(); err != nil {
|
if err := t.ensureTx(); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
if ctx == nil {
|
return nil, ErrQueryEmpty
|
||||||
return t.tx.Exec(query, args...)
|
|
||||||
}
|
}
|
||||||
return t.tx.ExecContext(ctx, query, args...)
|
|
||||||
|
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
|
||||||
|
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, hookArgs)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
var (
|
||||||
|
result sql.Result
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if ctx == nil {
|
||||||
|
result, err = t.tx.Exec(query, args...)
|
||||||
|
} else {
|
||||||
|
result, err = t.tx.ExecContext(ctx, query, args...)
|
||||||
|
}
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, hookArgs, duration, err)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare creates a prepared statement within the transaction
|
// Prepare creates a prepared statement within the transaction
|
||||||
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
|
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
|
||||||
|
if err := t.ensureTx(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, ErrQueryEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
|
||||||
|
hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, nil)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
stmt, err := t.tx.Prepare(query)
|
stmt, err := t.tx.Prepare(query)
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, nil, duration, err)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &StarStmt{stmt: stmt, db: t.db}, nil
|
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrepareContext creates a prepared statement with context
|
// PrepareContext creates a prepared statement with context
|
||||||
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
||||||
|
if err := t.ensureTx(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(query) == "" {
|
||||||
|
return nil, ErrQueryEmpty
|
||||||
|
}
|
||||||
|
|
||||||
|
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
|
||||||
|
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
||||||
|
if beforeHook != nil {
|
||||||
|
beforeHook(hookCtx, query, nil)
|
||||||
|
}
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
stmt, err := t.tx.PrepareContext(ctx, query)
|
stmt, err := t.tx.PrepareContext(ctx, query)
|
||||||
|
duration := time.Since(start)
|
||||||
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
||||||
|
afterHook(hookCtx, query, nil, duration, err)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &StarStmt{stmt: stmt, db: t.db}, nil
|
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryStmt executes a prepared statement query within the transaction
|
// QueryStmt executes a prepared statement query within the transaction
|
||||||
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := t.Prepare(query)
|
stmt, err := t.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -108,8 +210,8 @@ func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error)
|
|||||||
|
|
||||||
// QueryStmtContext executes a prepared statement query with context
|
// QueryStmtContext executes a prepared statement query with context
|
||||||
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := t.PrepareContext(ctx, query)
|
stmt, err := t.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -121,8 +223,8 @@ func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...int
|
|||||||
|
|
||||||
// ExecStmt executes a prepared statement within the transaction
|
// ExecStmt executes a prepared statement within the transaction
|
||||||
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := t.Prepare(query)
|
stmt, err := t.Prepare(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -134,8 +236,8 @@ func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error)
|
|||||||
|
|
||||||
// ExecStmtContext executes a prepared statement with context
|
// ExecStmtContext executes a prepared statement with context
|
||||||
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||||
if query == "" {
|
if strings.TrimSpace(query) == "" {
|
||||||
return nil, errors.New("query string cannot be empty")
|
return nil, ErrQueryEmpty
|
||||||
}
|
}
|
||||||
stmt, err := t.PrepareContext(ctx, query)
|
stmt, err := t.PrepareContext(ctx, query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -147,10 +249,16 @@ func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...inte
|
|||||||
|
|
||||||
// Commit commits the transaction
|
// Commit commits the transaction
|
||||||
func (t *StarTx) Commit() error {
|
func (t *StarTx) Commit() error {
|
||||||
|
if err := t.ensureTx(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return t.tx.Commit()
|
return t.tx.Commit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback rolls back the transaction
|
// Rollback rolls back the transaction
|
||||||
func (t *StarTx) Rollback() error {
|
func (t *StarTx) Rollback() error {
|
||||||
|
if err := t.ensureTx(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return t.tx.Rollback()
|
return t.tx.Rollback()
|
||||||
}
|
}
|
||||||
|
|||||||
45
tx_helper.go
Normal file
45
tx_helper.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
// WithTx runs fn in a transaction and handles commit/rollback automatically.
|
||||||
|
func (s *StarDB) WithTx(fn func(tx *StarTx) error) error {
|
||||||
|
return s.WithTxContext(context.Background(), nil, fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithTxContext runs fn in a transaction with context/options and handles commit/rollback automatically.
|
||||||
|
func (s *StarDB) WithTxContext(ctx context.Context, opts *sql.TxOptions, fn func(tx *StarTx) error) (err error) {
|
||||||
|
if fn == nil {
|
||||||
|
return ErrTxFuncNil
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
|
||||||
|
tx, err := s.BeginTx(ctx, opts)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if p := recover(); p != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
panic(p)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := fn(tx); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
Loading…
x
Reference in New Issue
Block a user