stardb/scan_each.go

125 lines
3.3 KiB
Go
Raw Permalink Normal View History

package stardb
import (
"context"
"database/sql"
"errors"
"reflect"
)
// ScanEachFunc is called for each scanned row in streaming mode.
type ScanEachFunc func(row *StarResult) error
// ScanEach executes query in streaming mode and invokes fn for each row.
func (s *StarDB) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(nil, query, fn, args...)
}
// ScanEachContext executes query with context in streaming mode and invokes fn for each row.
func (s *StarDB) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(ctx, query, fn, args...)
}
func (s *StarDB) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
rows, err := s.queryRaw(ctx, query, args...)
if err != nil {
return err
}
return scanEachSQLRows(rows, fn)
}
// ScanEach executes query in transaction streaming mode and invokes fn for each row.
func (t *StarTx) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
return t.scanEach(nil, query, fn, args...)
}
// ScanEachContext executes query with context in transaction streaming mode and invokes fn for each row.
func (t *StarTx) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
return t.scanEach(ctx, query, fn, args...)
}
func (t *StarTx) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
rows, err := t.queryRaw(ctx, query, args...)
if err != nil {
return err
}
return scanEachSQLRows(rows, fn)
}
// ScanEach executes prepared statement query in streaming mode and invokes fn for each row.
func (s *StarStmt) ScanEach(fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(nil, fn, args...)
}
// ScanEachContext executes prepared statement query with context in streaming mode and invokes fn for each row.
func (s *StarStmt) ScanEachContext(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(ctx, fn, args...)
}
func (s *StarStmt) scanEach(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
rows, err := s.queryRaw(ctx, args...)
if err != nil {
return err
}
return scanEachSQLRows(rows, fn)
}
func scanEachSQLRows(rows *sql.Rows, fn ScanEachFunc) error {
if fn == nil {
_ = rows.Close()
return ErrScanFuncNil
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
types := make([]reflect.Type, len(columnTypes))
for i, colType := range columnTypes {
types[i] = colType.ScanType()
}
columnIndex := make(map[string]int, len(columns))
for i, colName := range columns {
columnIndex[colName] = i
}
scanArgs := make([]interface{}, len(columns))
values := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
return err
}
rowCopy := make([]interface{}, len(values))
for i, val := range values {
rowCopy[i] = cloneScannedValue(val)
}
row := &StarResult{
result: rowCopy,
columns: columns,
columnIndex: columnIndex,
columnsType: types,
}
if err := fn(row); err != nil {
if errors.Is(err, ErrScanStopped) {
return nil
}
return err
}
}
return rows.Err()
}