2026-03-07 19:27:44 +08:00
|
|
|
package stardb
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"context"
|
|
|
|
|
"database/sql"
|
|
|
|
|
"fmt"
|
|
|
|
|
"reflect"
|
|
|
|
|
"strings"
|
2026-03-20 13:36:59 +08:00
|
|
|
"sync/atomic"
|
2026-03-07 19:27:44 +08:00
|
|
|
)
|
|
|
|
|
|
2026-03-20 13:36:59 +08:00
|
|
|
type multiSQLResult struct {
|
|
|
|
|
results []sql.Result
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type batchExecMetaKey struct{}
|
|
|
|
|
|
|
|
|
|
// BatchExecMeta contains chunk execution metadata for batch insert operations.
|
|
|
|
|
// It is attached to context for chunked execution and can be read in SQL hooks.
|
|
|
|
|
type BatchExecMeta struct {
|
|
|
|
|
ChunkIndex int
|
|
|
|
|
ChunkCount int
|
|
|
|
|
ChunkRows int
|
|
|
|
|
TotalRows int
|
|
|
|
|
ColumnCount int
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BatchExecMetaFromContext extracts batch chunk metadata from context.
|
|
|
|
|
func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) {
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
return BatchExecMeta{}, false
|
|
|
|
|
}
|
|
|
|
|
meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta)
|
|
|
|
|
return meta, ok
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context {
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
}
|
|
|
|
|
return context.WithValue(ctx, batchExecMetaKey{}, meta)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m multiSQLResult) LastInsertId() (int64, error) {
|
|
|
|
|
if len(m.results) == 0 {
|
|
|
|
|
return 0, nil
|
|
|
|
|
}
|
|
|
|
|
return m.results[len(m.results)-1].LastInsertId()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (m multiSQLResult) RowsAffected() (int64, error) {
|
|
|
|
|
var total int64
|
|
|
|
|
for _, result := range m.results {
|
|
|
|
|
affected, err := result.RowsAffected()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return 0, err
|
|
|
|
|
}
|
|
|
|
|
total += affected
|
|
|
|
|
}
|
|
|
|
|
return total, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs.
|
|
|
|
|
// <= 0 disables splitting and keeps single-statement behavior.
|
|
|
|
|
func (s *StarDB) SetBatchInsertMaxRows(maxRows int) {
|
|
|
|
|
if s == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if maxRows < 0 {
|
|
|
|
|
maxRows = 0
|
|
|
|
|
}
|
|
|
|
|
atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BatchInsertMaxRows returns current batch split threshold.
|
|
|
|
|
// 0 means disabled.
|
|
|
|
|
func (s *StarDB) BatchInsertMaxRows() int {
|
|
|
|
|
if s == nil {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
value := atomic.LoadInt64(&s.batchInsertMaxRows)
|
|
|
|
|
if value <= 0 {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
return int(value)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs.
|
|
|
|
|
// <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers.
|
|
|
|
|
func (s *StarDB) SetBatchInsertMaxParams(maxParams int) {
|
|
|
|
|
if s == nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
if maxParams < 0 {
|
|
|
|
|
maxParams = 0
|
|
|
|
|
}
|
|
|
|
|
atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BatchInsertMaxParams returns configured max bind parameter threshold.
|
|
|
|
|
// 0 means auto mode.
|
|
|
|
|
func (s *StarDB) BatchInsertMaxParams() int {
|
|
|
|
|
if s == nil {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
value := atomic.LoadInt64(&s.batchInsertMaxParams)
|
|
|
|
|
if value <= 0 {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
return int(value)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func detectBatchInsertMaxParams(db *sql.DB) int {
|
|
|
|
|
if db == nil {
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver()))
|
|
|
|
|
switch {
|
|
|
|
|
case strings.Contains(driverType, "sqlite"):
|
|
|
|
|
// Keep conservative default for wide compatibility.
|
|
|
|
|
return 999
|
|
|
|
|
case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"):
|
|
|
|
|
return 65535
|
|
|
|
|
case strings.Contains(driverType, "mysql"):
|
|
|
|
|
return 65535
|
|
|
|
|
case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"):
|
|
|
|
|
return 2100
|
|
|
|
|
default:
|
|
|
|
|
return 0
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func minPositive(a, b int) int {
|
|
|
|
|
if a <= 0 {
|
|
|
|
|
return b
|
|
|
|
|
}
|
|
|
|
|
if b <= 0 {
|
|
|
|
|
return a
|
|
|
|
|
}
|
|
|
|
|
if a < b {
|
|
|
|
|
return a
|
|
|
|
|
}
|
|
|
|
|
return b
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) {
|
|
|
|
|
maxRows := s.BatchInsertMaxRows()
|
|
|
|
|
maxParams := s.BatchInsertMaxParams()
|
|
|
|
|
if maxParams <= 0 {
|
|
|
|
|
maxParams = detectBatchInsertMaxParams(s.db)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
maxRowsByParams := 0
|
|
|
|
|
if maxParams > 0 {
|
|
|
|
|
maxRowsByParams = maxParams / columnCount
|
|
|
|
|
if maxRowsByParams <= 0 {
|
|
|
|
|
return 0, ErrBatchInsertMaxParamsTooLow
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return minPositive(maxRows, maxRowsByParams), nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) {
|
|
|
|
|
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
|
|
|
|
|
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
|
|
|
|
|
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
|
|
|
|
|
tableName,
|
|
|
|
|
strings.Join(columns, ", "),
|
|
|
|
|
placeholders)
|
|
|
|
|
|
|
|
|
|
args := make([]interface{}, 0, len(values)*len(columns))
|
|
|
|
|
for _, row := range values {
|
|
|
|
|
args = append(args, row...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return query, args
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) {
|
|
|
|
|
if ctx == nil {
|
|
|
|
|
ctx = context.Background()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
tx, err := s.BeginTx(ctx, nil)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
chunkCount := (len(values) + chunkSize - 1) / chunkSize
|
|
|
|
|
results := make([]sql.Result, 0, chunkCount)
|
|
|
|
|
for start := 0; start < len(values); start += chunkSize {
|
|
|
|
|
end := start + chunkSize
|
|
|
|
|
if end > len(values) {
|
|
|
|
|
end = len(values)
|
|
|
|
|
}
|
|
|
|
|
chunkIndex := start/chunkSize + 1
|
|
|
|
|
query, args := buildBatchInsertQuery(tableName, columns, values[start:end])
|
|
|
|
|
chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{
|
|
|
|
|
ChunkIndex: chunkIndex,
|
|
|
|
|
ChunkCount: chunkCount,
|
|
|
|
|
ChunkRows: end - start,
|
|
|
|
|
TotalRows: len(values),
|
|
|
|
|
ColumnCount: len(columns),
|
|
|
|
|
})
|
|
|
|
|
result, execErr := tx.exec(chunkCtx, query, args...)
|
|
|
|
|
if execErr != nil {
|
|
|
|
|
_ = tx.Rollback()
|
|
|
|
|
return nil, execErr
|
|
|
|
|
}
|
|
|
|
|
results = append(results, result)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
|
|
|
_ = tx.Rollback()
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return multiSQLResult{results: results}, nil
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-07 19:27:44 +08:00
|
|
|
// BatchInsert performs batch insert operation
|
|
|
|
|
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
|
|
|
|
|
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
|
|
|
|
return s.batchInsert(nil, tableName, columns, values)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BatchInsertContext performs batch insert with context
|
|
|
|
|
func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
|
|
|
|
return s.batchInsert(ctx, tableName, columns, values)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// batchInsert is the internal implementation
|
|
|
|
|
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
2026-03-20 13:36:59 +08:00
|
|
|
if strings.TrimSpace(tableName) == "" {
|
|
|
|
|
return nil, ErrTableNameEmpty
|
2026-03-07 19:27:44 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-20 13:36:59 +08:00
|
|
|
if len(columns) == 0 {
|
|
|
|
|
return nil, ErrNoInsertColumns
|
|
|
|
|
}
|
2026-03-07 19:27:44 +08:00
|
|
|
|
2026-03-20 13:36:59 +08:00
|
|
|
if len(values) == 0 {
|
|
|
|
|
return nil, ErrNoInsertValues
|
|
|
|
|
}
|
2026-03-07 19:27:44 +08:00
|
|
|
|
2026-03-20 13:36:59 +08:00
|
|
|
for i, row := range values {
|
|
|
|
|
if len(row) != len(columns) {
|
|
|
|
|
return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns))
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
chunkSize, err := s.batchInsertChunkSize(len(columns))
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
if chunkSize > 0 && len(values) > chunkSize {
|
|
|
|
|
return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize)
|
2026-03-07 19:27:44 +08:00
|
|
|
}
|
|
|
|
|
|
2026-03-20 13:36:59 +08:00
|
|
|
query, args := buildBatchInsertQuery(tableName, columns, values)
|
2026-03-07 19:27:44 +08:00
|
|
|
return s.exec(ctx, query, args...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BatchInsertStructs performs batch insert using structs
|
|
|
|
|
func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
|
|
|
|
return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// BatchInsertStructsContext performs batch insert using structs with context
|
|
|
|
|
func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
|
|
|
|
return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// batchInsertStructs is the internal implementation
|
|
|
|
|
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
2026-03-20 13:36:59 +08:00
|
|
|
if structs == nil {
|
|
|
|
|
return nil, ErrStructsNil
|
|
|
|
|
}
|
|
|
|
|
|
2026-03-07 19:27:44 +08:00
|
|
|
// Get slice of structs
|
|
|
|
|
targetValue := reflect.ValueOf(structs)
|
|
|
|
|
if targetValue.Kind() == reflect.Ptr {
|
2026-03-20 13:36:59 +08:00
|
|
|
if targetValue.IsNil() {
|
|
|
|
|
return nil, ErrStructsPointerNil
|
|
|
|
|
}
|
2026-03-07 19:27:44 +08:00
|
|
|
targetValue = targetValue.Elem()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
|
2026-03-20 13:36:59 +08:00
|
|
|
return nil, ErrStructsNotSlice
|
2026-03-07 19:27:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if targetValue.Len() == 0 {
|
2026-03-20 13:36:59 +08:00
|
|
|
return nil, ErrNoStructsToInsert
|
2026-03-07 19:27:44 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Get field names from first struct
|
|
|
|
|
firstStruct := targetValue.Index(0).Interface()
|
|
|
|
|
fieldNames, err := getStructFieldNames(firstStruct, "db")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Filter out auto-increment fields
|
|
|
|
|
var columns []string
|
|
|
|
|
for _, fieldName := range fieldNames {
|
|
|
|
|
isAutoIncrement := false
|
|
|
|
|
for _, autoField := range autoIncrementFields {
|
|
|
|
|
if fieldName == autoField {
|
|
|
|
|
isAutoIncrement = true
|
|
|
|
|
break
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if !isAutoIncrement {
|
|
|
|
|
columns = append(columns, fieldName)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Extract values from all structs
|
|
|
|
|
var values [][]interface{}
|
|
|
|
|
for i := 0; i < targetValue.Len(); i++ {
|
|
|
|
|
structVal := targetValue.Index(i).Interface()
|
|
|
|
|
fieldValues, err := getStructFieldValues(structVal, "db")
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
var row []interface{}
|
|
|
|
|
for _, col := range columns {
|
|
|
|
|
row = append(row, fieldValues[col])
|
|
|
|
|
}
|
|
|
|
|
values = append(values, row)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return s.batchInsert(ctx, tableName, columns, values)
|
|
|
|
|
}
|