package stardb import ( "context" "database/sql" "strings" "time" ) // StarTx represents a database transaction type StarTx struct { tx *sql.Tx db *StarDB } func (t *StarTx) ensureTx() error { if t == nil || t.tx == nil || t.db == nil { return ErrTxNotInitialized } return nil } // Query executes a query within the transaction func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) { return t.query(nil, query, args...) } // QueryContext executes a query with context within the transaction func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { return t.query(ctx, query, args...) } // QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing. func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) { return t.queryRaw(nil, query, args...) } // QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing. func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { return t.queryRaw(ctx, query, args...) } func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { if err := t.ensureTx(); err != nil { return nil, err } if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args) hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook) if beforeHook != nil { beforeHook(hookCtx, query, hookArgs) } start := time.Now() var ( rows *sql.Rows err error ) if ctx == nil { rows, err = t.tx.Query(query, args...) } else { rows, err = t.tx.QueryContext(ctx, query, args...) } duration := time.Since(start) if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { afterHook(hookCtx, query, hookArgs, duration, err) } if err != nil { return nil, err } return rows, nil } // query is the internal query implementation func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { rows, err := t.queryRaw(ctx, query, args...) if err != nil { return nil, err } starRows := &StarRows{ rows: rows, db: t.db, } if !t.db.ManualScan { if err := starRows.parse(); err != nil { _ = rows.Close() return nil, err } } return starRows, nil } // Exec executes a query within the transaction func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) { return t.exec(nil, query, args...) } // ExecContext executes a query with context within the transaction func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { return t.exec(ctx, query, args...) } // exec is the internal exec implementation func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { if err := t.ensureTx(); err != nil { return nil, err } if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args) hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook) if beforeHook != nil { beforeHook(hookCtx, query, hookArgs) } start := time.Now() var ( result sql.Result err error ) if ctx == nil { result, err = t.tx.Exec(query, args...) } else { result, err = t.tx.ExecContext(ctx, query, args...) } duration := time.Since(start) if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { afterHook(hookCtx, query, hookArgs, duration, err) } if err != nil { return nil, err } return result, nil } // Prepare creates a prepared statement within the transaction func (t *StarTx) Prepare(query string) (*StarStmt, error) { if err := t.ensureTx(); err != nil { return nil, err } if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil) hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook) if beforeHook != nil { beforeHook(hookCtx, query, nil) } start := time.Now() stmt, err := t.tx.Prepare(query) duration := time.Since(start) if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { afterHook(hookCtx, query, nil, duration, err) } if err != nil { return nil, err } return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil } // PrepareContext creates a prepared statement with context func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) { if err := t.ensureTx(); err != nil { return nil, err } if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil) hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook) if beforeHook != nil { beforeHook(hookCtx, query, nil) } start := time.Now() stmt, err := t.tx.PrepareContext(ctx, query) duration := time.Since(start) if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { afterHook(hookCtx, query, nil, duration, err) } if err != nil { return nil, err } return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil } // QueryStmt executes a prepared statement query within the transaction func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) { if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } stmt, err := t.Prepare(query) if err != nil { return nil, err } defer stmt.Close() return stmt.Query(args...) } // QueryStmtContext executes a prepared statement query with context func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } stmt, err := t.PrepareContext(ctx, query) if err != nil { return nil, err } defer stmt.Close() return stmt.QueryContext(ctx, args...) } // ExecStmt executes a prepared statement within the transaction func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) { if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } stmt, err := t.Prepare(query) if err != nil { return nil, err } defer stmt.Close() return stmt.Exec(args...) } // ExecStmtContext executes a prepared statement with context func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { if strings.TrimSpace(query) == "" { return nil, ErrQueryEmpty } stmt, err := t.PrepareContext(ctx, query) if err != nil { return nil, err } defer stmt.Close() return stmt.ExecContext(ctx, args...) } // Commit commits the transaction func (t *StarTx) Commit() error { if err := t.ensureTx(); err != nil { return err } return t.tx.Commit() } // Rollback rolls back the transaction func (t *StarTx) Rollback() error { if err := t.ensureTx(); err != nil { return err } return t.tx.Rollback() }