stardb/batch.go

339 lines
8.8 KiB
Go

package stardb
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"sync/atomic"
)
type multiSQLResult struct {
results []sql.Result
}
type batchExecMetaKey struct{}
// BatchExecMeta contains chunk execution metadata for batch insert operations.
// It is attached to context for chunked execution and can be read in SQL hooks.
type BatchExecMeta struct {
ChunkIndex int
ChunkCount int
ChunkRows int
TotalRows int
ColumnCount int
}
// BatchExecMetaFromContext extracts batch chunk metadata from context.
func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) {
if ctx == nil {
return BatchExecMeta{}, false
}
meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta)
return meta, ok
}
func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, batchExecMetaKey{}, meta)
}
func (m multiSQLResult) LastInsertId() (int64, error) {
if len(m.results) == 0 {
return 0, nil
}
return m.results[len(m.results)-1].LastInsertId()
}
func (m multiSQLResult) RowsAffected() (int64, error) {
var total int64
for _, result := range m.results {
affected, err := result.RowsAffected()
if err != nil {
return 0, err
}
total += affected
}
return total, nil
}
// SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs.
// <= 0 disables splitting and keeps single-statement behavior.
func (s *StarDB) SetBatchInsertMaxRows(maxRows int) {
if s == nil {
return
}
if maxRows < 0 {
maxRows = 0
}
atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows))
}
// BatchInsertMaxRows returns current batch split threshold.
// 0 means disabled.
func (s *StarDB) BatchInsertMaxRows() int {
if s == nil {
return 0
}
value := atomic.LoadInt64(&s.batchInsertMaxRows)
if value <= 0 {
return 0
}
return int(value)
}
// SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs.
// <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers.
func (s *StarDB) SetBatchInsertMaxParams(maxParams int) {
if s == nil {
return
}
if maxParams < 0 {
maxParams = 0
}
atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams))
}
// BatchInsertMaxParams returns configured max bind parameter threshold.
// 0 means auto mode.
func (s *StarDB) BatchInsertMaxParams() int {
if s == nil {
return 0
}
value := atomic.LoadInt64(&s.batchInsertMaxParams)
if value <= 0 {
return 0
}
return int(value)
}
func detectBatchInsertMaxParams(db *sql.DB) int {
if db == nil {
return 0
}
driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver()))
switch {
case strings.Contains(driverType, "sqlite"):
// Keep conservative default for wide compatibility.
return 999
case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"):
return 65535
case strings.Contains(driverType, "mysql"):
return 65535
case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"):
return 2100
default:
return 0
}
}
func minPositive(a, b int) int {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) {
maxRows := s.BatchInsertMaxRows()
maxParams := s.BatchInsertMaxParams()
if maxParams <= 0 {
maxParams = detectBatchInsertMaxParams(s.db)
}
maxRowsByParams := 0
if maxParams > 0 {
maxRowsByParams = maxParams / columnCount
if maxRowsByParams <= 0 {
return 0, ErrBatchInsertMaxParamsTooLow
}
}
return minPositive(maxRows, maxRowsByParams), nil
}
func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) {
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
tableName,
strings.Join(columns, ", "),
placeholders)
args := make([]interface{}, 0, len(values)*len(columns))
for _, row := range values {
args = append(args, row...)
}
return query, args
}
func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) {
if ctx == nil {
ctx = context.Background()
}
tx, err := s.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
chunkCount := (len(values) + chunkSize - 1) / chunkSize
results := make([]sql.Result, 0, chunkCount)
for start := 0; start < len(values); start += chunkSize {
end := start + chunkSize
if end > len(values) {
end = len(values)
}
chunkIndex := start/chunkSize + 1
query, args := buildBatchInsertQuery(tableName, columns, values[start:end])
chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{
ChunkIndex: chunkIndex,
ChunkCount: chunkCount,
ChunkRows: end - start,
TotalRows: len(values),
ColumnCount: len(columns),
})
result, execErr := tx.exec(chunkCtx, query, args...)
if execErr != nil {
_ = tx.Rollback()
return nil, execErr
}
results = append(results, result)
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return nil, err
}
return multiSQLResult{results: results}, nil
}
// BatchInsert performs batch insert operation
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
return s.batchInsert(nil, tableName, columns, values)
}
// BatchInsertContext performs batch insert with context
func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
return s.batchInsert(ctx, tableName, columns, values)
}
// batchInsert is the internal implementation
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
if strings.TrimSpace(tableName) == "" {
return nil, ErrTableNameEmpty
}
if len(columns) == 0 {
return nil, ErrNoInsertColumns
}
if len(values) == 0 {
return nil, ErrNoInsertValues
}
for i, row := range values {
if len(row) != len(columns) {
return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns))
}
}
chunkSize, err := s.batchInsertChunkSize(len(columns))
if err != nil {
return nil, err
}
if chunkSize > 0 && len(values) > chunkSize {
return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize)
}
query, args := buildBatchInsertQuery(tableName, columns, values)
return s.exec(ctx, query, args...)
}
// BatchInsertStructs performs batch insert using structs
func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...)
}
// BatchInsertStructsContext performs batch insert using structs with context
func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...)
}
// batchInsertStructs is the internal implementation
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
if structs == nil {
return nil, ErrStructsNil
}
// Get slice of structs
targetValue := reflect.ValueOf(structs)
if targetValue.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return nil, ErrStructsPointerNil
}
targetValue = targetValue.Elem()
}
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
return nil, ErrStructsNotSlice
}
if targetValue.Len() == 0 {
return nil, ErrNoStructsToInsert
}
// Get field names from first struct
firstStruct := targetValue.Index(0).Interface()
fieldNames, err := getStructFieldNames(firstStruct, "db")
if err != nil {
return nil, err
}
// Filter out auto-increment fields
var columns []string
for _, fieldName := range fieldNames {
isAutoIncrement := false
for _, autoField := range autoIncrementFields {
if fieldName == autoField {
isAutoIncrement = true
break
}
}
if !isAutoIncrement {
columns = append(columns, fieldName)
}
}
// Extract values from all structs
var values [][]interface{}
for i := 0; i < targetValue.Len(); i++ {
structVal := targetValue.Index(i).Interface()
fieldValues, err := getStructFieldValues(structVal, "db")
if err != nil {
return nil, err
}
var row []interface{}
for _, col := range columns {
row = append(row, fieldValues[col])
}
values = append(values, row)
}
return s.batchInsert(ctx, tableName, columns, values)
}