265 lines
6.9 KiB
Go
265 lines
6.9 KiB
Go
package stardb
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// StarTx represents a database transaction
|
|
type StarTx struct {
|
|
tx *sql.Tx
|
|
db *StarDB
|
|
}
|
|
|
|
func (t *StarTx) ensureTx() error {
|
|
if t == nil || t.tx == nil || t.db == nil {
|
|
return ErrTxNotInitialized
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Query executes a query within the transaction
|
|
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
|
|
return t.query(nil, query, args...)
|
|
}
|
|
|
|
// QueryContext executes a query with context within the transaction
|
|
func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
return t.query(ctx, query, args...)
|
|
}
|
|
|
|
// QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing.
|
|
func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
|
|
return t.queryRaw(nil, query, args...)
|
|
}
|
|
|
|
// QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing.
|
|
func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
return t.queryRaw(ctx, query, args...)
|
|
}
|
|
|
|
func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
if err := t.ensureTx(); err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
|
|
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
|
|
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
|
if beforeHook != nil {
|
|
beforeHook(hookCtx, query, hookArgs)
|
|
}
|
|
start := time.Now()
|
|
|
|
var (
|
|
rows *sql.Rows
|
|
err error
|
|
)
|
|
if ctx == nil {
|
|
rows, err = t.tx.Query(query, args...)
|
|
} else {
|
|
rows, err = t.tx.QueryContext(ctx, query, args...)
|
|
}
|
|
duration := time.Since(start)
|
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
afterHook(hookCtx, query, hookArgs, duration, err)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return rows, nil
|
|
}
|
|
|
|
// query is the internal query implementation
|
|
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
rows, err := t.queryRaw(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
starRows := &StarRows{
|
|
rows: rows,
|
|
db: t.db,
|
|
}
|
|
|
|
if !t.db.ManualScan {
|
|
if err := starRows.parse(); err != nil {
|
|
_ = rows.Close()
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return starRows, nil
|
|
}
|
|
|
|
// Exec executes a query within the transaction
|
|
func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
return t.exec(nil, query, args...)
|
|
}
|
|
|
|
// ExecContext executes a query with context within the transaction
|
|
func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
return t.exec(ctx, query, args...)
|
|
}
|
|
|
|
// exec is the internal exec implementation
|
|
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
if err := t.ensureTx(); err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
|
|
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
|
|
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
|
if beforeHook != nil {
|
|
beforeHook(hookCtx, query, hookArgs)
|
|
}
|
|
start := time.Now()
|
|
|
|
var (
|
|
result sql.Result
|
|
err error
|
|
)
|
|
if ctx == nil {
|
|
result, err = t.tx.Exec(query, args...)
|
|
} else {
|
|
result, err = t.tx.ExecContext(ctx, query, args...)
|
|
}
|
|
duration := time.Since(start)
|
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
afterHook(hookCtx, query, hookArgs, duration, err)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// Prepare creates a prepared statement within the transaction
|
|
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
|
|
if err := t.ensureTx(); err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
|
|
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
|
|
hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook)
|
|
if beforeHook != nil {
|
|
beforeHook(hookCtx, query, nil)
|
|
}
|
|
start := time.Now()
|
|
|
|
stmt, err := t.tx.Prepare(query)
|
|
duration := time.Since(start)
|
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
afterHook(hookCtx, query, nil, duration, err)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
|
|
}
|
|
|
|
// PrepareContext creates a prepared statement with context
|
|
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
|
if err := t.ensureTx(); err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
|
|
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
|
|
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
|
if beforeHook != nil {
|
|
beforeHook(hookCtx, query, nil)
|
|
}
|
|
start := time.Now()
|
|
|
|
stmt, err := t.tx.PrepareContext(ctx, query)
|
|
duration := time.Since(start)
|
|
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
afterHook(hookCtx, query, nil, duration, err)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
|
|
}
|
|
|
|
// QueryStmt executes a prepared statement query within the transaction
|
|
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
stmt, err := t.Prepare(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
return stmt.Query(args...)
|
|
}
|
|
|
|
// QueryStmtContext executes a prepared statement query with context
|
|
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
stmt, err := t.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
return stmt.QueryContext(ctx, args...)
|
|
}
|
|
|
|
// ExecStmt executes a prepared statement within the transaction
|
|
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
stmt, err := t.Prepare(query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
return stmt.Exec(args...)
|
|
}
|
|
|
|
// ExecStmtContext executes a prepared statement with context
|
|
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
if strings.TrimSpace(query) == "" {
|
|
return nil, ErrQueryEmpty
|
|
}
|
|
stmt, err := t.PrepareContext(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer stmt.Close()
|
|
return stmt.ExecContext(ctx, args...)
|
|
}
|
|
|
|
// Commit commits the transaction
|
|
func (t *StarTx) Commit() error {
|
|
if err := t.ensureTx(); err != nil {
|
|
return err
|
|
}
|
|
return t.tx.Commit()
|
|
}
|
|
|
|
// Rollback rolls back the transaction
|
|
func (t *StarTx) Rollback() error {
|
|
if err := t.ensureTx(); err != nil {
|
|
return err
|
|
}
|
|
return t.tx.Rollback()
|
|
}
|