package stardb import ( "context" "database/sql" "fmt" "reflect" "strings" "sync/atomic" ) type multiSQLResult struct { results []sql.Result } type batchExecMetaKey struct{} // BatchExecMeta contains chunk execution metadata for batch insert operations. // It is attached to context for chunked execution and can be read in SQL hooks. type BatchExecMeta struct { ChunkIndex int ChunkCount int ChunkRows int TotalRows int ColumnCount int } // BatchExecMetaFromContext extracts batch chunk metadata from context. func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) { if ctx == nil { return BatchExecMeta{}, false } meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta) return meta, ok } func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context { if ctx == nil { ctx = context.Background() } return context.WithValue(ctx, batchExecMetaKey{}, meta) } func (m multiSQLResult) LastInsertId() (int64, error) { if len(m.results) == 0 { return 0, nil } return m.results[len(m.results)-1].LastInsertId() } func (m multiSQLResult) RowsAffected() (int64, error) { var total int64 for _, result := range m.results { affected, err := result.RowsAffected() if err != nil { return 0, err } total += affected } return total, nil } // SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs. // <= 0 disables splitting and keeps single-statement behavior. func (s *StarDB) SetBatchInsertMaxRows(maxRows int) { if s == nil { return } if maxRows < 0 { maxRows = 0 } atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows)) } // BatchInsertMaxRows returns current batch split threshold. // 0 means disabled. func (s *StarDB) BatchInsertMaxRows() int { if s == nil { return 0 } value := atomic.LoadInt64(&s.batchInsertMaxRows) if value <= 0 { return 0 } return int(value) } // SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs. // <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers. func (s *StarDB) SetBatchInsertMaxParams(maxParams int) { if s == nil { return } if maxParams < 0 { maxParams = 0 } atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams)) } // BatchInsertMaxParams returns configured max bind parameter threshold. // 0 means auto mode. func (s *StarDB) BatchInsertMaxParams() int { if s == nil { return 0 } value := atomic.LoadInt64(&s.batchInsertMaxParams) if value <= 0 { return 0 } return int(value) } func detectBatchInsertMaxParams(db *sql.DB) int { if db == nil { return 0 } driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver())) switch { case strings.Contains(driverType, "sqlite"): // Keep conservative default for wide compatibility. return 999 case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"): return 65535 case strings.Contains(driverType, "mysql"): return 65535 case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"): return 2100 default: return 0 } } func minPositive(a, b int) int { if a <= 0 { return b } if b <= 0 { return a } if a < b { return a } return b } func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) { maxRows := s.BatchInsertMaxRows() maxParams := s.BatchInsertMaxParams() if maxParams <= 0 { maxParams = detectBatchInsertMaxParams(s.db) } maxRowsByParams := 0 if maxParams > 0 { maxRowsByParams = maxParams / columnCount if maxRowsByParams <= 0 { return 0, ErrBatchInsertMaxParamsTooLow } } return minPositive(maxRows, maxRowsByParams), nil } func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) { placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)" placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", tableName, strings.Join(columns, ", "), placeholders) args := make([]interface{}, 0, len(values)*len(columns)) for _, row := range values { args = append(args, row...) } return query, args } func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) { if ctx == nil { ctx = context.Background() } tx, err := s.BeginTx(ctx, nil) if err != nil { return nil, err } chunkCount := (len(values) + chunkSize - 1) / chunkSize results := make([]sql.Result, 0, chunkCount) for start := 0; start < len(values); start += chunkSize { end := start + chunkSize if end > len(values) { end = len(values) } chunkIndex := start/chunkSize + 1 query, args := buildBatchInsertQuery(tableName, columns, values[start:end]) chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{ ChunkIndex: chunkIndex, ChunkCount: chunkCount, ChunkRows: end - start, TotalRows: len(values), ColumnCount: len(columns), }) result, execErr := tx.exec(chunkCtx, query, args...) if execErr != nil { _ = tx.Rollback() return nil, execErr } results = append(results, result) } if err := tx.Commit(); err != nil { _ = tx.Rollback() return nil, err } return multiSQLResult{results: results}, nil } // BatchInsert performs batch insert operation // Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}}) func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) { return s.batchInsert(nil, tableName, columns, values) } // BatchInsertContext performs batch insert with context func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) { return s.batchInsert(ctx, tableName, columns, values) } // batchInsert is the internal implementation func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) { if strings.TrimSpace(tableName) == "" { return nil, ErrTableNameEmpty } if len(columns) == 0 { return nil, ErrNoInsertColumns } if len(values) == 0 { return nil, ErrNoInsertValues } for i, row := range values { if len(row) != len(columns) { return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns)) } } chunkSize, err := s.batchInsertChunkSize(len(columns)) if err != nil { return nil, err } if chunkSize > 0 && len(values) > chunkSize { return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize) } query, args := buildBatchInsertQuery(tableName, columns, values) return s.exec(ctx, query, args...) } // BatchInsertStructs performs batch insert using structs func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...) } // BatchInsertStructsContext performs batch insert using structs with context func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...) } // batchInsertStructs is the internal implementation func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { if structs == nil { return nil, ErrStructsNil } // Get slice of structs targetValue := reflect.ValueOf(structs) if targetValue.Kind() == reflect.Ptr { if targetValue.IsNil() { return nil, ErrStructsPointerNil } targetValue = targetValue.Elem() } if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array { return nil, ErrStructsNotSlice } if targetValue.Len() == 0 { return nil, ErrNoStructsToInsert } // Get field names from first struct firstStruct := targetValue.Index(0).Interface() fieldNames, err := getStructFieldNames(firstStruct, "db") if err != nil { return nil, err } // Filter out auto-increment fields var columns []string for _, fieldName := range fieldNames { isAutoIncrement := false for _, autoField := range autoIncrementFields { if fieldName == autoField { isAutoIncrement = true break } } if !isAutoIncrement { columns = append(columns, fieldName) } } // Extract values from all structs var values [][]interface{} for i := 0; i < targetValue.Len(); i++ { structVal := targetValue.Index(i).Interface() fieldValues, err := getStructFieldValues(structVal, "db") if err != nil { return nil, err } var row []interface{} for _, col := range columns { row = append(row, fieldValues[col]) } values = append(values, row) } return s.batchInsert(ctx, tableName, columns, values) }