stardb/tx.go

265 lines
6.9 KiB
Go
Raw Permalink Normal View History

2026-03-07 19:27:44 +08:00
package stardb
import (
"context"
"database/sql"
"strings"
"time"
2026-03-07 19:27:44 +08:00
)
// StarTx represents a database transaction
type StarTx struct {
tx *sql.Tx
db *StarDB
}
func (t *StarTx) ensureTx() error {
if t == nil || t.tx == nil || t.db == nil {
return ErrTxNotInitialized
}
return nil
}
2026-03-07 19:27:44 +08:00
// Query executes a query within the transaction
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
return t.query(nil, query, args...)
}
// QueryContext executes a query with context within the transaction
func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
return t.query(ctx, query, args...)
}
// QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing.
func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
return t.queryRaw(nil, query, args...)
}
// QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing.
func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return t.queryRaw(ctx, query, args...)
}
func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
if err := t.ensureTx(); err != nil {
2026-03-07 19:27:44 +08:00
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
2026-03-07 19:27:44 +08:00
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, hookArgs)
}
start := time.Now()
2026-03-07 19:27:44 +08:00
var (
rows *sql.Rows
err error
)
2026-03-07 19:27:44 +08:00
if ctx == nil {
rows, err = t.tx.Query(query, args...)
} else {
rows, err = t.tx.QueryContext(ctx, query, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return rows, nil
}
2026-03-07 19:27:44 +08:00
// query is the internal query implementation
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
rows, err := t.queryRaw(ctx, query, args...)
2026-03-07 19:27:44 +08:00
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: t.db,
}
if !t.db.ManualScan {
if err := starRows.parse(); err != nil {
_ = rows.Close()
return nil, err
}
2026-03-07 19:27:44 +08:00
}
return starRows, nil
2026-03-07 19:27:44 +08:00
}
// Exec executes a query within the transaction
func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) {
return t.exec(nil, query, args...)
}
// ExecContext executes a query with context within the transaction
func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return t.exec(ctx, query, args...)
}
// exec is the internal exec implementation
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if err := t.ensureTx(); err != nil {
2026-03-07 19:27:44 +08:00
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, hookArgs)
}
start := time.Now()
2026-03-07 19:27:44 +08:00
var (
result sql.Result
err error
)
2026-03-07 19:27:44 +08:00
if ctx == nil {
result, err = t.tx.Exec(query, args...)
} else {
result, err = t.tx.ExecContext(ctx, query, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, hookArgs, duration, err)
}
if err != nil {
return nil, err
2026-03-07 19:27:44 +08:00
}
return result, nil
2026-03-07 19:27:44 +08:00
}
// Prepare creates a prepared statement within the transaction
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
if err := t.ensureTx(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, nil)
}
start := time.Now()
2026-03-07 19:27:44 +08:00
stmt, err := t.tx.Prepare(query)
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, nil, duration, err)
}
2026-03-07 19:27:44 +08:00
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
2026-03-07 19:27:44 +08:00
}
// PrepareContext creates a prepared statement with context
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
if err := t.ensureTx(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, nil)
}
start := time.Now()
2026-03-07 19:27:44 +08:00
stmt, err := t.tx.PrepareContext(ctx, query)
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, nil, duration, err)
}
2026-03-07 19:27:44 +08:00
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
2026-03-07 19:27:44 +08:00
}
// QueryStmt executes a prepared statement query within the transaction
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
2026-03-07 19:27:44 +08:00
}
stmt, err := t.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(args...)
}
// QueryStmtContext executes a prepared statement query with context
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
2026-03-07 19:27:44 +08:00
}
stmt, err := t.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.QueryContext(ctx, args...)
}
// ExecStmt executes a prepared statement within the transaction
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
2026-03-07 19:27:44 +08:00
}
stmt, err := t.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Exec(args...)
}
// ExecStmtContext executes a prepared statement with context
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
2026-03-07 19:27:44 +08:00
}
stmt, err := t.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.ExecContext(ctx, args...)
}
// Commit commits the transaction
func (t *StarTx) Commit() error {
if err := t.ensureTx(); err != nil {
return err
}
2026-03-07 19:27:44 +08:00
return t.tx.Commit()
}
// Rollback rolls back the transaction
func (t *StarTx) Rollback() error {
if err := t.ensureTx(); err != nil {
return err
}
2026-03-07 19:27:44 +08:00
return t.tx.Rollback()
}