From e0af498fa49f41ac619e9753488c3c7b281da06b Mon Sep 17 00:00:00 2001 From: starainrt Date: Fri, 20 Mar 2026 13:36:59 +0800 Subject: [PATCH] =?UTF-8?q?bug=20fix:=E4=BF=AE=E5=A4=8D=E5=8F=AF=E8=83=BD?= =?UTF-8?q?=E7=9A=84panic=E7=8A=B6=E6=80=81=EF=BC=9B=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=9B=B4=E5=A4=9A=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 9 +- CHANGELOG.MD | 59 ++ README.MD | 913 +++++++++++++++-------------- batch.go | 260 +++++++- builder.go | 57 +- builder_test.go | 59 ++ converter.go | 156 +---- converter_safe.go | 61 +- errors.go | 64 ++ internal/convert/basic.go | 359 ++++++++++++ internal/scanutil/value_clone.go | 12 + internal/sqlplaceholder/convert.go | 119 ++++ internal/sqlruntime/fingerprint.go | 323 ++++++++++ internal/sqlruntime/hooks.go | 27 + internal/sqlruntime/state.go | 269 +++++++++ orm.go | 250 ++++---- orm_test.go | 191 +++++- pool.go | 4 + reflect.go | 302 +++++++--- result_safe.go | 137 ++++- rows.go | 21 +- rows_internal_test.go | 29 + scan_each.go | 124 ++++ scan_each_orm.go | 119 ++++ sql_placeholder.go | 19 + sql_placeholder_test.go | 34 ++ sql_runtime.go | 276 +++++++++ sql_runtime_test.go | 102 ++++ stardb.go | 324 ++++++++-- stardb_safe_test.go | 127 ++++ testing/batch_test.go | 337 ++++++++++- testing/orm_test.go | 245 +++++++- testing/perf_test.go | 128 ++++ testing/pool_test.go | 26 + testing/result_test.go | 100 ++++ testing/rows_test.go | 6 + testing/stardb_test.go | 770 ++++++++++++++++++++++++ tx.go | 154 ++++- tx_helper.go | 45 ++ 39 files changed, 5643 insertions(+), 974 deletions(-) create mode 100644 CHANGELOG.MD create mode 100644 errors.go create mode 100644 internal/convert/basic.go create mode 100644 internal/scanutil/value_clone.go create mode 100644 internal/sqlplaceholder/convert.go create mode 100644 internal/sqlruntime/fingerprint.go create mode 100644 internal/sqlruntime/hooks.go create mode 100644 internal/sqlruntime/state.go create mode 100644 rows_internal_test.go create mode 100644 scan_each.go create mode 100644 scan_each_orm.go create mode 100644 sql_placeholder.go create mode 100644 sql_placeholder_test.go create mode 100644 sql_runtime.go create mode 100644 sql_runtime_test.go create mode 100644 stardb_safe_test.go create mode 100644 testing/perf_test.go create mode 100644 tx_helper.go diff --git a/.gitignore b/.gitignore index cab69e8..221848e 100644 --- a/.gitignore +++ b/.gitignore @@ -13,7 +13,12 @@ vendor/ # IDE -.idea/ +.idea # OS -.DS_Store \ No newline at end of file +.DS_Store + +# Agent local governance files +.sentrux/ +agent_readme.md +target.md diff --git a/CHANGELOG.MD b/CHANGELOG.MD new file mode 100644 index 0000000..a9fd15f --- /dev/null +++ b/CHANGELOG.MD @@ -0,0 +1,59 @@ +# Changelog + +本文档记录 StarDB 的主要变更。 + +## [Unreleased] - 2026-03-20 + +### Added +- 新增可判定错误类型(`errors.Is` 友好): + - 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized` + - 参数/目标校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` 等 + - 映射与批量写入:`ErrColumnNotFound` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch` 等 +- 新增流式查询能力(DB / Tx / Stmt): + - `QueryRaw` / `QueryRawContext` + - `ScanEach` / `ScanEachContext` + - `ScanEachORM` / `ScanEachORMContext` +- 新增 NULL 安全取值: + - `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime` +- 新增 ORM 行为开关: + - `SetStrictORM(true)` 启用严格列检查 + - `ClearReflectCache()` 清理反射缓存 +- 新增 SQL 运行时可观测能力: + - Hook:`SetSQLHooks` `SetSQLBeforeHook` `SetSQLAfterHook` + - 慢 SQL 阈值:`SetSQLSlowThreshold` + - 指纹:`SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments` + - 指纹计数:`SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters` + - Context 元信息:`SQLHookMetaFromContext` `BatchExecMetaFromContext` +- 新增占位符方言适配: + - `SetPlaceholderStyle(PlaceholderQuestion|PlaceholderDollar)`(`?` / `$1,$2...`) +- 新增批量插入分片控制: + - `SetBatchInsertMaxRows` + - `SetBatchInsertMaxParams` + - 常见驱动参数上限自动识别(SQLite / PostgreSQL / MySQL / SQL Server) + +### Changed +- 批量写入在开启分片或触发参数阈值时,改为事务内多分片执行,降低单条 SQL 过大风险。 +- 分片批量写入结果语义明确: + - `RowsAffected()` 返回分片累计值 + - `LastInsertId()` 返回最后一个分片的 insert id +- 内部结构按模块归档到 `internal/`,保持外部 API 稳定: + - `internal/convert` + - `internal/scanutil` + - `internal/sqlplaceholder` + - `internal/sqlruntime` +- README 重写为面向使用场景的说明,补齐能力边界、接入顺序和 API 细节。 + +### Behavior Notes +- 默认查询 `Query` 仍为内存模式(解析到 `StarRows`)。 +- 关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`。 +- SQL Hook、指纹与指纹计数默认关闭,需显式开启。 +- 批量分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。 + +### Tests +- 新增/补强测试覆盖: + - 流式查询与流式 ORM + - NULL 安全取值 + - 严格 ORM 行为 + - 占位符转换 + - SQL Hook、慢 SQL 阈值、指纹模式、注释保留开关、指纹计数 + - BatchInsert 分片(按行数/参数)、失败回滚与结果语义 diff --git a/README.MD b/README.MD index abd5046..4682f6e 100644 --- a/README.MD +++ b/README.MD @@ -1,535 +1,560 @@ # StarDB -一个轻量级的 Go 数据库封装库,个人学习用,提供简洁的 API 和 ORM 功能。 +StarDB 是对 Go `database/sql` 的轻量封装,目标是把常见的数据库操作做得更直白: +- 少量 API 覆盖日常 CRUD、事务、批量写入、结构体映射。 +- 兼容原生 `database/sql` 心智,不引入重量级依赖。 +- 在可读性、可调试性和性能之间做实用平衡。 -[![Go Version](https://img.shields.io/badge/Go-%3E%3D%201.16-blue)](https://golang.org/) -[![License](https://img.shields.io/badge/license-Apache%20License%202.0-blue)](LICENSE) +适合: +- 想保留 SQL 控制权,但不想反复写样板代码。 +- 需要轻量 ORM 映射(不是全功能 ORM)。 +- 需要在生产里追踪 SQL(可选 Hook + 慢 SQL 阈值)。 -## ✨ 特性 +不适合: +- 需要完整领域模型关系管理、自动迁移、复杂查询 DSL 的项目。 -- ✅ **零第三方依赖** - 仅使用 Go 标准库 -- ✅ **类型安全** - 提供类型安全的结果转换方法 -- ✅ **简单 ORM** - 支持结构体与数据库的自动映射 -- ✅ **Context 支持** - 所有操作都支持 context -- ✅ **事务支持** - 完整的事务操作支持 -- ✅ **预编译语句** - 支持预编译语句以提升性能 -- ✅ **批量操作** - 高效的批量插入功能 -- ✅ **连接池管理** - 便捷的连接池配置 -- ✅ **查询构建器** - 链式调用构建 SQL 查询 - -## 📦 安装 +## 安装 ```bash go get b612.me/stardb ``` -## 🚀 快速开始 +要求: +- Go `>= 1.16` +- 自行选择并导入数据库驱动(本库只封装 `database/sql`) -### 基本使用 +## 常见 DSN 示例 + +下面示例都可以直接用于 `db.Open(driver, dsn)`,替换为实际账号、密码、库名即可。 + +### MySQL(`github.com/go-sql-driver/mysql`) ```go -package main +import _ "github.com/go-sql-driver/mysql" -import ( - "b612.me/stardb" - _ "github.com/mattn/go-sqlite3" -) - -func main() { - // 创建数据库实例 - db := stardb.NewStarDB() - err := db.Open("sqlite3", "test.db") - if err != nil { - panic(err) - } - defer db.Close() - - // 执行查询 - rows, err := db.Query("SELECT * FROM users WHERE age > ?", 18) - if err != nil { - panic(err) - } - defer rows.Close() - - // 遍历结果 - for i := 0; i < rows.Length(); i++ { - row := rows.Row(i) - name := row.MustString("name") - age := row.MustInt("age") - println(name, age) - } +dsn := "app:secret@tcp(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=true&loc=Local" +if err := db.Open("mysql", dsn); err != nil { + log.Fatal(err) } ``` -### ORM 使用 +常用参数说明: +- `charset=utf8mb4`:避免字符集问题。 +- `parseTime=true`:把 `DATETIME/TIMESTAMP` 解析为 `time.Time`。 +- `loc=Local`:指定时间解析时区(也可改成 `Asia/Shanghai`)。 + +### PostgreSQL(`github.com/lib/pq`) + +```go +import _ "github.com/lib/pq" + +dsn := "host=127.0.0.1 port=5432 user=postgres password=secret dbname=demo sslmode=disable" +if err := db.Open("postgres", dsn); err != nil { + log.Fatal(err) +} +``` + +也可以用 URL 形式: + +```go +urlDSN := "postgres://postgres:secret@127.0.0.1:5432/demo?sslmode=disable" +if err := db.Open("postgres", urlDSN); err != nil { + log.Fatal(err) +} +``` + +### SQLite(`modernc.org/sqlite`) + +```go +import _ "modernc.org/sqlite" + +// 文件数据库 +if err := db.Open("sqlite", "file:demo.db"); err != nil { + log.Fatal(err) +} + +// 内存数据库(适合测试) +if err := db.Open("sqlite", "file::memory:?cache=shared"); err != nil { + log.Fatal(err) +} +``` + +Windows 路径建议使用 `file:C:/data/demo.db` 这种写法,跨平台更稳。 + +## 能力概览 + +| 能力 | 主要 API | 说明 | +|---|---|---| +| 连接与连接池 | `Open` `Close` `Ping` `SetPoolConfig` | 保留原生 `sql.DB` 用法 | +| 常规查询 | `Query` `QueryContext` | 自动解析为 `StarRows` | +| 流式查询 | `QueryRaw` `ScanEach` | 大结果集不必全量进内存 | +| 流式 ORM | `ScanEachORM` | 逐行映射结构体 | +| 安全取值 | `Get*` `GetNull*` | 明确错误与 NULL 语义 | +| 结构体 ORM | `rows.Orm` | 支持单个、切片、数组映射 | +| 命名参数 | `QueryX` `ExecX` | `:field` 绑定结构体字段 | +| 结构体写入 | `Insert` `Update` | 通过 `db` tag 生成 SQL | +| 批量写入 | `BatchInsert` `BatchInsertStructs` `SetBatchInsertMaxRows` `SetBatchInsertMaxParams` | 多行插入,支持按行数/参数阈值分片 | +| 事务 | `Begin/Commit/Rollback` `WithTx` | 手动或托管事务 | +| 可观测性 | `SetSQLHooks` `SetSQLSlowThreshold` `SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments` `SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters` `SQLHookMetaFromContext` `BatchExecMetaFromContext` | Before/After Hook,默认关闭,支持指纹策略、命中计数与批量分片元信息 | +| 方言占位符 | `SetPlaceholderStyle` | `?` / `$1,$2...` | +| 查询构建 | `QueryBuilder` | 支持 `JOIN/GROUP BY/HAVING` | + +## 场景选型 + +| 场景 | 首选 API | 说明 | +|---|---|---| +| 中小结果集查询 | `Query` + `rows.Orm` | 读取方便,开发效率高 | +| 大结果集查询 | `ScanEach` / `ScanEachORM` | 逐行处理,避免全量缓存 | +| 需要底层 `Scan` 控制 | `QueryRaw` | 直接返回 `*sql.Rows` | +| 批量写入 | `BatchInsert` + 分片阈值 | 控制单条 SQL 大小与参数数量 | +| SQL 可观测 | `SetSQLHooks` + `SetSQLSlowThreshold` + 指纹配置 | 支持慢 SQL、指纹、分片元信息 | + +## 快速开始 ```go package main import ( + "log" + "b612.me/stardb" - "time" - _ "github.com/mattn/go-sqlite3" + _ "modernc.org/sqlite" ) -// 定义结构体 type User struct { - ID int64 `db:"id"` - Name string `db:"name"` - Email string `db:"email"` - Age int `db:"age"` - Active bool `db:"active"` - CreatedAt time.Time `db:"created_at"` + ID int64 `db:"id"` + Name string `db:"name"` + Age int `db:"age"` } func main() { db := stardb.NewStarDB() - db.Open("sqlite3", "test.db") + if err := db.Open("sqlite", "test.db"); err != nil { + log.Fatal(err) + } defer db.Close() - // 查询单个对象 - rows, _ := db.Query("SELECT * FROM users WHERE id = ?", 1) - defer rows.Close() - - var user User - rows.Orm(&user) - println(user.Name) - - // 查询多个对象 - rows2, _ := db.Query("SELECT * FROM users WHERE age > ?", 18) - defer rows2.Close() - - var users []User - rows2.Orm(&users) - - for _, u := range users { - println(u.Name, u.Age) + rows, err := db.Query("SELECT id, name, age FROM users WHERE age >= ?", 18) + if err != nil { + log.Fatal(err) } + defer rows.Close() + + var users []User + if err := rows.Orm(&users); err != nil { + log.Fatal(err) + } + + log.Printf("users: %d", len(users)) } ``` -### 插入和更新 +## 接入流程 + +按下面顺序接入,可在开发阶段先固定运行边界: + +1. 建立连接并设置连接池。 +2. 在查询路径中区分内存模式与流式模式。 +3. 在批量写入路径设置分片阈值(按行数和参数数)。 +4. 启用 SQL Hook、慢 SQL 阈值、指纹策略。 +5. 在调用侧统一使用 `errors.Is` 判定错误类别。 + +一个常用初始化示例: ```go -// 插入数据 -user := User{ - Name: "Alice", - Email: "alice@example.com", - Age: 25, - Active: true, - CreatedAt: time.Now(), +db := stardb.NewStarDB() +if err := db.Open("mysql", dsn); err != nil { + return err } -result, err := db.Insert(&user, "users", "id") // "id" 是自增字段 -if err != nil { - panic(err) -} +db.SetPoolConfig(&stardb.PoolConfig{ + MaxOpenConns: 25, + MaxIdleConns: 5, + ConnMaxLifetime: time.Hour, + ConnMaxIdleTime: 10 * time.Minute, +}) -lastID, _ := result.LastInsertId() -println("插入的 ID:", lastID) +db.SetBatchInsertMaxRows(500) +db.SetBatchInsertMaxParams(60000) -// 更新数据 -user.Age = 26 -result, err = db.Update(&user, "users", "id") // "id" 是主键 -if err != nil { - panic(err) -} - -affected, _ := result.RowsAffected() -println("更新的行数:", affected) +db.SetSQLSlowThreshold(200 * time.Millisecond) +db.SetSQLFingerprintEnabled(true) +db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) +db.SetSQLFingerprintKeepComments(false) ``` -### 批量插入 +## API 指南 + +### 1) 连接与连接池 + +```go +db := stardb.NewStarDB() +_ = db.Open("mysql", "user:pass@tcp(localhost:3306)/app") + +db.SetPoolConfig(&stardb.PoolConfig{ + MaxOpenConns: 25, + MaxIdleConns: 5, + ConnMaxLifetime: time.Hour, + ConnMaxIdleTime: 10 * time.Minute, +}) +``` + +也可以直接拿到底层连接: + +```go +raw := db.DB() +raw.SetMaxOpenConns(50) +``` + +### 2) 三种查询模式 + +#### A. 内存模式(默认) + +`Query` 会把结果解析为 `StarRows`,适合中小结果集: + +```go +rows, err := db.Query("SELECT * FROM users WHERE active = ?", true) +if err != nil { /* ... */ } +defer rows.Close() + +for i := 0; i < rows.Length(); i++ { + row := rows.Row(i) + _ = row.MustString("name") +} +``` + +#### B. 原生流式模式 + +`QueryRaw` 返回 `*sql.Rows`,完全按原生 `Scan` 处理: + +```go +rawRows, err := db.QueryRaw("SELECT id, name FROM users") +if err != nil { /* ... */ } +defer rawRows.Close() +``` + +#### C. 回调流式模式(常用) + +`ScanEach` 逐行回调,避免全量缓存: + +关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`,不使用 `Query`。 + +```go +err := db.ScanEach("SELECT id, name FROM users", func(row *stardb.StarResult) error { + id := row.MustInt64("id") + name := row.MustString("name") + _ = id + _ = name + return nil +}) +``` + +可通过 `stardb.ErrScanStopped` 提前终止: + +```go +count := 0 +_ = db.ScanEach("SELECT * FROM users", func(row *stardb.StarResult) error { + count++ + if count >= 1000 { + return stardb.ErrScanStopped + } + return nil +}) +``` + +### 3) 流式 ORM(逐行映射) + +`ScanEachORM` 将每行映射到结构体,再回调。 + +```go +var model User +var users []User + +err := db.ScanEachORM("SELECT id, name, age FROM users", &model, func(target interface{}) error { + u := *(target.(*User)) // 注意拷贝一份,target 会被复用 + users = append(users, u) + return nil +}) +``` + +同样支持 `Tx` / `Stmt`: +- `tx.ScanEachORM(...)` +- `stmt.ScanEachORM(...)` + +### 4) 结果读取与 NULL 语义 + +#### Must 系列(无错误,失败给零值) +- `MustString` `MustInt64` `MustFloat64` `MustBool` ... + +#### 安全系列(带错误) +- `GetString` `GetInt64` `GetFloat64` +- `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime` + +```go +name, err := row.GetString("name") +age, err := row.GetNullInt64("age") +if age.Valid { + // use age.Int64 +} +``` + +### 5) ORM 映射 + +```go +type User struct { + ID int64 `db:"id"` + Name string `db:"name"` +} + +var u User +_ = rows.Orm(&u) + +var list []User +_ = rows.Orm(&list) +``` + +严格列检查(字段/SQL 变更敏感场景可开启): + +```go +db.SetStrictORM(true) +``` + +若结构体 tag 大范围调整,可清理反射缓存: + +```go +stardb.ClearReflectCache() +``` + +### 6) 命名参数绑定 + +```go +type Filter struct { + Name string `db:"name"` + MinAge int `db:"min_age"` +} + +f := Filter{Name: "Alice", MinAge: 18} +rows, err := db.QueryX(&f, + "SELECT * FROM users WHERE name = ? AND age >= ?", + ":name", ":min_age") +``` + +### 7) 写入能力 + +#### Insert / Update + +```go +_, _ = db.Insert(&user, "users", "id") // id 作为自增字段跳过 +_, _ = db.Update(&user, "users", "id") // id 作为主键 +``` + +#### BatchInsert ```go -// 方式 1:使用原始数据 columns := []string{"name", "email", "age"} values := [][]interface{}{ {"Alice", "alice@example.com", 25}, {"Bob", "bob@example.com", 30}, - {"Charlie", "charlie@example.com", 35}, } - -result, err := db.BatchInsert("users", columns, values) -if err != nil { - panic(err) -} - -// 方式 2:使用结构体 -users := []User{ - {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, - {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()}, - {Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()}, -} - -result, err = db.BatchInsertStructs("users", users, "id") -if err != nil { - panic(err) -} - -affected, _ := result.RowsAffected() -println("批量插入行数:", affected) +_, _ = db.BatchInsert("users", columns, values) ``` -### 事务操作 +如需避免单条 SQL 过大(参数过多),可打开分片: + +```go +db.SetBatchInsertMaxRows(500) // 0 或负数表示关闭分片(默认) +db.SetBatchInsertMaxParams(60000) // 0 表示自动识别常见驱动参数上限 +``` + +分片模式下会在一个事务里按块执行,避免部分写入成功。 +自动识别当前覆盖:SQLite `999`、PostgreSQL `65535`、MySQL `65535`、SQL Server `2100`。 + +分片行为细节: +- 分片阈值按更严格条件生效:`min(maxRows, maxParams/列数)`(忽略未设置项)。 +- 分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。 +- 分片执行失败会回滚整个批次。 +- 分片结果语义: + - `RowsAffected()` 返回所有分片累计值。 + - `LastInsertId()` 返回最后一个分片的 insert id。 + +#### BatchInsertStructs + +```go +users := []User{{Name: "Alice"}, {Name: "Bob"}} +_, _ = db.BatchInsertStructs("users", users, "id") +``` + +### 8) 事务 + +#### 手动事务 ```go -// 开始事务 tx, err := db.Begin() -if err != nil { - panic(err) -} +if err != nil { /* ... */ } +defer tx.Rollback() -// 执行操作 -_, err = tx.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "Alice", "alice@example.com") -if err != nil { - tx.Rollback() - panic(err) -} - -_, err = tx.Exec("UPDATE accounts SET balance = balance - ? WHERE user_id = ?", 100, 1) -if err != nil { - tx.Rollback() - panic(err) -} - -// 提交事务 -err = tx.Commit() -if err != nil { - panic(err) +if _, err := tx.Exec("UPDATE users SET age = age + 1 WHERE id = ?", 1); err != nil { + return err } +return tx.Commit() ``` -### 预编译语句 +#### 托管事务(常用) ```go -// 创建预编译语句 -stmt, err := db.Prepare("SELECT * FROM users WHERE age > ?") -if err != nil { - panic(err) -} -defer stmt.Close() - -// 多次执行 -rows1, _ := stmt.Query(18) -defer rows1.Close() - -rows2, _ := stmt.Query(25) -defer rows2.Close() - -rows3, _ := stmt.Query(30) -defer rows3.Close() +err := db.WithTx(func(tx *stardb.StarTx) error { + if _, err := tx.Exec("UPDATE users SET age = ? WHERE id = ?", 26, 1); err != nil { + return err + } + if _, err := tx.Exec("INSERT INTO logs (msg) VALUES (?)", "age updated"); err != nil { + return err + } + return nil +}) ``` -### 查询构建器 +`WithTx` 规则: +- `fn` 返回 `nil` -> `Commit` +- `fn` 返回错误 -> `Rollback` +- `fn` panic -> `Rollback` 后继续抛出 panic + +### 9) SQL Hook 与慢 SQL 阈值 + +默认关闭;仅在显式设置时生效。 ```go -// 使用查询构建器 -rows, err := stardb.NewQueryBuilder("users"). - Select("id", "name", "email"). - Where("age > ?", 18). - Where("active = ?", true). - OrderBy("name ASC"). - Limit(10). +db.SetSQLSlowThreshold(200 * time.Millisecond) +db.SetSQLFingerprintEnabled(true) // 可选:在 Hook context 附带 SQL 指纹 +db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) // 可选:指纹里脱敏数字/字符串字面量 +db.SetSQLFingerprintKeepComments(false) // 默认 false:指纹不保留 SQL 注释 +db.SetSQLFingerprintCounterEnabled(true) // 可选:记录指纹命中次数(内存级) + +db.SetSQLHooks( + func(ctx context.Context, query string, args []interface{}) { + // before + }, + func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) { + // after + if hookMeta, ok := stardb.SQLHookMetaFromContext(ctx); ok { + _ = hookMeta.Fingerprint + } + if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok { + // chunked batch insert metadata + // meta.ChunkIndex / meta.ChunkCount / meta.ChunkRows ... + } + }, +) +``` + +阈值行为: +- `threshold <= 0`:`After` 每次都触发 +- `threshold > 0`:仅在“慢于阈值”或“执行出错”时触发 +- 打开 `SetSQLFingerprintEnabled(true)` 后,可从 `SQLHookMetaFromContext` 获取 SQL 指纹 +- 指纹模式:`SQLFingerprintBasic`(默认,仅归一化)/ `SQLFingerprintMaskLiterals`(归一化 + 字面量脱敏) +- `SetSQLFingerprintKeepComments(true)` 可保留注释文本(默认关闭,利于聚合) +- `SetSQLFingerprintCounterEnabled(true)` 后,可通过 `SQLFingerprintCounters()` 查看命中次数,`ResetSQLFingerprintCounters()` 清空 +- 若是分片批量写入,Hook 可通过 `BatchExecMetaFromContext` 读取分片元信息 + +Hook 上下文字段说明: +- `SQLHookMetaFromContext(ctx)`: + - `Fingerprint`:按配置生成的 SQL 指纹。 +- `BatchExecMetaFromContext(ctx)`(仅分片批量写入): + - `ChunkIndex`:当前分片序号(从 1 开始) + - `ChunkCount`:总分片数 + - `ChunkRows`:当前分片行数 + - `TotalRows`:本次批量总行数 + - `ColumnCount`:本次写入列数 + +### 10) 占位符方言 + +```go +db.SetPlaceholderStyle(stardb.PlaceholderQuestion) // 默认 +// 或 +db.SetPlaceholderStyle(stardb.PlaceholderDollar) // ? -> $1,$2... +``` + +### 11) QueryBuilder + +```go +query, args := stardb.NewQueryBuilder("users u"). + Select("u.id", "u.name", "COUNT(o.id) AS order_count"). + Join("LEFT JOIN orders o ON o.user_id = u.id"). + Where("u.active = ?", true). + GroupBy("u.id", "u.name"). + Having("COUNT(o.id) > ?", 2). + OrderBy("order_count DESC"). + Limit(20). Offset(0). - Query(db) + Build() -if err != nil { - panic(err) -} -defer rows.Close() - -var users []User -rows.Orm(&users) +_ = query +_ = args ``` -### 连接池配置 +## 错误处理 + +库内置可判定错误,调用侧使用 `errors.Is` 做分支处理: ```go -// 方式 1:使用默认配置 -db, err := stardb.OpenWithPool("sqlite3", "test.db", nil) -if err != nil { - panic(err) +if errors.Is(err, stardb.ErrDBNotInitialized) { + // 未初始化 } -defer db.Close() - -// 方式 2:自定义配置 -config := &stardb.PoolConfig{ - MaxOpenConns: 50, // 最大打开连接数 - MaxIdleConns: 10, // 最大空闲连接数 - ConnMaxLifetime: 1 * time.Hour, // 连接最大生命周期 - ConnMaxIdleTime: 10 * time.Minute, // 连接最大空闲时间 +if errors.Is(err, stardb.ErrColumnNotFound) { + // 字段/列不匹配 } - -db, err = stardb.OpenWithPool("mysql", "user:pass@tcp(localhost:3306)/dbname", config) -if err != nil { - panic(err) -} -defer db.Close() - -// 方式 3:手动设置 -db := stardb.NewStarDB() -db.Open("postgres", "postgres://user:pass@localhost/dbname") -db.SetPoolConfig(config) -``` - -### 命名参数绑定 - -```go -user := User{ - Name: "Alice", - Age: 25, -} - -// 使用 :fieldname 语法绑定结构体字段 -rows, err := db.QueryX(&user, - "SELECT * FROM users WHERE name = ? AND age > ?", - ":name", ":age") -if err != nil { - panic(err) -} -defer rows.Close() -``` - -## 📖 详细文档 - -### 结果转换方法 - -StarDB 提供了丰富的类型转换方法: - -```go -row := rows.Row(0) - -// 字符串 -name := row.MustString("name") - -// 整数 -age := row.MustInt("age") -id := row.MustInt64("id") -count := row.MustInt32("count") -uid := row.MustUint64("uid") - -// 浮点数 -price := row.MustFloat64("price") -rate := row.MustFloat32("rate") - -// 布尔值 -active := row.MustBool("active") - -// 字节数组 -data := row.MustBytes("data") - -// 时间 -createdAt := row.MustDate("created_at", "2006-01-02 15:04:05") - -// 检查 NULL -isNull := row.IsNil("optional_field") -``` - -### 列操作 - -```go -// 获取某一列的所有值 -col := rows.Col("name") - -names := col.MustString() // []string -ages := col.MustInt() // []int -prices := col.MustFloat64() // []float64 -actives := col.MustBool() // []bool -``` - -### Context 支持 - -所有操作都有对应的 Context 版本: - -```go -ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) -defer cancel() - -// 查询 -rows, err := db.QueryContext(ctx, "SELECT * FROM users") - -// 执行 -result, err := db.ExecContext(ctx, "UPDATE users SET age = ?", 26) - -// 事务 -tx, err := db.BeginTx(ctx, nil) - -// 批量插入 -result, err := db.BatchInsertContext(ctx, "users", columns, values) -``` - -### 错误处理 - -```go -// 使用 Must* 方法(忽略错误,返回零值) -name := row.MustString("name") - -// 使用 Get* 方法(返回错误) -name, err := row.GetString("name") -if err != nil { - // 处理错误 +if errors.Is(err, stardb.ErrNoInsertValues) { + // 批量插入空数据 } ``` -## 🔧 高级用法 +常见错误类别: +- 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized` +- 参数校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` ... +- 映射问题:`ErrColumnNotFound` `ErrFieldNotFound` +- 批量写入:`ErrNoInsertColumns` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch` +- 流式回调:`ErrScanFuncNil` `ErrScanORMFuncNil` -### 嵌套结构体 +## 使用边界 -```go -type Profile struct { - Bio string `db:"bio"` - Avatar string `db:"avatar"` -} +1. 这是轻量封装,不是全功能 ORM。 +- 不做模型关系管理(has-many/association) +- 不做自动迁移 +- 不做复杂查询 DSL -type User struct { - ID int64 `db:"id"` - Name string `db:"name"` - Profile Profile `db:"---"` // 使用 "---" 标记嵌套结构体 -} +2. 大结果集优先用流式 API。 +- `Query` 适合中小结果集 +- `ScanEach` / `ScanEachORM` 更稳 -rows, _ := db.Query("SELECT id, name, bio, avatar FROM users") -defer rows.Close() +3. 日志 Hook 按需打开。 +- 生产环境最好配合慢 SQL 阈值,减少噪音 -var user User -rows.Orm(&user) +4. `ScanEachORM` 回调里的 target 会复用。 +- 需要持久化时请拷贝结构体值 -println(user.Name) -println(user.Profile.Bio) -``` - -### 手动扫描模式 - -```go -db := stardb.NewStarDB() -db.ManualScan = true // 启用手动扫描模式 -db.Open("sqlite3", "test.db") - -rows, _ := db.Query("SELECT * FROM users") -defer rows.Close() - -// 手动触发解析 -rows.Rescan() - -// 现在可以访问数据 -println(rows.Length()) -``` - -### 直接访问底层 *sql.DB - -```go -db := stardb.NewStarDB() -db.Open("sqlite3", "test.db") - -// 获取底层 *sql.DB -rawDB := db.DB() -rawDB.SetMaxOpenConns(100) - -// 或者使用已有的 *sql.DB -sqlDB, _ := sql.Open("sqlite3", "test.db") -db := stardb.NewStarDBWithDB(sqlDB) -``` - -## 🎯 性能优化建议 - -### 1. 使用预编译语句 - -对于重复执行的查询,使用预编译语句可以提升 20-50% 的性能: - -```go -stmt, _ := db.Prepare("SELECT * FROM users WHERE id = ?") -defer stmt.Close() - -for _, id := range userIDs { - rows, _ := stmt.Query(id) - // 处理结果 - rows.Close() -} -``` - -### 2. 批量操作 - -批量插入比单条插入快 2-3 倍: - -```go -// ❌ 慢 -for _, user := range users { - db.Exec("INSERT INTO users (name) VALUES (?)", user.Name) -} - -// ✅ 快 -db.BatchInsertStructs("users", users, "id") -``` - -### 3. 合理配置连接池 - -```go -config := &stardb.PoolConfig{ - MaxOpenConns: 25, // 根据数据库服务器调整 - MaxIdleConns: 5, // 保持少量空闲连接 - ConnMaxLifetime: 1 * time.Hour, - ConnMaxIdleTime: 10 * time.Minute, -} -db.SetPoolConfig(config) -``` - -### 4. 使用事务 - -将多个操作放在一个事务中可以显著提升性能: - -```go -tx, _ := db.Begin() -for _, user := range users { - tx.Exec("INSERT INTO users (name) VALUES (?)", user.Name) -} -tx.Commit() -``` - -## 🧪 测试 - -项目包含完整的单元测试: +## 测试、竞态与基准 ```bash -# 运行所有测试 +# 根模块 +go test ./... + +go test -race ./... + +go test -run ^$ -bench BenchmarkQueryBuilder_ -benchmem ./... + +# testing 子模块(集成测试/基准) cd testing -go test -v - -# 运行特定测试 -go test -v -run TestStarDB_Query - -# 查看覆盖率 -go test -cover - -# 生成覆盖率报告 -go test -coverprofile=coverage.out -go tool cover -html=coverage.out - -# 运行基准测试 -go test -bench=. -benchmem +go test ./... +go test -race ./... +go test -run ^$ -bench "Benchmark(QueryX|Orm|ScanEach|BatchInsert)" -benchmem ``` -## 📊 支持的数据库 +## 支持数据库驱动 -StarDB 支持所有实现了 `database/sql` 接口的数据库驱动: +本库兼容所有实现 `database/sql` 的驱动。常见示例: +- SQLite: `_ "modernc.org/sqlite"` +- MySQL: `_ "github.com/go-sql-driver/mysql"` +- PostgreSQL: `_ "github.com/lib/pq"` -| 数据库 | 驱动 | 导入 | -|--------|------|------| -| SQLite | modernc.org/sqlite | `_ "modernc.org/sqlite"` | -| MySQL | github.com/go-sql-driver/mysql | `_ "github.com/go-sql-driver/mysql"` | -| PostgreSQL | github.com/lib/pq | `_ "github.com/lib/pq"` | -| SQL Server | github.com/denisenkom/go-mssqldb | `_ "github.com/denisenkom/go-mssqldb"` | -| Oracle | github.com/sijms/go-ora/v2 | `_ "github.com/sijms/go-ora/v2"` | +## License - -## 📄 许可证 - -本项目采用 Apache 2.0 许可证 - 详见 [LICENSE](LICENSE) 文件 - -## 🙏 致谢 - -- 感谢 Go 标准库提供的 `database/sql` 包 -- 灵感来源于 xorm、gorm 等优秀的 ORM 框架 - -## 📮 联系方式 - -- 项目主页: https://git.b612.me/b612/stardb.git \ No newline at end of file +Apache License 2.0 diff --git a/batch.go b/batch.go index 3607b03..3375b31 100644 --- a/batch.go +++ b/batch.go @@ -6,8 +6,220 @@ import ( "fmt" "reflect" "strings" + "sync/atomic" ) +type multiSQLResult struct { + results []sql.Result +} + +type batchExecMetaKey struct{} + +// BatchExecMeta contains chunk execution metadata for batch insert operations. +// It is attached to context for chunked execution and can be read in SQL hooks. +type BatchExecMeta struct { + ChunkIndex int + ChunkCount int + ChunkRows int + TotalRows int + ColumnCount int +} + +// BatchExecMetaFromContext extracts batch chunk metadata from context. +func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) { + if ctx == nil { + return BatchExecMeta{}, false + } + meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta) + return meta, ok +} + +func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, batchExecMetaKey{}, meta) +} + +func (m multiSQLResult) LastInsertId() (int64, error) { + if len(m.results) == 0 { + return 0, nil + } + return m.results[len(m.results)-1].LastInsertId() +} + +func (m multiSQLResult) RowsAffected() (int64, error) { + var total int64 + for _, result := range m.results { + affected, err := result.RowsAffected() + if err != nil { + return 0, err + } + total += affected + } + return total, nil +} + +// SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs. +// <= 0 disables splitting and keeps single-statement behavior. +func (s *StarDB) SetBatchInsertMaxRows(maxRows int) { + if s == nil { + return + } + if maxRows < 0 { + maxRows = 0 + } + atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows)) +} + +// BatchInsertMaxRows returns current batch split threshold. +// 0 means disabled. +func (s *StarDB) BatchInsertMaxRows() int { + if s == nil { + return 0 + } + value := atomic.LoadInt64(&s.batchInsertMaxRows) + if value <= 0 { + return 0 + } + return int(value) +} + +// SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs. +// <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers. +func (s *StarDB) SetBatchInsertMaxParams(maxParams int) { + if s == nil { + return + } + if maxParams < 0 { + maxParams = 0 + } + atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams)) +} + +// BatchInsertMaxParams returns configured max bind parameter threshold. +// 0 means auto mode. +func (s *StarDB) BatchInsertMaxParams() int { + if s == nil { + return 0 + } + value := atomic.LoadInt64(&s.batchInsertMaxParams) + if value <= 0 { + return 0 + } + return int(value) +} + +func detectBatchInsertMaxParams(db *sql.DB) int { + if db == nil { + return 0 + } + + driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver())) + switch { + case strings.Contains(driverType, "sqlite"): + // Keep conservative default for wide compatibility. + return 999 + case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"): + return 65535 + case strings.Contains(driverType, "mysql"): + return 65535 + case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"): + return 2100 + default: + return 0 + } +} + +func minPositive(a, b int) int { + if a <= 0 { + return b + } + if b <= 0 { + return a + } + if a < b { + return a + } + return b +} + +func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) { + maxRows := s.BatchInsertMaxRows() + maxParams := s.BatchInsertMaxParams() + if maxParams <= 0 { + maxParams = detectBatchInsertMaxParams(s.db) + } + + maxRowsByParams := 0 + if maxParams > 0 { + maxRowsByParams = maxParams / columnCount + if maxRowsByParams <= 0 { + return 0, ErrBatchInsertMaxParamsTooLow + } + } + + return minPositive(maxRows, maxRowsByParams), nil +} + +func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) { + placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)" + placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", + tableName, + strings.Join(columns, ", "), + placeholders) + + args := make([]interface{}, 0, len(values)*len(columns)) + for _, row := range values { + args = append(args, row...) + } + + return query, args +} + +func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) { + if ctx == nil { + ctx = context.Background() + } + + tx, err := s.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + + chunkCount := (len(values) + chunkSize - 1) / chunkSize + results := make([]sql.Result, 0, chunkCount) + for start := 0; start < len(values); start += chunkSize { + end := start + chunkSize + if end > len(values) { + end = len(values) + } + chunkIndex := start/chunkSize + 1 + query, args := buildBatchInsertQuery(tableName, columns, values[start:end]) + chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{ + ChunkIndex: chunkIndex, + ChunkCount: chunkCount, + ChunkRows: end - start, + TotalRows: len(values), + ColumnCount: len(columns), + }) + result, execErr := tx.exec(chunkCtx, query, args...) + if execErr != nil { + _ = tx.Rollback() + return nil, execErr + } + results = append(results, result) + } + + if err := tx.Commit(); err != nil { + _ = tx.Rollback() + return nil, err + } + + return multiSQLResult{results: results}, nil +} + // BatchInsert performs batch insert operation // Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}}) func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) { @@ -21,26 +233,33 @@ func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, colum // batchInsert is the internal implementation func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) { + if strings.TrimSpace(tableName) == "" { + return nil, ErrTableNameEmpty + } + + if len(columns) == 0 { + return nil, ErrNoInsertColumns + } + if len(values) == 0 { - return nil, fmt.Errorf("no values to insert") + return nil, ErrNoInsertValues } - // Build placeholders: (?, ?), (?, ?), ... - placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)" - placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup - - // Build SQL - query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", - tableName, - strings.Join(columns, ", "), - placeholders) - - // Flatten values - var args []interface{} - for _, row := range values { - args = append(args, row...) + for i, row := range values { + if len(row) != len(columns) { + return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns)) + } } + chunkSize, err := s.batchInsertChunkSize(len(columns)) + if err != nil { + return nil, err + } + if chunkSize > 0 && len(values) > chunkSize { + return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize) + } + + query, args := buildBatchInsertQuery(tableName, columns, values) return s.exec(ctx, query, args...) } @@ -56,18 +275,25 @@ func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string // batchInsertStructs is the internal implementation func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { + if structs == nil { + return nil, ErrStructsNil + } + // Get slice of structs targetValue := reflect.ValueOf(structs) if targetValue.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return nil, ErrStructsPointerNil + } targetValue = targetValue.Elem() } if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array { - return nil, fmt.Errorf("structs must be a slice or array") + return nil, ErrStructsNotSlice } if targetValue.Len() == 0 { - return nil, fmt.Errorf("no structs to insert") + return nil, ErrNoStructsToInsert } // Get field names from first struct diff --git a/builder.go b/builder.go index 7cf5762..8e43b92 100644 --- a/builder.go +++ b/builder.go @@ -7,13 +7,17 @@ import ( // QueryBuilder helps build SQL queries type QueryBuilder struct { - table string - columns []string - where []string - whereArgs []interface{} - orderBy string - limit int - offset int + table string + columns []string + joins []string + where []string + whereArgs []interface{} + groupBy []string + having []string + havingArgs []interface{} + orderBy string + limit int + offset int } // NewQueryBuilder creates a new query builder @@ -37,6 +41,25 @@ func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuild return qb } +// Join adds a JOIN clause. +func (qb *QueryBuilder) Join(clause string) *QueryBuilder { + qb.joins = append(qb.joins, clause) + return qb +} + +// GroupBy sets GROUP BY columns. +func (qb *QueryBuilder) GroupBy(columns ...string) *QueryBuilder { + qb.groupBy = append(qb.groupBy, columns...) + return qb +} + +// Having adds a HAVING condition. +func (qb *QueryBuilder) Having(condition string, args ...interface{}) *QueryBuilder { + qb.having = append(qb.having, condition) + qb.havingArgs = append(qb.havingArgs, args...) + return qb +} + // OrderBy sets the ORDER BY clause func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder { qb.orderBy = orderBy @@ -63,11 +86,26 @@ func (qb *QueryBuilder) Build() (string, []interface{}) { parts = append(parts, fmt.Sprintf("SELECT %s FROM %s", strings.Join(qb.columns, ", "), qb.table)) + // JOIN + if len(qb.joins) > 0 { + parts = append(parts, strings.Join(qb.joins, " ")) + } + // WHERE if len(qb.where) > 0 { parts = append(parts, "WHERE "+strings.Join(qb.where, " AND ")) } + // GROUP BY + if len(qb.groupBy) > 0 { + parts = append(parts, "GROUP BY "+strings.Join(qb.groupBy, ", ")) + } + + // HAVING + if len(qb.having) > 0 { + parts = append(parts, "HAVING "+strings.Join(qb.having, " AND ")) + } + // ORDER BY if qb.orderBy != "" { parts = append(parts, "ORDER BY "+qb.orderBy) @@ -83,7 +121,10 @@ func (qb *QueryBuilder) Build() (string, []interface{}) { parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset)) } - return strings.Join(parts, " "), qb.whereArgs + args := make([]interface{}, 0, len(qb.whereArgs)+len(qb.havingArgs)) + args = append(args, qb.whereArgs...) + args = append(args, qb.havingArgs...) + return strings.Join(parts, " "), args } // Query executes the query diff --git a/builder_test.go b/builder_test.go index c838553..1ce5d51 100644 --- a/builder_test.go +++ b/builder_test.go @@ -271,6 +271,21 @@ func TestQueryBuilder_Chaining(t *testing.T) { if qb != qb6 { t.Error("Offset should return the same builder instance") } + + qb7 := qb.Join("LEFT JOIN orders o ON o.user_id = users.id") + if qb != qb7 { + t.Error("Join should return the same builder instance") + } + + qb8 := qb.GroupBy("users.id") + if qb != qb8 { + t.Error("GroupBy should return the same builder instance") + } + + qb9 := qb.Having("COUNT(o.id) > ?", 1) + if qb != qb9 { + t.Error("Having should return the same builder instance") + } } func TestQueryBuilder_EmptyWhere(t *testing.T) { @@ -439,6 +454,50 @@ func TestQueryBuilder_JoinLikeWhere(t *testing.T) { } } +func TestQueryBuilder_Build_WithJoinGroupByHaving(t *testing.T) { + qb := NewQueryBuilder("users u"). + Select("u.id", "u.name", "COUNT(o.id) AS order_count"). + Join("LEFT JOIN orders o ON o.user_id = u.id"). + Where("u.active = ?", true). + GroupBy("u.id", "u.name"). + Having("COUNT(o.id) > ?", 2). + OrderBy("order_count DESC") + + query, args := qb.Build() + + expectedQuery := "SELECT u.id, u.name, COUNT(o.id) AS order_count FROM users u LEFT JOIN orders o ON o.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(o.id) > ? ORDER BY order_count DESC" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + expectedArgs := []interface{}{true, 2} + if len(args) != len(expectedArgs) { + t.Fatalf("Expected %d args, got %d", len(expectedArgs), len(args)) + } + for i, expected := range expectedArgs { + if args[i] != expected { + t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i]) + } + } +} + +func TestQueryBuilder_Build_HavingWithoutWhere(t *testing.T) { + qb := NewQueryBuilder("orders"). + Select("user_id", "COUNT(*) AS cnt"). + GroupBy("user_id"). + Having("COUNT(*) >= ?", 3) + + query, args := qb.Build() + + expectedQuery := "SELECT user_id, COUNT(*) AS cnt FROM orders GROUP BY user_id HAVING COUNT(*) >= ?" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + if len(args) != 1 || args[0] != 3 { + t.Errorf("Expected args [3], got %v", args) + } +} + // Benchmark tests func BenchmarkQueryBuilder_Simple(b *testing.B) { for i := 0; i < b.N; i++ { diff --git a/converter.go b/converter.go index 1e0020e..a8d6ff5 100644 --- a/converter.go +++ b/converter.go @@ -1,176 +1,32 @@ package stardb import ( - "strconv" + internalconv "b612.me/stardb/internal/convert" "time" ) // convertToInt64 converts any value to int64 func convertToInt64(val interface{}) int64 { - switch v := val.(type) { - case nil: - return 0 - case int: - return int64(v) - case int32: - return int64(v) - case int64: - return v - case uint64: - return int64(v) - case float32: - return int64(v) - case float64: - return int64(v) - case string: - result, _ := strconv.ParseInt(v, 10, 64) - return result - case bool: - if v { - return 1 - } - return 0 - case time.Time: - return v.Unix() - case []byte: - result, _ := strconv.ParseInt(string(v), 10, 64) - return result - default: - return 0 - } + return internalconv.ToInt64(val) } // convertToUint64 converts any value to uint64 func convertToUint64(val interface{}) uint64 { - switch v := val.(type) { - case nil: - return 0 - case int: - return uint64(v) - case int32: - return uint64(v) - case int64: - return uint64(v) - case uint64: - return v - case float32: - return uint64(v) - case float64: - return uint64(v) - case string: - result, _ := strconv.ParseUint(v, 10, 64) - return result - case bool: - if v { - return 1 - } - return 0 - case time.Time: - return uint64(v.Unix()) - case []byte: - result, _ := strconv.ParseUint(string(v), 10, 64) - return result - default: - return 0 - } + return internalconv.ToUint64(val) } // convertToFloat64 converts any value to float64 func convertToFloat64(val interface{}) float64 { - switch v := val.(type) { - case nil: - return 0 - case int: - return float64(v) - case int32: - return float64(v) - case int64: - return float64(v) - case uint64: - return float64(v) - case float32: - return float64(v) - case float64: - return v - case string: - result, _ := strconv.ParseFloat(v, 64) - return result - case bool: - if v { - return 1 - } - return 0 - case time.Time: - return float64(v.Unix()) - case []byte: - result, _ := strconv.ParseFloat(string(v), 64) - return result - default: - return 0 - } + return internalconv.ToFloat64(val) } // convertToBool converts any value to bool // Non-zero numbers are considered true func convertToBool(val interface{}) bool { - switch v := val.(type) { - case nil: - return false - case bool: - return v - case int: - return v != 0 - case int32: - return v != 0 - case int64: - return v != 0 - case uint64: - return v != 0 - case float32: - return v != 0 - case float64: - return v != 0 - case string: - result, _ := strconv.ParseBool(v) - return result - case []byte: - result, _ := strconv.ParseBool(string(v)) - return result - default: - return false - } + return internalconv.ToBool(val) } // convertToTime converts any value to time.Time func convertToTime(val interface{}, layout string) time.Time { - switch v := val.(type) { - case nil: - return time.Time{} - case time.Time: - return v - case int: - return time.Unix(int64(v), 0) - case int32: - return time.Unix(int64(v), 0) - case int64: - return time.Unix(v, 0) - case uint64: - return time.Unix(int64(v), 0) - case float32: - sec := int64(v) - nsec := int64((v - float32(sec)) * 1e9) - return time.Unix(sec, nsec) - case float64: - sec := int64(v) - nsec := int64((v - float64(sec)) * 1e9) - return time.Unix(sec, nsec) - case string: - result, _ := time.Parse(layout, v) - return result - case []byte: - result, _ := time.Parse(layout, string(v)) - return result - default: - return time.Time{} - } + return internalconv.ToTime(val, layout) } diff --git a/converter_safe.go b/converter_safe.go index 9793c82..e635098 100644 --- a/converter_safe.go +++ b/converter_safe.go @@ -1,68 +1,13 @@ package stardb -import ( - "fmt" - "strconv" - "time" -) +import internalconv "b612.me/stardb/internal/convert" // ConvertToInt64Safe converts any value to int64 with error handling func ConvertToInt64Safe(val interface{}) (int64, error) { - switch v := val.(type) { - case nil: - return 0, nil - case int: - return int64(v), nil - case int32: - return int64(v), nil - case int64: - return v, nil - case uint64: - return int64(v), nil - case float32: - return int64(v), nil - case float64: - return int64(v), nil - case string: - return strconv.ParseInt(v, 10, 64) - case bool: - if v { - return 1, nil - } - return 0, nil - case time.Time: - return v.Unix(), nil - case []byte: - return strconv.ParseInt(string(v), 10, 64) - default: - return 0, fmt.Errorf("cannot convert %T to int64", val) - } + return internalconv.ToInt64Safe(val) } // ConvertToStringSafe converts any value to string with error handling func ConvertToStringSafe(val interface{}) (string, error) { - switch v := val.(type) { - case nil: - return "", nil - case string: - return v, nil - case int: - return strconv.Itoa(v), nil - case int32: - return strconv.FormatInt(int64(v), 10), nil - case int64: - return strconv.FormatInt(v, 10), nil - case float32: - return strconv.FormatFloat(float64(v), 'f', -1, 32), nil - case float64: - return strconv.FormatFloat(v, 'f', -1, 64), nil - case bool: - return strconv.FormatBool(v), nil - case time.Time: - return v.String(), nil - case []byte: - return string(v), nil - default: - return "", fmt.Errorf("cannot convert %T to string", val) - } + return internalconv.ToStringSafe(val) } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..1c9b3f1 --- /dev/null +++ b/errors.go @@ -0,0 +1,64 @@ +package stardb + +import ( + "errors" + "fmt" +) + +var ( + // Lifecycle errors. + ErrDBNotInitialized = errors.New("database is not initialized; call Open or SetDB first") + ErrTxNotInitialized = errors.New("transaction is not initialized") + ErrStmtNotInitialized = errors.New("statement is not initialized") + ErrStmtDBNotInitialized = errors.New("statement database context is not initialized") + + // SQL input errors. + ErrQueryEmpty = errors.New("query string cannot be empty") + ErrScanStopped = errors.New("scan stopped by callback") + ErrScanFuncNil = errors.New("scan callback cannot be nil") + ErrScanORMFuncNil = errors.New("scan orm callback cannot be nil") + + // Mapping and schema errors. + ErrColumnNotFound = errors.New("column not found") + ErrFieldNotFound = errors.New("field not found") + ErrRowIndexOutOfRange = errors.New("row index out of range") + + // Target validation errors. + ErrTargetNil = errors.New("target cannot be nil") + ErrTargetsNil = errors.New("targets cannot be nil") + ErrTargetNotPointer = errors.New("target must be a pointer") + ErrTargetPointerNil = errors.New("target pointer cannot be nil") + ErrTargetsPointerNil = errors.New("targets pointer is nil") + ErrTargetNotStruct = errors.New("target is not a struct") + ErrTargetNotWritable = errors.New("target is not writable") + ErrPointerTargetNil = errors.New("pointer target is nil") + + // SQL builder errors. + ErrTableNameEmpty = errors.New("table name cannot be empty") + ErrPrimaryKeyRequired = errors.New("at least one primary key is required") + ErrPrimaryKeyEmpty = errors.New("primary key cannot be empty") + ErrNoInsertColumns = errors.New("no columns to insert") + ErrNoInsertValues = errors.New("no values to insert") + ErrBatchInsertMaxParamsTooLow = errors.New("batch insert max params is lower than column count") + ErrNoUpdateFields = errors.New("no fields to update after excluding primary keys") + ErrBatchRowValueCountMismatch = errors.New("row values count does not match columns") + ErrStructsNil = errors.New("structs cannot be nil") + ErrStructsPointerNil = errors.New("structs pointer is nil") + ErrStructsNotSlice = errors.New("structs must be a slice or array") + ErrNoStructsToInsert = errors.New("no structs to insert") + + // Transaction helper errors. + ErrTxFuncNil = errors.New("transaction callback cannot be nil") +) + +func wrapColumnNotFound(column string) error { + return fmt.Errorf("%w: %s", ErrColumnNotFound, column) +} + +func wrapFieldNotFound(field string) error { + return fmt.Errorf("%w: %s", ErrFieldNotFound, field) +} + +func wrapBatchRowValueCountMismatch(rowIndex, got, expected int) error { + return fmt.Errorf("%w: row %d has %d values, expected %d", ErrBatchRowValueCountMismatch, rowIndex, got, expected) +} diff --git a/internal/convert/basic.go b/internal/convert/basic.go new file mode 100644 index 0000000..6786078 --- /dev/null +++ b/internal/convert/basic.go @@ -0,0 +1,359 @@ +package convert + +import ( + "fmt" + "strconv" + "strings" + "time" +) + +var defaultNullTimeLayouts = []string{ + time.RFC3339Nano, + time.RFC3339, + "2006-01-02 15:04:05", + "2006-01-02", +} + +// ToInt64 converts any value to int64. +func ToInt64(val interface{}) int64 { + switch v := val.(type) { + case nil: + return 0 + case int: + return int64(v) + case int32: + return int64(v) + case int64: + return v + case uint64: + return int64(v) + case float32: + return int64(v) + case float64: + return int64(v) + case string: + result, _ := strconv.ParseInt(v, 10, 64) + return result + case bool: + if v { + return 1 + } + return 0 + case time.Time: + return v.Unix() + case []byte: + result, _ := strconv.ParseInt(string(v), 10, 64) + return result + default: + return 0 + } +} + +// ToUint64 converts any value to uint64. +func ToUint64(val interface{}) uint64 { + switch v := val.(type) { + case nil: + return 0 + case int: + return uint64(v) + case int32: + return uint64(v) + case int64: + return uint64(v) + case uint64: + return v + case float32: + return uint64(v) + case float64: + return uint64(v) + case string: + result, _ := strconv.ParseUint(v, 10, 64) + return result + case bool: + if v { + return 1 + } + return 0 + case time.Time: + return uint64(v.Unix()) + case []byte: + result, _ := strconv.ParseUint(string(v), 10, 64) + return result + default: + return 0 + } +} + +// ToFloat64 converts any value to float64. +func ToFloat64(val interface{}) float64 { + switch v := val.(type) { + case nil: + return 0 + case int: + return float64(v) + case int32: + return float64(v) + case int64: + return float64(v) + case uint64: + return float64(v) + case float32: + return float64(v) + case float64: + return v + case string: + result, _ := strconv.ParseFloat(v, 64) + return result + case bool: + if v { + return 1 + } + return 0 + case time.Time: + return float64(v.Unix()) + case []byte: + result, _ := strconv.ParseFloat(string(v), 64) + return result + default: + return 0 + } +} + +// ToBool converts any value to bool. +func ToBool(val interface{}) bool { + switch v := val.(type) { + case nil: + return false + case bool: + return v + case int: + return v != 0 + case int32: + return v != 0 + case int64: + return v != 0 + case uint64: + return v != 0 + case float32: + return v != 0 + case float64: + return v != 0 + case string: + result, _ := strconv.ParseBool(v) + return result + case []byte: + result, _ := strconv.ParseBool(string(v)) + return result + default: + return false + } +} + +// ToTime converts any value to time.Time. +func ToTime(val interface{}, layout string) time.Time { + switch v := val.(type) { + case nil: + return time.Time{} + case time.Time: + return v + case int: + return time.Unix(int64(v), 0) + case int32: + return time.Unix(int64(v), 0) + case int64: + return time.Unix(v, 0) + case uint64: + return time.Unix(int64(v), 0) + case float32: + sec := int64(v) + nsec := int64((v - float32(sec)) * 1e9) + return time.Unix(sec, nsec) + case float64: + sec := int64(v) + nsec := int64((v - float64(sec)) * 1e9) + return time.Unix(sec, nsec) + case string: + result, _ := time.Parse(layout, v) + return result + case []byte: + result, _ := time.Parse(layout, string(v)) + return result + default: + return time.Time{} + } +} + +// ToInt64Safe converts any value to int64 with error handling. +func ToInt64Safe(val interface{}) (int64, error) { + switch v := val.(type) { + case nil: + return 0, nil + case int: + return int64(v), nil + case int32: + return int64(v), nil + case int64: + return v, nil + case uint64: + return int64(v), nil + case float32: + return int64(v), nil + case float64: + return int64(v), nil + case string: + return strconv.ParseInt(v, 10, 64) + case bool: + if v { + return 1, nil + } + return 0, nil + case time.Time: + return v.Unix(), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + default: + return 0, fmt.Errorf("cannot convert %T to int64", val) + } +} + +// ToStringSafe converts any value to string with error handling. +func ToStringSafe(val interface{}) (string, error) { + switch v := val.(type) { + case nil: + return "", nil + case string: + return v, nil + case int: + return strconv.Itoa(v), nil + case int32: + return strconv.FormatInt(int64(v), 10), nil + case int64: + return strconv.FormatInt(v, 10), nil + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32), nil + case float64: + return strconv.FormatFloat(v, 'f', -1, 64), nil + case bool: + return strconv.FormatBool(v), nil + case time.Time: + return v.String(), nil + case []byte: + return string(v), nil + default: + return "", fmt.Errorf("cannot convert %T to string", val) + } +} + +// ToFloat64Safe converts any value to float64 with error handling. +func ToFloat64Safe(val interface{}) (float64, error) { + switch v := val.(type) { + case nil: + return 0, nil + case float64: + return v, nil + case float32: + return float64(v), nil + case int, int32, int64, uint64: + intVal, err := ToInt64Safe(v) + return float64(intVal), err + case string: + return strconv.ParseFloat(v, 64) + case []byte: + return strconv.ParseFloat(string(v), 64) + default: + return 0, fmt.Errorf("cannot convert %T to float64", val) + } +} + +// ToBoolSafe converts any value to bool with error handling. +func ToBoolSafe(val interface{}) (bool, error) { + switch v := val.(type) { + case nil: + return false, nil + case bool: + return v, nil + case int: + return v != 0, nil + case int8: + return v != 0, nil + case int16: + return v != 0, nil + case int32: + return v != 0, nil + case int64: + return v != 0, nil + case uint: + return v != 0, nil + case uint8: + return v != 0, nil + case uint16: + return v != 0, nil + case uint32: + return v != 0, nil + case uint64: + return v != 0, nil + case float32: + return v != 0, nil + case float64: + return v != 0, nil + case string: + return ParseBoolString(v) + case []byte: + return ParseBoolString(string(v)) + default: + return false, fmt.Errorf("cannot convert %T to bool", val) + } +} + +// ParseBoolString parses string-like bool values. +func ParseBoolString(raw string) (bool, error) { + normalized := strings.TrimSpace(strings.ToLower(raw)) + switch normalized { + case "", "0", "false", "f", "off", "no", "n": + return false, nil + case "1", "true", "t", "on", "yes", "y": + return true, nil + default: + return false, fmt.Errorf("cannot parse bool value: %q", raw) + } +} + +// ToTimeSafe converts any value to time.Time with error handling. +func ToTimeSafe(val interface{}) (time.Time, error) { + switch v := val.(type) { + case nil: + return time.Time{}, nil + case time.Time: + return v, nil + case int: + return time.Unix(int64(v), 0), nil + case int64: + return time.Unix(v, 0), nil + case string: + return ParseTimeValue(v) + case []byte: + return ParseTimeValue(string(v)) + default: + return time.Time{}, fmt.Errorf("cannot convert %T to time.Time", val) + } +} + +// ParseTimeValue parses common SQL date-time formats and unix timestamp. +func ParseTimeValue(raw string) (time.Time, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return time.Time{}, nil + } + + for _, layout := range defaultNullTimeLayouts { + if t, err := time.Parse(layout, trimmed); err == nil { + return t, nil + } + } + + if ts, err := strconv.ParseInt(trimmed, 10, 64); err == nil { + return time.Unix(ts, 0), nil + } + + return time.Time{}, fmt.Errorf("cannot parse time value: %q", raw) +} diff --git a/internal/scanutil/value_clone.go b/internal/scanutil/value_clone.go new file mode 100644 index 0000000..c3f5437 --- /dev/null +++ b/internal/scanutil/value_clone.go @@ -0,0 +1,12 @@ +package scanutil + +// CloneScannedValue copies driver-scanned values that may be reused by driver. +// []byte is deep-copied; other types are returned as-is. +func CloneScannedValue(val interface{}) interface{} { + if b, ok := val.([]byte); ok { + copied := make([]byte, len(b)) + copy(copied, b) + return copied + } + return val +} diff --git a/internal/sqlplaceholder/convert.go b/internal/sqlplaceholder/convert.go new file mode 100644 index 0000000..0ca431d --- /dev/null +++ b/internal/sqlplaceholder/convert.go @@ -0,0 +1,119 @@ +package sqlplaceholder + +import ( + "strconv" + "strings" +) + +// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text. +// It skips quoted strings, quoted identifiers and comments. +func ConvertQuestionToDollarPlaceholders(query string) string { + if query == "" || !strings.Contains(query, "?") { + return query + } + + const ( + stateNormal = iota + stateSingleQuote + stateDoubleQuote + stateBacktick + stateLineComment + stateBlockComment + ) + + var b strings.Builder + b.Grow(len(query) + 8) + + state := stateNormal + index := 1 + + for i := 0; i < len(query); i++ { + c := query[i] + + switch state { + case stateNormal: + if c == '\'' { + state = stateSingleQuote + b.WriteByte(c) + continue + } + if c == '"' { + state = stateDoubleQuote + b.WriteByte(c) + continue + } + if c == '`' { + state = stateBacktick + b.WriteByte(c) + continue + } + if c == '-' && i+1 < len(query) && query[i+1] == '-' { + state = stateLineComment + b.WriteByte(c) + i++ + b.WriteByte(query[i]) + continue + } + if c == '/' && i+1 < len(query) && query[i+1] == '*' { + state = stateBlockComment + b.WriteByte(c) + i++ + b.WriteByte(query[i]) + continue + } + if c == '?' { + b.WriteByte('$') + b.WriteString(strconv.Itoa(index)) + index++ + continue + } + b.WriteByte(c) + + case stateSingleQuote: + b.WriteByte(c) + if c == '\'' { + // SQL escaped single quote: '' + if i+1 < len(query) && query[i+1] == '\'' { + i++ + b.WriteByte(query[i]) + continue + } + state = stateNormal + } + + case stateDoubleQuote: + b.WriteByte(c) + if c == '"' { + // escaped double quote: "" + if i+1 < len(query) && query[i+1] == '"' { + i++ + b.WriteByte(query[i]) + continue + } + state = stateNormal + } + + case stateBacktick: + b.WriteByte(c) + if c == '`' { + state = stateNormal + } + + case stateLineComment: + b.WriteByte(c) + if c == '\n' { + state = stateNormal + } + + case stateBlockComment: + b.WriteByte(c) + if c == '*' && i+1 < len(query) && query[i+1] == '/' { + i++ + b.WriteByte(query[i]) + state = stateNormal + } + } + } + + return b.String() +} diff --git a/internal/sqlruntime/fingerprint.go b/internal/sqlruntime/fingerprint.go new file mode 100644 index 0000000..a90a7aa --- /dev/null +++ b/internal/sqlruntime/fingerprint.go @@ -0,0 +1,323 @@ +package sqlruntime + +import "strings" + +// FingerprintSQL creates a normalized SQL fingerprint. +// mode controls literal masking; keepComments controls whether comments are preserved. +func FingerprintSQL(query string, mode int, keepComments bool) string { + prepared := query + if !keepComments { + prepared = stripSQLComments(prepared) + } + + normalized := normalizeSQL(prepared) + if normalized == "" { + return "" + } + + if NormalizeFingerprintMode(mode) == fingerprintModeMaskLiterals { + return maskSQLLiterals(normalized, keepComments) + } + return normalized +} + +func normalizeSQL(query string) string { + normalized := strings.ToLower(strings.TrimSpace(query)) + if normalized == "" { + return "" + } + return strings.Join(strings.Fields(normalized), " ") +} + +func stripSQLComments(query string) string { + if query == "" { + return "" + } + + const ( + stateNormal = iota + stateSingleQuote + stateDoubleQuote + stateBacktick + stateLineComment + stateBlockComment + ) + + var b strings.Builder + b.Grow(len(query)) + state := stateNormal + + for i := 0; i < len(query); i++ { + c := query[i] + + switch state { + case stateNormal: + if c == '\'' { + state = stateSingleQuote + b.WriteByte(c) + continue + } + if c == '"' { + state = stateDoubleQuote + b.WriteByte(c) + continue + } + if c == '`' { + state = stateBacktick + b.WriteByte(c) + continue + } + if c == '-' && i+1 < len(query) && query[i+1] == '-' { + b.WriteByte(' ') + i++ + state = stateLineComment + continue + } + if c == '/' && i+1 < len(query) && query[i+1] == '*' { + b.WriteByte(' ') + i++ + state = stateBlockComment + continue + } + b.WriteByte(c) + + case stateSingleQuote: + b.WriteByte(c) + if c == '\'' { + if i+1 < len(query) && query[i+1] == '\'' { + i++ + b.WriteByte(query[i]) + continue + } + state = stateNormal + } + + case stateDoubleQuote: + b.WriteByte(c) + if c == '"' { + if i+1 < len(query) && query[i+1] == '"' { + i++ + b.WriteByte(query[i]) + continue + } + state = stateNormal + } + + case stateBacktick: + b.WriteByte(c) + if c == '`' { + state = stateNormal + } + + case stateLineComment: + if c == '\n' { + b.WriteByte(' ') + state = stateNormal + } + + case stateBlockComment: + if c == '*' && i+1 < len(query) && query[i+1] == '/' { + b.WriteByte(' ') + i++ + state = stateNormal + } + } + } + + return b.String() +} + +func maskSQLLiterals(query string, keepComments bool) string { + if query == "" { + return "" + } + + const ( + stateNormal = iota + stateSingleQuote + stateDoubleQuote + stateBacktick + stateLineComment + stateBlockComment + ) + + var b strings.Builder + b.Grow(len(query)) + state := stateNormal + + for i := 0; i < len(query); i++ { + c := query[i] + + switch state { + case stateNormal: + if c == '\'' { + b.WriteByte('?') + state = stateSingleQuote + continue + } + if c == '"' { + b.WriteByte(c) + state = stateDoubleQuote + continue + } + if c == '`' { + b.WriteByte(c) + state = stateBacktick + continue + } + if c == '-' && i+1 < len(query) && query[i+1] == '-' { + if keepComments { + b.WriteByte(c) + i++ + b.WriteByte(query[i]) + } else { + b.WriteByte(' ') + i++ + } + state = stateLineComment + continue + } + if c == '/' && i+1 < len(query) && query[i+1] == '*' { + if keepComments { + b.WriteByte(c) + i++ + b.WriteByte(query[i]) + } else { + b.WriteByte(' ') + i++ + } + state = stateBlockComment + continue + } + if c == '$' { + j := i + 1 + for j < len(query) && isDigit(query[j]) { + j++ + } + if j > i+1 { + b.WriteByte('?') + i = j - 1 + continue + } + } + if c == '-' && i+1 < len(query) && isDigit(query[i+1]) && isNumberBoundaryBefore(query, i) { + j := scanNumber(query, i+1) + if isNumberBoundaryAfter(query, j) { + b.WriteByte('?') + i = j - 1 + continue + } + } + if isDigit(c) && isNumberBoundaryBefore(query, i) { + j := scanNumber(query, i) + if isNumberBoundaryAfter(query, j) { + b.WriteByte('?') + i = j - 1 + continue + } + } + b.WriteByte(c) + + case stateSingleQuote: + if c == '\'' { + if i+1 < len(query) && query[i+1] == '\'' { + i++ + continue + } + state = stateNormal + } + + case stateDoubleQuote: + b.WriteByte(c) + if c == '"' { + if i+1 < len(query) && query[i+1] == '"' { + i++ + b.WriteByte(query[i]) + continue + } + state = stateNormal + } + + case stateBacktick: + b.WriteByte(c) + if c == '`' { + state = stateNormal + } + + case stateLineComment: + if keepComments { + b.WriteByte(c) + } + if c == '\n' { + if !keepComments { + b.WriteByte(' ') + } + state = stateNormal + } + + case stateBlockComment: + if keepComments { + b.WriteByte(c) + } + if c == '*' && i+1 < len(query) && query[i+1] == '/' { + if keepComments { + i++ + b.WriteByte(query[i]) + } else { + b.WriteByte(' ') + i++ + } + state = stateNormal + } + } + } + + return strings.Join(strings.Fields(b.String()), " ") +} + +func isDigit(c byte) bool { + return c >= '0' && c <= '9' +} + +func isIdentifierChar(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_' +} + +func isNumberBoundaryBefore(query string, index int) bool { + if index <= 0 { + return true + } + prev := query[index-1] + return !isIdentifierChar(prev) && prev != '$' && prev != '.' +} + +func isNumberBoundaryAfter(query string, index int) bool { + if index >= len(query) { + return true + } + next := query[index] + return !isIdentifierChar(next) && next != '.' +} + +func scanNumber(query string, start int) int { + i := start + for i < len(query) && isDigit(query[i]) { + i++ + } + if i < len(query) && query[i] == '.' { + i++ + for i < len(query) && isDigit(query[i]) { + i++ + } + } + if i < len(query) && (query[i] == 'e' || query[i] == 'E') { + i++ + if i < len(query) && (query[i] == '+' || query[i] == '-') { + i++ + } + for i < len(query) && isDigit(query[i]) { + i++ + } + } + return i +} diff --git a/internal/sqlruntime/hooks.go b/internal/sqlruntime/hooks.go new file mode 100644 index 0000000..9eb93d8 --- /dev/null +++ b/internal/sqlruntime/hooks.go @@ -0,0 +1,27 @@ +package sqlruntime + +import "time" + +// CloneHookArgs creates a shallow copy for hook consumers to avoid mutation races. +func CloneHookArgs(args []interface{}) []interface{} { + if len(args) == 0 { + return nil + } + copied := make([]interface{}, len(args)) + copy(copied, args) + return copied +} + +// ShouldRunAfterHook decides whether after-hook should run. +func ShouldRunAfterHook(hasAfterHook bool, slowThreshold, duration time.Duration, err error) bool { + if !hasAfterHook { + return false + } + if err != nil { + return true + } + if slowThreshold <= 0 { + return true + } + return duration >= slowThreshold +} diff --git a/internal/sqlruntime/state.go b/internal/sqlruntime/state.go new file mode 100644 index 0000000..30312f9 --- /dev/null +++ b/internal/sqlruntime/state.go @@ -0,0 +1,269 @@ +package sqlruntime + +import ( + "sync" + "time" +) + +const ( + placeholderQuestion = 0 + placeholderDollar = 1 + + fingerprintModeBasic = 0 + fingerprintModeMaskLiterals = 1 +) + +// NormalizePlaceholderStyle converts unknown style values to default question style. +func NormalizePlaceholderStyle(style int) int { + switch style { + case placeholderDollar: + return placeholderDollar + default: + return placeholderQuestion + } +} + +// NormalizeFingerprintMode converts unknown mode values to default basic mode. +func NormalizeFingerprintMode(mode int) int { + switch mode { + case fingerprintModeMaskLiterals: + return fingerprintModeMaskLiterals + default: + return fingerprintModeBasic + } +} + +// State stores runtime SQL behavior toggles in a thread-safe manner. +type State struct { + mu sync.RWMutex + beforeHook interface{} + afterHook interface{} + placeholder int + slowThreshold time.Duration + fingerprintEnabled bool + fingerprintMode int + fingerprintKeepComments bool + fingerprintCounterEnabled bool + fingerprintCounts map[string]uint64 +} + +// Options returns snapshot of current runtime options. +func (s *State) Options() (before, after interface{}, placeholder int, slowThreshold time.Duration) { + if s == nil { + return nil, nil, placeholderQuestion, 0 + } + s.mu.RLock() + before = s.beforeHook + after = s.afterHook + placeholder = NormalizePlaceholderStyle(s.placeholder) + slowThreshold = s.slowThreshold + s.mu.RUnlock() + return before, after, placeholder, slowThreshold +} + +// Hooks returns before/after hooks and slow threshold. +func (s *State) Hooks() (before, after interface{}, slowThreshold time.Duration) { + before, after, _, slowThreshold = s.Options() + return before, after, slowThreshold +} + +// SetHooks sets before/after hooks. +func (s *State) SetHooks(before, after interface{}) { + if s == nil { + return + } + s.mu.Lock() + s.beforeHook = before + s.afterHook = after + s.mu.Unlock() +} + +// SetBeforeHook sets before hook. +func (s *State) SetBeforeHook(before interface{}) { + if s == nil { + return + } + s.mu.Lock() + s.beforeHook = before + s.mu.Unlock() +} + +// SetAfterHook sets after hook. +func (s *State) SetAfterHook(after interface{}) { + if s == nil { + return + } + s.mu.Lock() + s.afterHook = after + s.mu.Unlock() +} + +// SetPlaceholderStyle sets placeholder style. +func (s *State) SetPlaceholderStyle(style int) { + if s == nil { + return + } + s.mu.Lock() + s.placeholder = NormalizePlaceholderStyle(style) + s.mu.Unlock() +} + +// PlaceholderStyle returns placeholder style. +func (s *State) PlaceholderStyle() int { + if s == nil { + return placeholderQuestion + } + s.mu.RLock() + style := NormalizePlaceholderStyle(s.placeholder) + s.mu.RUnlock() + return style +} + +// SetSlowThreshold sets minimum duration for triggering after hook. +func (s *State) SetSlowThreshold(threshold time.Duration) { + if s == nil { + return + } + if threshold < 0 { + threshold = 0 + } + s.mu.Lock() + s.slowThreshold = threshold + s.mu.Unlock() +} + +// SlowThreshold returns current slow threshold. +func (s *State) SlowThreshold() time.Duration { + if s == nil { + return 0 + } + s.mu.RLock() + threshold := s.slowThreshold + s.mu.RUnlock() + return threshold +} + +// SetFingerprintEnabled toggles SQL fingerprint metadata generation for hooks. +func (s *State) SetFingerprintEnabled(enabled bool) { + if s == nil { + return + } + s.mu.Lock() + s.fingerprintEnabled = enabled + s.mu.Unlock() +} + +// FingerprintEnabled reports whether SQL fingerprint metadata generation is enabled. +func (s *State) FingerprintEnabled() bool { + if s == nil { + return false + } + s.mu.RLock() + enabled := s.fingerprintEnabled + s.mu.RUnlock() + return enabled +} + +// SetFingerprintMode sets SQL fingerprint mode. +func (s *State) SetFingerprintMode(mode int) { + if s == nil { + return + } + s.mu.Lock() + s.fingerprintMode = NormalizeFingerprintMode(mode) + s.mu.Unlock() +} + +// FingerprintMode returns SQL fingerprint mode. +func (s *State) FingerprintMode() int { + if s == nil { + return fingerprintModeBasic + } + s.mu.RLock() + mode := NormalizeFingerprintMode(s.fingerprintMode) + s.mu.RUnlock() + return mode +} + +// SetFingerprintKeepComments toggles comment preservation in generated SQL fingerprints. +func (s *State) SetFingerprintKeepComments(keep bool) { + if s == nil { + return + } + s.mu.Lock() + s.fingerprintKeepComments = keep + s.mu.Unlock() +} + +// FingerprintKeepComments reports whether comments are kept in generated SQL fingerprints. +func (s *State) FingerprintKeepComments() bool { + if s == nil { + return false + } + s.mu.RLock() + keep := s.fingerprintKeepComments + s.mu.RUnlock() + return keep +} + +// SetFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter. +func (s *State) SetFingerprintCounterEnabled(enabled bool) { + if s == nil { + return + } + s.mu.Lock() + s.fingerprintCounterEnabled = enabled + s.mu.Unlock() +} + +// FingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled. +func (s *State) FingerprintCounterEnabled() bool { + if s == nil { + return false + } + s.mu.RLock() + enabled := s.fingerprintCounterEnabled + s.mu.RUnlock() + return enabled +} + +// IncFingerprintCount increments hit count for a fingerprint. +func (s *State) IncFingerprintCount(fingerprint string) { + if s == nil || fingerprint == "" { + return + } + s.mu.Lock() + if s.fingerprintCounts == nil { + s.fingerprintCounts = make(map[string]uint64) + } + s.fingerprintCounts[fingerprint]++ + s.mu.Unlock() +} + +// FingerprintCountsSnapshot returns a snapshot copy of fingerprint counters. +func (s *State) FingerprintCountsSnapshot() map[string]uint64 { + if s == nil { + return map[string]uint64{} + } + s.mu.RLock() + if len(s.fingerprintCounts) == 0 { + s.mu.RUnlock() + return map[string]uint64{} + } + out := make(map[string]uint64, len(s.fingerprintCounts)) + for k, v := range s.fingerprintCounts { + out[k] = v + } + s.mu.RUnlock() + return out +} + +// ResetFingerprintCounts clears all fingerprint counters. +func (s *State) ResetFingerprintCounts() { + if s == nil { + return + } + s.mu.Lock() + s.fingerprintCounts = nil + s.mu.Unlock() +} diff --git a/orm.go b/orm.go index 2453b88..1c2b12e 100644 --- a/orm.go +++ b/orm.go @@ -3,7 +3,6 @@ package stardb import ( "context" "database/sql" - "errors" "fmt" "reflect" "strings" @@ -23,20 +22,27 @@ func (r *StarRows) Orm(target interface{}) error { } } + if target == nil { + return ErrTargetNil + } + targetType := reflect.TypeOf(target) targetValue := reflect.ValueOf(target) if targetType.Kind() != reflect.Ptr { - return errors.New("target must be a pointer") + return ErrTargetNotPointer + } + if targetValue.IsNil() { + return ErrTargetPointerNil } targetType = targetType.Elem() targetValue = targetValue.Elem() - // Handle slice/array - if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + // Handle slice + if targetValue.Kind() == reflect.Slice { elementType := targetType.Elem() - result := reflect.New(targetType).Elem() + result := reflect.MakeSlice(targetType, 0, r.Length()) if r.Length() == 0 { targetValue.Set(result) @@ -55,6 +61,29 @@ func (r *StarRows) Orm(target interface{}) error { return nil } + // Handle array + if targetValue.Kind() == reflect.Array { + elementType := targetType.Elem() + + if r.Length() == 0 { + return nil + } + + if r.Length() > targetValue.Len() { + return fmt.Errorf("target array length %d is smaller than rows %d", targetValue.Len(), r.Length()) + } + + for i := 0; i < r.Length(); i++ { + element := reflect.New(elementType) + if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil { + return err + } + targetValue.Index(i).Set(element.Elem()) + } + + return nil + } + // Handle single struct if r.Length() == 0 { return nil @@ -63,6 +92,35 @@ func (r *StarRows) Orm(target interface{}) error { return r.setStructFieldsFromRow(target, "db", 0) } +func bindNamedArgs(args []interface{}, fieldValues map[string]interface{}) ([]interface{}, error) { + processedArgs := make([]interface{}, len(args)) + for i, arg := range args { + str, ok := arg.(string) + if !ok { + processedArgs[i] = arg + continue + } + + if strings.HasPrefix(str, `\:`) { + processedArgs[i] = str[1:] + continue + } + + if strings.HasPrefix(str, ":") { + fieldName := str[1:] + val, exists := fieldValues[fieldName] + if !exists { + return nil, wrapFieldNotFound(fieldName) + } + processedArgs[i] = val + continue + } + + processedArgs[i] = arg + } + return processedArgs, nil +} + // QueryX executes a query with named parameter binding // Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id") func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) { @@ -81,25 +139,9 @@ func (s *StarDB) queryX(ctx context.Context, target interface{}, query string, a return nil, err } - // Replace named parameters with actual values - processedArgs := make([]interface{}, len(args)) - for i, arg := range args { - if str, ok := arg.(string); ok { - if strings.HasPrefix(str, ":") { - fieldName := str[1:] - if val, exists := fieldValues[fieldName]; exists { - processedArgs[i] = val - } else { - processedArgs[i] = "" - } - } else if strings.HasPrefix(str, `\:`) { - processedArgs[i] = str[1:] - } else { - processedArgs[i] = arg - } - } else { - processedArgs[i] = arg - } + processedArgs, err := bindNamedArgs(args, fieldValues) + if err != nil { + return nil, err } return s.query(ctx, query, processedArgs...) @@ -118,16 +160,23 @@ func (s *StarDB) QueryXSContext(ctx context.Context, targets interface{}, query // queryXS is the internal implementation func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { var results []*StarRows + if targets == nil { + return results, ErrTargetsNil + } targetType := reflect.TypeOf(targets) targetValue := reflect.ValueOf(targets) if targetType.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return results, ErrTargetsPointerNil + } targetType = targetType.Elem() targetValue = targetValue.Elem() } if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + results = make([]*StarRows, 0, targetValue.Len()) for i := 0; i < targetValue.Len(); i++ { result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...) if err != nil { @@ -163,25 +212,9 @@ func (s *StarDB) execX(ctx context.Context, target interface{}, query string, ar return nil, err } - // Replace named parameters with actual values - processedArgs := make([]interface{}, len(args)) - for i, arg := range args { - if str, ok := arg.(string); ok { - if strings.HasPrefix(str, ":") { - fieldName := str[1:] - if val, exists := fieldValues[fieldName]; exists { - processedArgs[i] = val - } else { - processedArgs[i] = "" - } - } else if strings.HasPrefix(str, `\:`) { - processedArgs[i] = str[1:] - } else { - processedArgs[i] = arg - } - } else { - processedArgs[i] = arg - } + processedArgs, err := bindNamedArgs(args, fieldValues) + if err != nil { + return nil, err } return s.exec(ctx, query, processedArgs...) @@ -200,16 +233,23 @@ func (s *StarDB) ExecXSContext(ctx context.Context, targets interface{}, query s // execXS is the internal implementation func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { var results []sql.Result + if targets == nil { + return results, ErrTargetsNil + } targetType := reflect.TypeOf(targets) targetValue := reflect.ValueOf(targets) if targetType.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return results, ErrTargetsPointerNil + } targetType = targetType.Elem() targetValue = targetValue.Elem() } if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + results = make([]sql.Result, 0, targetValue.Len()) for i := 0; i < targetValue.Len(); i++ { result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...) if err != nil { @@ -246,9 +286,9 @@ func (s *StarDB) insert(ctx context.Context, target interface{}, tableName strin return nil, err } - args := []interface{}{} - for _, param := range params { - args = append(args, param) + args := make([]interface{}, len(params)) + for i, param := range params { + args[i] = param } return s.execX(ctx, target, query, args...) @@ -272,9 +312,9 @@ func (s *StarDB) update(ctx context.Context, target interface{}, tableName strin return nil, err } - args := []interface{}{} - for _, param := range params { - args = append(args, param) + args := make([]interface{}, len(params)) + for i, param := range params { + args[i] = param } return s.execX(ctx, target, query, args...) @@ -282,6 +322,10 @@ func (s *StarDB) update(ctx context.Context, target interface{}, tableName strin // buildInsertSQL builds an INSERT SQL statement func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) { + if strings.TrimSpace(tableName) == "" { + return "", []string{}, ErrTableNameEmpty + } + fieldNames, err := getStructFieldNames(target, "db") if err != nil { return "", []string{}, err @@ -290,17 +334,14 @@ func buildInsertSQL(target interface{}, tableName string, autoIncrementFields .. var columns []string var placeholders []string var params []string + autoIncrementSet := make(map[string]struct{}, len(autoIncrementFields)) + for _, autoField := range autoIncrementFields { + autoIncrementSet[autoField] = struct{}{} + } for _, fieldName := range fieldNames { // Skip auto-increment fields - isAutoIncrement := false - for _, autoField := range autoIncrementFields { - if fieldName == autoField { - isAutoIncrement = true - break - } - } - if isAutoIncrement { + if _, isAutoIncrement := autoIncrementSet[fieldName]; isAutoIncrement { continue } @@ -309,6 +350,10 @@ func buildInsertSQL(target interface{}, tableName string, autoIncrementFields .. params = append(params, ":"+fieldName) } + if len(columns) == 0 { + return "", []string{}, ErrNoInsertColumns + } + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", tableName, strings.Join(columns, ", "), @@ -319,20 +364,43 @@ func buildInsertSQL(target interface{}, tableName string, autoIncrementFields .. // buildUpdateSQL builds an UPDATE SQL statement func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) { + if strings.TrimSpace(tableName) == "" { + return "", []string{}, ErrTableNameEmpty + } + fieldNames, err := getStructFieldNames(target, "db") if err != nil { return "", []string{}, err } + if len(primaryKeys) == 0 { + return "", []string{}, ErrPrimaryKeyRequired + } + + primaryKeySet := make(map[string]struct{}, len(primaryKeys)) + for _, pk := range primaryKeys { + if pk == "" { + return "", []string{}, ErrPrimaryKeyEmpty + } + primaryKeySet[pk] = struct{}{} + } + var setClauses []string var params []string // Build SET clause for _, fieldName := range fieldNames { + if _, isPrimaryKey := primaryKeySet[fieldName]; isPrimaryKey { + continue + } setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName)) params = append(params, ":"+fieldName) } + if len(setClauses) == 0 { + return "", []string{}, ErrNoUpdateFields + } + // Build WHERE clause var whereClauses []string for _, pk := range primaryKeys { @@ -367,24 +435,9 @@ func (t *StarTx) queryX(ctx context.Context, target interface{}, query string, a return nil, err } - processedArgs := make([]interface{}, len(args)) - for i, arg := range args { - if str, ok := arg.(string); ok { - if strings.HasPrefix(str, ":") { - fieldName := str[1:] - if val, exists := fieldValues[fieldName]; exists { - processedArgs[i] = val - } else { - processedArgs[i] = "" - } - } else if strings.HasPrefix(str, `\:`) { - processedArgs[i] = str[1:] - } else { - processedArgs[i] = arg - } - } else { - processedArgs[i] = arg - } + processedArgs, err := bindNamedArgs(args, fieldValues) + if err != nil { + return nil, err } return t.query(ctx, query, processedArgs...) @@ -403,16 +456,23 @@ func (t *StarTx) QueryXSContext(ctx context.Context, targets interface{}, query // queryXS is the internal implementation func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { var results []*StarRows + if targets == nil { + return results, ErrTargetsNil + } targetType := reflect.TypeOf(targets) targetValue := reflect.ValueOf(targets) if targetType.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return results, ErrTargetsPointerNil + } targetType = targetType.Elem() targetValue = targetValue.Elem() } if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + results = make([]*StarRows, 0, targetValue.Len()) for i := 0; i < targetValue.Len(); i++ { result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...) if err != nil { @@ -448,24 +508,9 @@ func (t *StarTx) execX(ctx context.Context, target interface{}, query string, ar return nil, err } - processedArgs := make([]interface{}, len(args)) - for i, arg := range args { - if str, ok := arg.(string); ok { - if strings.HasPrefix(str, ":") { - fieldName := str[1:] - if val, exists := fieldValues[fieldName]; exists { - processedArgs[i] = val - } else { - processedArgs[i] = "" - } - } else if strings.HasPrefix(str, `\:`) { - processedArgs[i] = str[1:] - } else { - processedArgs[i] = arg - } - } else { - processedArgs[i] = arg - } + processedArgs, err := bindNamedArgs(args, fieldValues) + if err != nil { + return nil, err } return t.exec(ctx, query, processedArgs...) @@ -484,16 +529,23 @@ func (t *StarTx) ExecXSContext(ctx context.Context, targets interface{}, query s // execXS is the internal implementation func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { var results []sql.Result + if targets == nil { + return results, ErrTargetsNil + } targetType := reflect.TypeOf(targets) targetValue := reflect.ValueOf(targets) if targetType.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return results, ErrTargetsPointerNil + } targetType = targetType.Elem() targetValue = targetValue.Elem() } if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + results = make([]sql.Result, 0, targetValue.Len()) for i := 0; i < targetValue.Len(); i++ { result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...) if err != nil { @@ -529,9 +581,9 @@ func (t *StarTx) insert(ctx context.Context, target interface{}, tableName strin return nil, err } - args := []interface{}{} - for _, param := range params { - args = append(args, param) + args := make([]interface{}, len(params)) + for i, param := range params { + args[i] = param } return t.execX(ctx, target, query, args...) @@ -554,9 +606,9 @@ func (t *StarTx) update(ctx context.Context, target interface{}, tableName strin return nil, err } - args := []interface{}{} - for _, param := range params { - args = append(args, param) + args := make([]interface{}, len(params)) + for i, param := range params { + args[i] = param } return t.execX(ctx, target, query, args...) diff --git a/orm_test.go b/orm_test.go index 61c684c..f7e6023 100644 --- a/orm_test.go +++ b/orm_test.go @@ -1,6 +1,8 @@ package stardb import ( + "errors" + "reflect" "testing" "time" ) @@ -27,6 +29,21 @@ type NestedUser struct { Profile `db:"---"` } +type NestedUserPtr struct { + ID int64 `db:"id"` + Name string `db:"name"` + Profile *Profile `db:"---"` +} + +type AutoIDOnly struct { + ID int64 `db:"id"` +} + +type HiddenTagged struct { + ID int64 `db:"id"` + hidden string `db:"hidden"` +} + func TestBuildInsertSQL(t *testing.T) { user := User{ ID: 1, @@ -76,13 +93,183 @@ func TestBuildUpdateSQL(t *testing.T) { t.Fatalf("buildUpdateSQL failed: %v", err) } - expectedQuery := "UPDATE users SET id = ?, name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?" + expectedQuery := "UPDATE users SET name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?" if query != expectedQuery { t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) } - expectedParamCount := 8 // 7 fields + 1 primary key + expectedParamCount := 7 // 6 fields + 1 primary key if len(params) != expectedParamCount { t.Errorf("Expected %d params, got %d", expectedParamCount, len(params)) } } + +func TestBuildInsertSQL_NoColumns(t *testing.T) { + model := AutoIDOnly{ID: 1} + + _, _, err := buildInsertSQL(&model, "users", "id") + if err == nil { + t.Fatal("Expected error when no columns remain to insert, got nil") + } + if !errors.Is(err, ErrNoInsertColumns) { + t.Fatalf("Expected ErrNoInsertColumns, got %v", err) + } +} + +func TestBuildInsertSQL_EmptyTableName(t *testing.T) { + user := User{ + ID: 1, + Name: "Test", + } + + _, _, err := buildInsertSQL(&user, "", "id") + if err == nil { + t.Fatal("Expected error when table name is empty, got nil") + } + if !errors.Is(err, ErrTableNameEmpty) { + t.Fatalf("Expected ErrTableNameEmpty, got %v", err) + } +} + +func TestBuildUpdateSQL_NoPrimaryKey(t *testing.T) { + user := User{ + ID: 1, + Name: "Test", + } + + _, _, err := buildUpdateSQL(&user, "users") + if err == nil { + t.Fatal("Expected error when no primary key is provided, got nil") + } + if !errors.Is(err, ErrPrimaryKeyRequired) { + t.Fatalf("Expected ErrPrimaryKeyRequired, got %v", err) + } +} + +func TestBuildUpdateSQL_EmptyTableName(t *testing.T) { + user := User{ + ID: 1, + Name: "Test", + } + + _, _, err := buildUpdateSQL(&user, "", "id") + if err == nil { + t.Fatal("Expected error when table name is empty, got nil") + } + if !errors.Is(err, ErrTableNameEmpty) { + t.Fatalf("Expected ErrTableNameEmpty, got %v", err) + } +} + +func TestBuildUpdateSQL_OnlyPrimaryKey(t *testing.T) { + model := AutoIDOnly{ID: 1} + + _, _, err := buildUpdateSQL(&model, "users", "id") + if err == nil { + t.Fatal("Expected error when no fields remain for SET clause, got nil") + } + if !errors.Is(err, ErrNoUpdateFields) { + t.Fatalf("Expected ErrNoUpdateFields, got %v", err) + } +} + +func TestGetStructFieldValues_NilNestedPointer(t *testing.T) { + user := NestedUserPtr{ + ID: 1, + Name: "Test", + Profile: nil, + } + + values, err := getStructFieldValues(user, "db") + if err != nil { + t.Fatalf("getStructFieldValues failed: %v", err) + } + + if values["id"] != int64(1) { + t.Errorf("Expected id=1, got %v", values["id"]) + } + if values["name"] != "Test" { + t.Errorf("Expected name=Test, got %v", values["name"]) + } +} + +func TestGetStructFieldNames_NilNestedPointer(t *testing.T) { + user := NestedUserPtr{ + ID: 1, + Name: "Test", + Profile: nil, + } + + names, err := getStructFieldNames(user, "db") + if err != nil { + t.Fatalf("getStructFieldNames failed: %v", err) + } + + expected := []string{"id", "name"} + if !reflect.DeepEqual(names, expected) { + t.Errorf("Expected names %v, got %v", expected, names) + } +} + +func TestGetStructFieldNames_SkipUnexportedField(t *testing.T) { + model := HiddenTagged{ + ID: 1, + hidden: "secret", + } + + names, err := getStructFieldNames(model, "db") + if err != nil { + t.Fatalf("getStructFieldNames failed: %v", err) + } + + expected := []string{"id"} + if !reflect.DeepEqual(names, expected) { + t.Errorf("Expected names %v, got %v", expected, names) + } +} + +func TestGetStructFieldValues_NilTarget(t *testing.T) { + _, err := getStructFieldValues(nil, "db") + if err == nil { + t.Fatal("Expected error for nil target, got nil") + } + if !errors.Is(err, ErrTargetNil) { + t.Fatalf("Expected ErrTargetNil, got %v", err) + } +} + +func TestGetStructFieldNames_NilTarget(t *testing.T) { + _, err := getStructFieldNames(nil, "db") + if err == nil { + t.Fatal("Expected error for nil target, got nil") + } + if !errors.Is(err, ErrTargetNil) { + t.Fatalf("Expected ErrTargetNil, got %v", err) + } +} + +func TestClearReflectCache(t *testing.T) { + type cacheUser struct { + ID int64 `db:"id"` + Name string `db:"name"` + } + + typ := reflect.TypeOf(cacheUser{}) + plan1, err := getStructTagPlan(typ, "db") + if err != nil { + t.Fatalf("getStructTagPlan failed: %v", err) + } + if len(plan1) != 2 { + t.Fatalf("Expected 2 fields in plan, got %d", len(plan1)) + } + + ClearReflectCache() + + plan2, err := getStructTagPlan(typ, "db") + if err != nil { + t.Fatalf("getStructTagPlan after clear failed: %v", err) + } + if len(plan2) != 2 { + t.Fatalf("Expected 2 fields in plan after clear, got %d", len(plan2)) + } +} diff --git a/pool.go b/pool.go index e608ae4..b82284c 100644 --- a/pool.go +++ b/pool.go @@ -24,6 +24,10 @@ func DefaultPoolConfig() *PoolConfig { // SetPoolConfig applies pool configuration to the database func (s *StarDB) SetPoolConfig(config *PoolConfig) { + if s == nil || s.db == nil || config == nil { + return + } + if config.MaxOpenConns > 0 { s.db.SetMaxOpenConns(config.MaxOpenConns) } diff --git a/reflect.go b/reflect.go index 78a73f4..9497f1d 100644 --- a/reflect.go +++ b/reflect.go @@ -1,13 +1,126 @@ package stardb import ( - "errors" "reflect" + "sync" "time" ) +type structTagField struct { + path []int + tag string +} + +type structTagPlanKey struct { + typ reflect.Type + tagKey string +} + +var structTagPlanCache sync.Map + +// ClearReflectCache clears internal reflection metadata cache. +// Useful after schema/tag refactors in long-running processes. +func ClearReflectCache() { + structTagPlanCache = sync.Map{} +} + +func getStructTagPlan(targetType reflect.Type, tagKey string) ([]structTagField, error) { + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + } + if targetType.Kind() != reflect.Struct { + return nil, ErrTargetNotStruct + } + + cacheKey := structTagPlanKey{ + typ: targetType, + tagKey: tagKey, + } + if cached, ok := structTagPlanCache.Load(cacheKey); ok { + return cached.([]structTagField), nil + } + + fields := make([]structTagField, 0, targetType.NumField()) + if err := buildStructTagPlan(targetType, tagKey, nil, &fields); err != nil { + return nil, err + } + + structTagPlanCache.Store(cacheKey, fields) + return fields, nil +} + +func buildStructTagPlan(currentType reflect.Type, tagKey string, prefix []int, out *[]structTagField) error { + if currentType.Kind() == reflect.Ptr { + currentType = currentType.Elem() + } + if currentType.Kind() != reflect.Struct { + return ErrTargetNotStruct + } + + for i := 0; i < currentType.NumField(); i++ { + field := currentType.Field(i) + tagValue := field.Tag.Get(tagKey) + fieldType := field.Type + + path := make([]int, len(prefix)+1) + copy(path, prefix) + path[len(prefix)] = i + + if tagValue == "---" { + if fieldType.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct { + if err := buildStructTagPlan(fieldType.Elem(), tagKey, path, out); err != nil { + return err + } + continue + } + if fieldType.Kind() == reflect.Struct { + if err := buildStructTagPlan(fieldType, tagKey, path, out); err != nil { + return err + } + continue + } + } + + if tagValue != "" { + *out = append(*out, structTagField{ + path: path, + tag: tagValue, + }) + } + } + + return nil +} + +func resolveFieldByPath(root reflect.Value, path []int) (reflect.Value, bool) { + current := root + for _, idx := range path { + if current.Kind() == reflect.Ptr { + if current.IsNil() { + return reflect.Value{}, false + } + current = current.Elem() + } + + if current.Kind() != reflect.Struct { + return reflect.Value{}, false + } + if idx < 0 || idx >= current.NumField() { + return reflect.Value{}, false + } + + current = current.Field(idx) + } + + return current, true +} + // setStructFieldsFromRow sets struct fields from a row result using reflection func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error { + if target == nil { + return ErrTargetNil + } + targetType := reflect.TypeOf(target) targetValue := reflect.ValueOf(target) @@ -16,7 +129,7 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row } if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() { - return errors.New("target is not writable") + return ErrTargetNotWritable } if targetType.Kind() == reflect.Ptr { @@ -24,7 +137,12 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row } if targetValue.Kind() != reflect.Struct { - return errors.New("target is not a struct") + return ErrTargetNotStruct + } + + row := r.Row(rowIndex) + if row.columnIndex == nil { + return ErrRowIndexOutOfRange } for i := 0; i < targetType.NumField(); i++ { @@ -32,14 +150,25 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row fieldValue := targetValue.Field(i) tagValue := field.Tag.Get(tagKey) + if !fieldValue.CanInterface() { + continue + } + + // Skip unexported or otherwise non-settable fields. + if !fieldValue.CanSet() { + continue + } + // Handle nested structs - if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { + if fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct { if tagValue == "" { continue } if tagValue == "---" { - nestedPtr := reflect.New(reflect.TypeOf(fieldValue.Interface()).Elem()).Interface() - r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex) + nestedPtr := reflect.New(fieldValue.Type().Elem()).Interface() + if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil { + return err + } targetValue.Field(i).Set(reflect.ValueOf(nestedPtr)) continue } @@ -51,7 +180,9 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row } if tagValue == "---" { nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface() - r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex) + if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil { + return err + } targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem()) continue } @@ -62,20 +193,21 @@ func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, row } // Check if column exists - if _, ok := r.Row(rowIndex).columnIndex[tagValue]; !ok { + if _, ok := row.columnIndex[tagValue]; !ok { + if r.db != nil && r.db.StrictORM { + return wrapColumnNotFound(tagValue) + } continue } // Set field value based on type - r.setFieldValue(fieldValue, tagValue, rowIndex) + r.setFieldValue(fieldValue, tagValue, row) } return nil } // setFieldValue sets a single field value -func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, rowIndex int) { - row := r.Row(rowIndex) - +func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, row *StarResult) { switch fieldValue.Kind() { case reflect.String: fieldValue.SetString(row.MustString(columnName)) @@ -105,79 +237,70 @@ func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, ro fieldValue.SetFloat(float64(row.MustFloat32(columnName))) case reflect.Float64: fieldValue.SetFloat(row.MustFloat64(columnName)) - case reflect.Interface, reflect.Struct, reflect.Ptr: - // Handle special types like time.Time - colIndex := r.columnIndex[columnName] + case reflect.Struct: + // Handle special struct types like time.Time + colIndex := row.columnIndex[columnName] val := row.Result()[colIndex] - if t, ok := val.(time.Time); ok { - fieldValue.Set(reflect.ValueOf(t)) + if fieldValue.Type() == reflect.TypeOf(time.Time{}) { + if t, ok := val.(time.Time); ok { + fieldValue.Set(reflect.ValueOf(t)) + } + } + case reflect.Ptr: + // Handle pointer to special types like *time.Time + colIndex := row.columnIndex[columnName] + val := row.Result()[colIndex] + if fieldValue.Type().Elem() == reflect.TypeOf(time.Time{}) { + if t, ok := val.(time.Time); ok { + tCopy := t + fieldValue.Set(reflect.ValueOf(&tCopy)) + } + } + case reflect.Interface: + colIndex := row.columnIndex[columnName] + val := row.Result()[colIndex] + if val != nil { + fieldValue.Set(reflect.ValueOf(val)) } } } // getStructFieldValues extracts all field values from a struct func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) { - result := make(map[string]interface{}) + if target == nil { + return nil, ErrTargetNil + } + targetType := reflect.TypeOf(target) targetValue := reflect.ValueOf(target) if targetType.Kind() == reflect.Ptr { if targetValue.IsNil() { - return nil, errors.New("pointer target is nil") + return nil, ErrPointerTargetNil } targetType = targetType.Elem() targetValue = targetValue.Elem() } if targetValue.Kind() != reflect.Struct { - return nil, errors.New("target is not a struct") + return nil, ErrTargetNotStruct } - for i := 0; i < targetType.NumField(); i++ { - field := targetType.Field(i) - fieldValue := targetValue.Field(i) - tagValue := field.Tag.Get(tagKey) + plan, err := getStructTagPlan(targetType, tagKey) + if err != nil { + return nil, err + } - // Handle nested pointer structs - if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { - if fieldValue.IsNil() { - continue - } - if tagValue == "---" { - nestedValues, err := getStructFieldValues(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey) - if err != nil { - return result, err - } - for k, v := range nestedValues { - result[k] = v - } - continue - } - } - - // Handle nested structs - if targetValue.Field(i).Kind() == reflect.Struct { - if tagValue == "---" { - nestedValues, err := getStructFieldValues(targetValue.Field(i).Interface(), tagKey) - if err != nil { - return result, err - } - for k, v := range nestedValues { - result[k] = v - } - continue - } - } - - if tagValue == "" { + result := make(map[string]interface{}, len(plan)) + for _, field := range plan { + fieldValue, ok := resolveFieldByPath(targetValue, field.path) + if !ok { continue } - if !fieldValue.CanInterface() { continue } - - result[tagValue] = fieldValue.Interface() + result[field.tag] = fieldValue.Interface() } return result, nil @@ -185,10 +308,8 @@ func getStructFieldValues(target interface{}, tagKey string) (map[string]interfa // getStructFieldNames extracts all field names (tag values) from a struct func getStructFieldNames(target interface{}, tagKey string) ([]string, error) { - var result []string - - if !isStruct(target) { - return []string{}, errors.New("target is not a struct") + if target == nil { + return []string{}, ErrTargetNil } targetType := reflect.TypeOf(target) @@ -196,45 +317,31 @@ func getStructFieldNames(target interface{}, tagKey string) ([]string, error) { if targetType.Kind() == reflect.Ptr { if targetValue.IsNil() { - return []string{}, errors.New("pointer target is nil") + return []string{}, ErrPointerTargetNil } targetType = targetType.Elem() targetValue = targetValue.Elem() } - for i := 0; i < targetType.NumField(); i++ { - fieldValue := targetValue.Field(i) - field := targetType.Field(i) - tagValue := field.Tag.Get(tagKey) + if targetValue.Kind() != reflect.Struct { + return []string{}, ErrTargetNotStruct + } - // Handle nested pointer structs - if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { - if fieldValue.IsNil() { - continue - } - if tagValue == "---" { - nestedNames, err := getStructFieldNames(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey) - if err != nil { - return result, err - } - result = append(result, nestedNames...) - continue - } - } + plan, err := getStructTagPlan(targetType, tagKey) + if err != nil { + return []string{}, err + } - // Handle nested structs - if targetValue.Field(i).Kind() == reflect.Struct && tagValue == "---" { - nestedNames, err := getStructFieldNames(targetValue.Field(i).Interface(), tagKey) - if err != nil { - return result, err - } - result = append(result, nestedNames...) + result := make([]string, 0, len(plan)) + for _, field := range plan { + fieldValue, ok := resolveFieldByPath(targetValue, field.path) + if !ok { continue } - - if tagValue != "" { - result = append(result, tagValue) + if !fieldValue.CanInterface() { + continue } + result = append(result, field.tag) } return result, nil @@ -242,6 +349,10 @@ func getStructFieldNames(target interface{}, tagKey string) ([]string, error) { // isWritable checks if a value is writable func isWritable(target interface{}) bool { + if target == nil { + return false + } + targetType := reflect.TypeOf(target) targetValue := reflect.ValueOf(target) return targetType.Kind() == reflect.Ptr || targetValue.CanSet() @@ -249,8 +360,15 @@ func isWritable(target interface{}) bool { // isStruct checks if a value is a struct func isStruct(target interface{}) bool { + if target == nil { + return false + } + targetValue := reflect.ValueOf(target) if targetValue.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return false + } targetValue = targetValue.Elem() } return targetValue.Kind() == reflect.Struct diff --git a/result_safe.go b/result_safe.go index 63dd486..8c07aa7 100644 --- a/result_safe.go +++ b/result_safe.go @@ -1,51 +1,124 @@ package stardb import ( - "errors" - "fmt" - "strconv" + internalconv "b612.me/stardb/internal/convert" + "database/sql" ) +func (r *StarResult) getColumnValue(name string) (interface{}, error) { + if r == nil || r.columnIndex == nil { + return nil, wrapColumnNotFound(name) + } + index, ok := r.columnIndex[name] + if !ok { + return nil, wrapColumnNotFound(name) + } + return r.Result()[index], nil +} + // GetString returns column value as string with error func (r *StarResult) GetString(name string) (string, error) { - index, ok := r.columnIndex[name] - if !ok { - return "", errors.New("column not found: " + name) + val, err := r.getColumnValue(name) + if err != nil { + return "", err } - return ConvertToStringSafe(r.Result()[index]) + return ConvertToStringSafe(val) } // GetInt64 returns column value as int64 with error func (r *StarResult) GetInt64(name string) (int64, error) { - index, ok := r.columnIndex[name] - if !ok { - return 0, errors.New("column not found: " + name) + val, err := r.getColumnValue(name) + if err != nil { + return 0, err } - return ConvertToInt64Safe(r.Result()[index]) + return ConvertToInt64Safe(val) } // GetFloat64 returns column value as float64 with error func (r *StarResult) GetFloat64(name string) (float64, error) { - index, ok := r.columnIndex[name] - if !ok { - return 0, errors.New("column not found: " + name) - } - - switch v := r.Result()[index].(type) { - case nil: - return 0, nil - case float64: - return v, nil - case float32: - return float64(v), nil - case int, int32, int64, uint64: - val, err := ConvertToInt64Safe(v) - return float64(val), err - case string: - return strconv.ParseFloat(v, 64) - case []byte: - return strconv.ParseFloat(string(v), 64) - default: - return 0, fmt.Errorf("cannot convert %T to float64", v) + val, err := r.getColumnValue(name) + if err != nil { + return 0, err } + return internalconv.ToFloat64Safe(val) +} + +// GetNullString returns a nullable string value. +func (r *StarResult) GetNullString(name string) (sql.NullString, error) { + val, err := r.getColumnValue(name) + if err != nil { + return sql.NullString{}, err + } + if val == nil { + return sql.NullString{}, nil + } + str, err := ConvertToStringSafe(val) + if err != nil { + return sql.NullString{}, err + } + return sql.NullString{String: str, Valid: true}, nil +} + +// GetNullInt64 returns a nullable int64 value. +func (r *StarResult) GetNullInt64(name string) (sql.NullInt64, error) { + val, err := r.getColumnValue(name) + if err != nil { + return sql.NullInt64{}, err + } + if val == nil { + return sql.NullInt64{}, nil + } + i, err := ConvertToInt64Safe(val) + if err != nil { + return sql.NullInt64{}, err + } + return sql.NullInt64{Int64: i, Valid: true}, nil +} + +// GetNullFloat64 returns a nullable float64 value. +func (r *StarResult) GetNullFloat64(name string) (sql.NullFloat64, error) { + val, err := r.getColumnValue(name) + if err != nil { + return sql.NullFloat64{}, err + } + if val == nil { + return sql.NullFloat64{}, nil + } + f, err := internalconv.ToFloat64Safe(val) + if err != nil { + return sql.NullFloat64{}, err + } + return sql.NullFloat64{Float64: f, Valid: true}, nil +} + +// GetNullBool returns a nullable bool value. +func (r *StarResult) GetNullBool(name string) (sql.NullBool, error) { + val, err := r.getColumnValue(name) + if err != nil { + return sql.NullBool{}, err + } + if val == nil { + return sql.NullBool{}, nil + } + b, err := internalconv.ToBoolSafe(val) + if err != nil { + return sql.NullBool{}, err + } + return sql.NullBool{Bool: b, Valid: true}, nil +} + +// GetNullTime returns a nullable time value. +func (r *StarResult) GetNullTime(name string) (sql.NullTime, error) { + val, err := r.getColumnValue(name) + if err != nil { + return sql.NullTime{}, err + } + if val == nil { + return sql.NullTime{}, nil + } + t, err := internalconv.ToTimeSafe(val) + if err != nil { + return sql.NullTime{}, err + } + return sql.NullTime{Time: t, Valid: true}, nil } diff --git a/rows.go b/rows.go index 4af1ee1..87995f8 100644 --- a/rows.go +++ b/rows.go @@ -1,6 +1,7 @@ package stardb import ( + "b612.me/stardb/internal/scanutil" "database/sql" "reflect" "strconv" @@ -53,7 +54,7 @@ func (r *StarRows) Rescan() error { // Row returns a specific row by index func (r *StarRows) Row(index int) *StarResult { result := &StarResult{} - if index >= len(r.data) { + if index < 0 || index >= len(r.data) { return result } result.result = r.data[index] @@ -82,13 +83,11 @@ func (r *StarRows) parse() error { if r.parsed { return nil } - defer func() { - r.parsed = true - }() r.data = [][]interface{}{} r.columnIndex = make(map[string]int) r.stringResult = []map[string]string{} + r.columnsType = []reflect.Type{} var err error r.columns, err = r.rows.Columns() @@ -127,18 +126,28 @@ func (r *StarRows) parse() error { rowCopy := make([]interface{}, len(values)) for i, val := range values { - rowCopy[i] = val - record[r.columns[i]] = convertToString(val) + copiedVal := cloneScannedValue(val) + rowCopy[i] = copiedVal + record[r.columns[i]] = convertToString(copiedVal) } r.data = append(r.data, rowCopy) r.stringResult = append(r.stringResult, record) } + if err := r.rows.Err(); err != nil { + return err + } + r.length = len(r.stringResult) + r.parsed = true return nil } +func cloneScannedValue(val interface{}) interface{} { + return scanutil.CloneScannedValue(val) +} + // convertToString converts any value to string func convertToString(val interface{}) string { switch v := val.(type) { diff --git a/rows_internal_test.go b/rows_internal_test.go new file mode 100644 index 0000000..1ed068e --- /dev/null +++ b/rows_internal_test.go @@ -0,0 +1,29 @@ +package stardb + +import "testing" + +func TestCloneScannedValue_BytesAreCopied(t *testing.T) { + original := []byte("hello") + clonedAny := cloneScannedValue(original) + cloned, ok := clonedAny.([]byte) + if !ok { + t.Fatalf("expected []byte, got %T", clonedAny) + } + + original[0] = 'H' + if string(cloned) != "hello" { + t.Fatalf("expected cloned value to remain 'hello', got '%s'", string(cloned)) + } + + if len(cloned) > 0 && &cloned[0] == &original[0] { + t.Fatal("expected cloned bytes to have a different backing array") + } +} + +func TestCloneScannedValue_NonBytesKeepReference(t *testing.T) { + in := int64(42) + out := cloneScannedValue(in) + if out != in { + t.Fatalf("expected %v, got %v", in, out) + } +} diff --git a/scan_each.go b/scan_each.go new file mode 100644 index 0000000..4256fbb --- /dev/null +++ b/scan_each.go @@ -0,0 +1,124 @@ +package stardb + +import ( + "context" + "database/sql" + "errors" + "reflect" +) + +// ScanEachFunc is called for each scanned row in streaming mode. +type ScanEachFunc func(row *StarResult) error + +// ScanEach executes query in streaming mode and invokes fn for each row. +func (s *StarDB) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error { + return s.scanEach(nil, query, fn, args...) +} + +// ScanEachContext executes query with context in streaming mode and invokes fn for each row. +func (s *StarDB) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error { + return s.scanEach(ctx, query, fn, args...) +} + +func (s *StarDB) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error { + rows, err := s.queryRaw(ctx, query, args...) + if err != nil { + return err + } + return scanEachSQLRows(rows, fn) +} + +// ScanEach executes query in transaction streaming mode and invokes fn for each row. +func (t *StarTx) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error { + return t.scanEach(nil, query, fn, args...) +} + +// ScanEachContext executes query with context in transaction streaming mode and invokes fn for each row. +func (t *StarTx) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error { + return t.scanEach(ctx, query, fn, args...) +} + +func (t *StarTx) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error { + rows, err := t.queryRaw(ctx, query, args...) + if err != nil { + return err + } + return scanEachSQLRows(rows, fn) +} + +// ScanEach executes prepared statement query in streaming mode and invokes fn for each row. +func (s *StarStmt) ScanEach(fn ScanEachFunc, args ...interface{}) error { + return s.scanEach(nil, fn, args...) +} + +// ScanEachContext executes prepared statement query with context in streaming mode and invokes fn for each row. +func (s *StarStmt) ScanEachContext(ctx context.Context, fn ScanEachFunc, args ...interface{}) error { + return s.scanEach(ctx, fn, args...) +} + +func (s *StarStmt) scanEach(ctx context.Context, fn ScanEachFunc, args ...interface{}) error { + rows, err := s.queryRaw(ctx, args...) + if err != nil { + return err + } + return scanEachSQLRows(rows, fn) +} + +func scanEachSQLRows(rows *sql.Rows, fn ScanEachFunc) error { + if fn == nil { + _ = rows.Close() + return ErrScanFuncNil + } + defer rows.Close() + + columns, err := rows.Columns() + if err != nil { + return err + } + + columnTypes, err := rows.ColumnTypes() + if err != nil { + return err + } + types := make([]reflect.Type, len(columnTypes)) + for i, colType := range columnTypes { + types[i] = colType.ScanType() + } + + columnIndex := make(map[string]int, len(columns)) + for i, colName := range columns { + columnIndex[colName] = i + } + + scanArgs := make([]interface{}, len(columns)) + values := make([]interface{}, len(columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + for rows.Next() { + if err := rows.Scan(scanArgs...); err != nil { + return err + } + + rowCopy := make([]interface{}, len(values)) + for i, val := range values { + rowCopy[i] = cloneScannedValue(val) + } + + row := &StarResult{ + result: rowCopy, + columns: columns, + columnIndex: columnIndex, + columnsType: types, + } + if err := fn(row); err != nil { + if errors.Is(err, ErrScanStopped) { + return nil + } + return err + } + } + + return rows.Err() +} diff --git a/scan_each_orm.go b/scan_each_orm.go new file mode 100644 index 0000000..8792383 --- /dev/null +++ b/scan_each_orm.go @@ -0,0 +1,119 @@ +package stardb + +import ( + "context" + "reflect" +) + +// ScanEachORMFunc is called for each mapped struct in streaming ORM mode. +type ScanEachORMFunc func(target interface{}) error + +// ScanEachORM streams query rows and maps each row to target before invoking fn. +// target must be a pointer to struct; it is reused for each row. +func (s *StarDB) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error { + return s.ScanEachORMContext(nil, query, target, fn, args...) +} + +// ScanEachORMContext streams query rows with context and maps each row to target before invoking fn. +// target must be a pointer to struct; it is reused for each row. +func (s *StarDB) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error { + if fn == nil { + return ErrScanORMFuncNil + } + if err := validateScanORMTarget(target); err != nil { + return err + } + + return s.ScanEachContext(ctx, query, func(row *StarResult) error { + if err := mapResultToStructTarget(row, target, s); err != nil { + return err + } + return fn(target) + }, args...) +} + +// ScanEachORM streams transaction rows and maps each row to target before invoking fn. +// target must be a pointer to struct; it is reused for each row. +func (t *StarTx) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error { + return t.ScanEachORMContext(nil, query, target, fn, args...) +} + +// ScanEachORMContext streams transaction rows with context and maps each row to target before invoking fn. +// target must be a pointer to struct; it is reused for each row. +func (t *StarTx) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error { + if fn == nil { + return ErrScanORMFuncNil + } + if err := validateScanORMTarget(target); err != nil { + return err + } + + return t.ScanEachContext(ctx, query, func(row *StarResult) error { + if err := mapResultToStructTarget(row, target, t.db); err != nil { + return err + } + return fn(target) + }, args...) +} + +// ScanEachORM streams prepared statement rows and maps each row to target before invoking fn. +// target must be a pointer to struct; it is reused for each row. +func (s *StarStmt) ScanEachORM(target interface{}, fn ScanEachORMFunc, args ...interface{}) error { + return s.ScanEachORMContext(nil, target, fn, args...) +} + +// ScanEachORMContext streams prepared statement rows with context and maps each row to target before invoking fn. +// target must be a pointer to struct; it is reused for each row. +func (s *StarStmt) ScanEachORMContext(ctx context.Context, target interface{}, fn ScanEachORMFunc, args ...interface{}) error { + if fn == nil { + return ErrScanORMFuncNil + } + if err := validateScanORMTarget(target); err != nil { + return err + } + + return s.ScanEachContext(ctx, func(row *StarResult) error { + if err := mapResultToStructTarget(row, target, s.db); err != nil { + return err + } + return fn(target) + }, args...) +} + +func validateScanORMTarget(target interface{}) error { + if target == nil { + return ErrTargetNil + } + + targetType := reflect.TypeOf(target) + targetValue := reflect.ValueOf(target) + + if targetType.Kind() != reflect.Ptr { + return ErrTargetNotPointer + } + if targetValue.IsNil() { + return ErrTargetPointerNil + } + if targetValue.Elem().Kind() != reflect.Struct { + return ErrTargetNotStruct + } + + return nil +} + +func mapResultToStructTarget(row *StarResult, target interface{}, db *StarDB) error { + targetValue := reflect.ValueOf(target) + targetValue.Elem().Set(reflect.Zero(targetValue.Elem().Type())) + + rowWrapper := &StarRows{ + db: db, + length: 1, + columns: row.columns, + columnsType: row.columnsType, + columnIndex: row.columnIndex, + data: [][]interface{}{row.result}, + parsed: true, + } + + return rowWrapper.setStructFieldsFromRow(target, "db", 0) +} diff --git a/sql_placeholder.go b/sql_placeholder.go new file mode 100644 index 0000000..a7b342c --- /dev/null +++ b/sql_placeholder.go @@ -0,0 +1,19 @@ +package stardb + +import internalsqlplaceholder "b612.me/stardb/internal/sqlplaceholder" + +// ConvertPlaceholders converts placeholders according to style. +func ConvertPlaceholders(query string, style PlaceholderStyle) string { + switch normalizePlaceholderStyle(style) { + case PlaceholderDollar: + return ConvertQuestionToDollarPlaceholders(query) + default: + return query + } +} + +// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text. +// It skips quoted strings, quoted identifiers and comments. +func ConvertQuestionToDollarPlaceholders(query string) string { + return internalsqlplaceholder.ConvertQuestionToDollarPlaceholders(query) +} diff --git a/sql_placeholder_test.go b/sql_placeholder_test.go new file mode 100644 index 0000000..c43ed18 --- /dev/null +++ b/sql_placeholder_test.go @@ -0,0 +1,34 @@ +package stardb + +import "testing" + +func TestConvertQuestionToDollarPlaceholders(t *testing.T) { + query := "SELECT * FROM users WHERE id = ? AND name = ?" + got := ConvertQuestionToDollarPlaceholders(query) + want := "SELECT * FROM users WHERE id = $1 AND name = $2" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} + +func TestConvertQuestionToDollarPlaceholders_SkipQuotedAndComments(t *testing.T) { + query := "SELECT '?', \"?\", `?`, col FROM t WHERE id = ? -- ?\nAND note = '??' /* ? */ AND x = ?" + got := ConvertQuestionToDollarPlaceholders(query) + want := "SELECT '?', \"?\", `?`, col FROM t WHERE id = $1 -- ?\nAND note = '??' /* ? */ AND x = $2" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} + +func TestConvertPlaceholders(t *testing.T) { + query := "SELECT * FROM t WHERE a = ? AND b = ?" + if got := ConvertPlaceholders(query, PlaceholderQuestion); got != query { + t.Fatalf("question style should keep query unchanged, got %q", got) + } + + got := ConvertPlaceholders(query, PlaceholderDollar) + want := "SELECT * FROM t WHERE a = $1 AND b = $2" + if got != want { + t.Fatalf("expected %q, got %q", want, got) + } +} diff --git a/sql_runtime.go b/sql_runtime.go new file mode 100644 index 0000000..4608997 --- /dev/null +++ b/sql_runtime.go @@ -0,0 +1,276 @@ +package stardb + +import ( + internalsqlruntime "b612.me/stardb/internal/sqlruntime" + "context" + "time" +) + +// SQLBeforeHook runs before a SQL statement is executed. +type SQLBeforeHook func(ctx context.Context, query string, args []interface{}) + +// SQLAfterHook runs after a SQL statement is executed. +type SQLAfterHook func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) + +// PlaceholderStyle controls SQL placeholder format conversion. +type PlaceholderStyle int + +const ( + // PlaceholderQuestion keeps '?' placeholders unchanged. + PlaceholderQuestion PlaceholderStyle = iota + // PlaceholderDollar converts '?' placeholders to '$1,$2,...'. + PlaceholderDollar +) + +// SQLFingerprintMode controls SQL fingerprint generation strategy. +type SQLFingerprintMode int + +const ( + // SQLFingerprintBasic lowercases SQL and collapses whitespace. + SQLFingerprintBasic SQLFingerprintMode = iota + // SQLFingerprintMaskLiterals also masks numeric/string literals and $n placeholders. + SQLFingerprintMaskLiterals +) + +type sqlHookMetaKey struct{} + +// SQLHookMeta contains extra hook metadata attached to context. +type SQLHookMeta struct { + Fingerprint string +} + +// SQLHookMetaFromContext extracts SQL hook metadata from context. +func SQLHookMetaFromContext(ctx context.Context) (SQLHookMeta, bool) { + if ctx == nil { + return SQLHookMeta{}, false + } + meta, ok := ctx.Value(sqlHookMetaKey{}).(SQLHookMeta) + return meta, ok +} + +func withSQLHookMeta(ctx context.Context, meta SQLHookMeta) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, sqlHookMetaKey{}, meta) +} + +type sqlRuntime struct { + state internalsqlruntime.State +} + +func normalizePlaceholderStyle(style PlaceholderStyle) PlaceholderStyle { + return PlaceholderStyle(internalsqlruntime.NormalizePlaceholderStyle(int(style))) +} + +func normalizeSQLFingerprintMode(mode SQLFingerprintMode) SQLFingerprintMode { + return SQLFingerprintMode(internalsqlruntime.NormalizeFingerprintMode(int(mode))) +} + +func cloneHookArgs(args []interface{}) []interface{} { + return internalsqlruntime.CloneHookArgs(args) +} + +func (s *StarDB) runtimeOptions() (SQLBeforeHook, SQLAfterHook, PlaceholderStyle, time.Duration) { + if s == nil { + return nil, nil, PlaceholderQuestion, 0 + } + + beforeAny, afterAny, rawStyle, slowThreshold := s.runtime.state.Options() + var before SQLBeforeHook + if b, ok := beforeAny.(SQLBeforeHook); ok { + before = b + } + var after SQLAfterHook + if a, ok := afterAny.(SQLAfterHook); ok { + after = a + } + + return before, after, normalizePlaceholderStyle(PlaceholderStyle(rawStyle)), slowThreshold +} + +func (s *StarDB) sqlHooks() (SQLBeforeHook, SQLAfterHook, time.Duration) { + before, after, _, slowThreshold := s.runtimeOptions() + return before, after, slowThreshold +} + +func (s *StarDB) prepareSQLCall(query string, args []interface{}) (string, SQLBeforeHook, SQLAfterHook, []interface{}, time.Duration) { + before, after, style, slowThreshold := s.runtimeOptions() + query = ConvertPlaceholders(query, style) + if before == nil && after == nil { + return query, nil, nil, nil, slowThreshold + } + return query, before, after, cloneHookArgs(args), slowThreshold +} + +func (s *StarDB) hookContext(ctx context.Context, query string, before SQLBeforeHook, after SQLAfterHook) context.Context { + if s == nil { + return ctx + } + + needCounter := s.SQLFingerprintCounterEnabled() + needMeta := (before != nil || after != nil) && s.SQLFingerprintEnabled() + if !needCounter && !needMeta { + return ctx + } + + mode := s.SQLFingerprintMode() + keepComments := s.SQLFingerprintKeepComments() + fingerprint := internalsqlruntime.FingerprintSQL(query, int(mode), keepComments) + if fingerprint == "" { + return ctx + } + + if needCounter { + s.runtime.state.IncFingerprintCount(fingerprint) + } + if !needMeta { + return ctx + } + return withSQLHookMeta(ctx, SQLHookMeta{Fingerprint: fingerprint}) +} + +func shouldRunAfterHook(after SQLAfterHook, slowThreshold, duration time.Duration, err error) bool { + return internalsqlruntime.ShouldRunAfterHook(after != nil, slowThreshold, duration, err) +} + +// SetSQLHooks sets SQL before/after hooks. +func (s *StarDB) SetSQLHooks(before SQLBeforeHook, after SQLAfterHook) { + if s == nil { + return + } + s.runtime.state.SetHooks(before, after) +} + +// SetSQLBeforeHook sets SQL before hook. +func (s *StarDB) SetSQLBeforeHook(before SQLBeforeHook) { + if s == nil { + return + } + s.runtime.state.SetBeforeHook(before) +} + +// SetSQLAfterHook sets SQL after hook. +func (s *StarDB) SetSQLAfterHook(after SQLAfterHook) { + if s == nil { + return + } + s.runtime.state.SetAfterHook(after) +} + +// SetPlaceholderStyle sets placeholder conversion style. +func (s *StarDB) SetPlaceholderStyle(style PlaceholderStyle) { + if s == nil { + return + } + s.runtime.state.SetPlaceholderStyle(int(style)) +} + +// SetSQLSlowThreshold sets minimum duration for triggering after hook. +// When threshold > 0, after hook runs only for statements slower than threshold or those with error. +func (s *StarDB) SetSQLSlowThreshold(threshold time.Duration) { + if s == nil { + return + } + s.runtime.state.SetSlowThreshold(threshold) +} + +// SQLSlowThreshold returns current slow SQL threshold. +func (s *StarDB) SQLSlowThreshold() time.Duration { + if s == nil { + return 0 + } + return s.runtime.state.SlowThreshold() +} + +// PlaceholderStyle returns current placeholder style. +func (s *StarDB) PlaceholderStyle() PlaceholderStyle { + if s == nil { + return PlaceholderQuestion + } + style := PlaceholderStyle(s.runtime.state.PlaceholderStyle()) + return normalizePlaceholderStyle(style) +} + +// SetSQLFingerprintEnabled toggles SQL fingerprint metadata generation for hooks. +func (s *StarDB) SetSQLFingerprintEnabled(enabled bool) { + if s == nil { + return + } + s.runtime.state.SetFingerprintEnabled(enabled) +} + +// SQLFingerprintEnabled reports whether SQL fingerprint metadata generation is enabled. +func (s *StarDB) SQLFingerprintEnabled() bool { + if s == nil { + return false + } + return s.runtime.state.FingerprintEnabled() +} + +// SetSQLFingerprintMode sets SQL fingerprint generation mode. +func (s *StarDB) SetSQLFingerprintMode(mode SQLFingerprintMode) { + if s == nil { + return + } + s.runtime.state.SetFingerprintMode(int(mode)) +} + +// SQLFingerprintMode returns SQL fingerprint generation mode. +func (s *StarDB) SQLFingerprintMode() SQLFingerprintMode { + if s == nil { + return SQLFingerprintBasic + } + mode := SQLFingerprintMode(s.runtime.state.FingerprintMode()) + return normalizeSQLFingerprintMode(mode) +} + +// SetSQLFingerprintKeepComments controls whether comments are preserved in SQL fingerprints. +// Default is false. +func (s *StarDB) SetSQLFingerprintKeepComments(keep bool) { + if s == nil { + return + } + s.runtime.state.SetFingerprintKeepComments(keep) +} + +// SQLFingerprintKeepComments reports whether SQL fingerprints preserve comments. +func (s *StarDB) SQLFingerprintKeepComments() bool { + if s == nil { + return false + } + return s.runtime.state.FingerprintKeepComments() +} + +// SetSQLFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter. +// Default is false. +func (s *StarDB) SetSQLFingerprintCounterEnabled(enabled bool) { + if s == nil { + return + } + s.runtime.state.SetFingerprintCounterEnabled(enabled) +} + +// SQLFingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled. +func (s *StarDB) SQLFingerprintCounterEnabled() bool { + if s == nil { + return false + } + return s.runtime.state.FingerprintCounterEnabled() +} + +// SQLFingerprintCounters returns a snapshot of fingerprint hit counters. +func (s *StarDB) SQLFingerprintCounters() map[string]uint64 { + if s == nil { + return map[string]uint64{} + } + return s.runtime.state.FingerprintCountsSnapshot() +} + +// ResetSQLFingerprintCounters clears all in-memory fingerprint hit counters. +func (s *StarDB) ResetSQLFingerprintCounters() { + if s == nil { + return + } + s.runtime.state.ResetFingerprintCounts() +} diff --git a/sql_runtime_test.go b/sql_runtime_test.go new file mode 100644 index 0000000..14775c6 --- /dev/null +++ b/sql_runtime_test.go @@ -0,0 +1,102 @@ +package stardb + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestStarDB_RuntimeConfigConcurrent(t *testing.T) { + db := NewStarDB() + + before := func(ctx context.Context, query string, args []interface{}) {} + after := func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {} + + var wg sync.WaitGroup + for i := 0; i < 16; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + for j := 0; j < 1000; j++ { + if (i+j)%2 == 0 { + db.SetPlaceholderStyle(PlaceholderDollar) + } else { + db.SetPlaceholderStyle(PlaceholderQuestion) + } + db.SetSQLSlowThreshold(time.Duration((i+j)%5) * time.Millisecond) + db.SetSQLFingerprintEnabled((i+j)%3 == 0) + db.SetSQLFingerprintMode(SQLFingerprintMode((i + j) % 3)) + db.SetSQLFingerprintKeepComments((i+j)%4 == 0) + db.SetSQLFingerprintCounterEnabled((i+j)%5 == 0) + if (i+j)%7 == 0 { + db.ResetSQLFingerprintCounters() + } + db.SetSQLHooks(before, after) + _ = db.PlaceholderStyle() + _ = db.SQLSlowThreshold() + _ = db.SQLFingerprintEnabled() + _ = db.SQLFingerprintMode() + _ = db.SQLFingerprintKeepComments() + _ = db.SQLFingerprintCounterEnabled() + _ = db.SQLFingerprintCounters() + _, _, _, _ = db.runtimeOptions() + } + }(i) + } + wg.Wait() +} + +func TestStarDB_SQLFingerprintMode(t *testing.T) { + db := NewStarDB() + + if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic { + t.Fatalf("expected default mode SQLFingerprintBasic, got %v", got) + } + + db.SetSQLFingerprintMode(SQLFingerprintMaskLiterals) + if got := db.SQLFingerprintMode(); got != SQLFingerprintMaskLiterals { + t.Fatalf("expected SQLFingerprintMaskLiterals, got %v", got) + } + + db.SetSQLFingerprintMode(SQLFingerprintMode(99)) + if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic { + t.Fatalf("expected invalid mode fallback to SQLFingerprintBasic, got %v", got) + } +} + +func TestStarDB_SQLFingerprintKeepComments(t *testing.T) { + db := NewStarDB() + + if db.SQLFingerprintKeepComments() { + t.Fatal("expected default keep comments to be false") + } + + db.SetSQLFingerprintKeepComments(true) + if !db.SQLFingerprintKeepComments() { + t.Fatal("expected keep comments to be true") + } + + db.SetSQLFingerprintKeepComments(false) + if db.SQLFingerprintKeepComments() { + t.Fatal("expected keep comments to be false") + } +} + +func TestStarDB_SQLFingerprintCounterSwitch(t *testing.T) { + db := NewStarDB() + + if db.SQLFingerprintCounterEnabled() { + t.Fatal("expected default counter switch to be false") + } + + db.SetSQLFingerprintCounterEnabled(true) + if !db.SQLFingerprintCounterEnabled() { + t.Fatal("expected counter switch to be true") + } + + db.ResetSQLFingerprintCounters() + if got := len(db.SQLFingerprintCounters()); got != 0 { + t.Fatalf("expected empty counters after reset, got %d", got) + } +} diff --git a/stardb.go b/stardb.go index db43d6d..683268a 100644 --- a/stardb.go +++ b/stardb.go @@ -3,13 +3,20 @@ package stardb import ( "context" "database/sql" - "errors" + "strings" + "time" ) // StarDB is a simple wrapper around sql.DB providing enhanced functionality type StarDB struct { db *sql.DB ManualScan bool // If true, rows won't be automatically parsed + StrictORM bool // If true, Orm requires all tagged columns to exist in query results + // batchInsertMaxRows controls batch split size for BatchInsert/BatchInsertStructs. + // <= 0 means no split (single SQL statement). + batchInsertMaxRows int64 + batchInsertMaxParams int64 + runtime sqlRuntime } // NewStarDB creates a new StarDB instance @@ -32,6 +39,21 @@ func (s *StarDB) SetDB(db *sql.DB) { s.db = db } +// SetStrictORM enables or disables strict column validation for Orm mapping. +func (s *StarDB) SetStrictORM(strict bool) { + if s == nil { + return + } + s.StrictORM = strict +} + +func (s *StarDB) ensureDB() error { + if s == nil || s.db == nil { + return ErrDBNotInitialized + } + return nil +} + // Open opens a new database connection func (s *StarDB) Open(driver, connStr string) error { var err error @@ -41,36 +63,57 @@ func (s *StarDB) Open(driver, connStr string) error { // Close closes the database connection func (s *StarDB) Close() error { + if err := s.ensureDB(); err != nil { + return err + } return s.db.Close() } // Ping verifies the database connection is alive func (s *StarDB) Ping() error { + if err := s.ensureDB(); err != nil { + return err + } return s.db.Ping() } // PingContext verifies the database connection with context func (s *StarDB) PingContext(ctx context.Context) error { + if err := s.ensureDB(); err != nil { + return err + } return s.db.PingContext(ctx) } // Stats returns database statistics func (s *StarDB) Stats() sql.DBStats { + if s == nil || s.db == nil { + return sql.DBStats{} + } return s.db.Stats() } // SetMaxOpenConns sets the maximum number of open connections func (s *StarDB) SetMaxOpenConns(n int) { + if s == nil || s.db == nil { + return + } s.db.SetMaxOpenConns(n) } // SetMaxIdleConns sets the maximum number of idle connections func (s *StarDB) SetMaxIdleConns(n int) { + if s == nil || s.db == nil { + return + } s.db.SetMaxIdleConns(n) } // Conn returns a single connection from the pool func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) { + if err := s.ensureDB(); err != nil { + return nil, err + } return s.db.Conn(ctx) } @@ -85,21 +128,55 @@ func (s *StarDB) QueryContext(ctx context.Context, query string, args ...interfa return s.query(ctx, query, args...) } -// query is the internal query implementation -func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { - if err := s.db.Ping(); err != nil { +// QueryRaw executes a query and returns *sql.Rows without automatic parsing. +func (s *StarDB) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) { + return s.queryRaw(nil, query, args...) +} + +// QueryRawContext executes a query with context and returns *sql.Rows without automatic parsing. +func (s *StarDB) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return s.queryRaw(ctx, query, args...) +} + +func (s *StarDB) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if err := s.ensureDB(); err != nil { return nil, err } + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty + } - var rows *sql.Rows - var err error + query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args) + hookCtx := s.hookContext(ctx, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, hookArgs) + } + start := time.Now() + var ( + rows *sql.Rows + err error + ) if ctx == nil { rows, err = s.db.Query(query, args...) } else { rows, err = s.db.QueryContext(ctx, query, args...) } + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, hookArgs, duration, err) + } + + if err != nil { + return nil, err + } + return rows, nil +} + +// query is the internal query implementation +func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + rows, err := s.queryRaw(ctx, query, args...) if err != nil { return nil, err } @@ -110,10 +187,13 @@ func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) ( } if !s.ManualScan { - err = starRows.parse() + if err := starRows.parse(); err != nil { + _ = rows.Close() + return nil, err + } } - return starRows, err + return starRows, nil } // Exec executes a query that doesn't return rows @@ -129,39 +209,100 @@ func (s *StarDB) ExecContext(ctx context.Context, query string, args ...interfac // exec is the internal exec implementation func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - if err := s.db.Ping(); err != nil { + if err := s.ensureDB(); err != nil { return nil, err } - - if ctx == nil { - return s.db.Exec(query, args...) + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } - return s.db.ExecContext(ctx, query, args...) + + query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args) + hookCtx := s.hookContext(ctx, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, hookArgs) + } + start := time.Now() + + var ( + result sql.Result + err error + ) + if ctx == nil { + result, err = s.db.Exec(query, args...) + } else { + result, err = s.db.ExecContext(ctx, query, args...) + } + + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, hookArgs, duration, err) + } + + if err != nil { + return nil, err + } + return result, nil } // Prepare creates a prepared statement func (s *StarDB) Prepare(query string) (*StarStmt, error) { + if err := s.ensureDB(); err != nil { + return nil, err + } + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty + } + + query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil) + hookCtx := s.hookContext(nil, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, nil) + } + start := time.Now() + stmt, err := s.db.Prepare(query) + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, nil, duration, err) + } if err != nil { return nil, err } - return &StarStmt{stmt: stmt, db: s}, nil + return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil } // PrepareContext creates a prepared statement with context func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) { + if err := s.ensureDB(); err != nil { + return nil, err + } + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty + } + + query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil) + hookCtx := s.hookContext(ctx, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, nil) + } + start := time.Now() + stmt, err := s.db.PrepareContext(ctx, query) + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, nil, duration, err) + } if err != nil { return nil, err } - return &StarStmt{stmt: stmt, db: s}, nil + return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil } // QueryStmt executes a prepared statement query // Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1) func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := s.Prepare(query) if err != nil { @@ -173,8 +314,8 @@ func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) // QueryStmtContext executes a prepared statement query with context func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := s.PrepareContext(ctx, query) if err != nil { @@ -187,8 +328,8 @@ func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...int // ExecStmt executes a prepared statement // Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John") func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := s.Prepare(query) if err != nil { @@ -200,8 +341,8 @@ func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) // ExecStmtContext executes a prepared statement with context func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := s.PrepareContext(ctx, query) if err != nil { @@ -213,6 +354,9 @@ func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...inte // Begin starts a transaction func (s *StarDB) Begin() (*StarTx, error) { + if err := s.ensureDB(); err != nil { + return nil, err + } tx, err := s.db.Begin() if err != nil { return nil, err @@ -222,6 +366,9 @@ func (s *StarDB) Begin() (*StarTx, error) { // BeginTx starts a transaction with options func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) { + if err := s.ensureDB(); err != nil { + return nil, err + } tx, err := s.db.BeginTx(ctx, opts) if err != nil { return nil, err @@ -231,8 +378,26 @@ func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, err // StarStmt represents a prepared statement type StarStmt struct { - stmt *sql.Stmt - db *StarDB + stmt *sql.Stmt + db *StarDB + sqlText string +} + +func (s *StarStmt) ensureStmt() error { + if s == nil || s.stmt == nil { + return ErrStmtNotInitialized + } + return nil +} + +func (s *StarStmt) ensureStmtWithDB() error { + if err := s.ensureStmt(); err != nil { + return err + } + if s.db == nil { + return ErrStmtDBNotInitialized + } + return nil } // Query executes a prepared statement query @@ -245,17 +410,66 @@ func (s *StarStmt) QueryContext(ctx context.Context, args ...interface{}) (*Star return s.query(ctx, args...) } -// query is the internal query implementation -func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) { - var rows *sql.Rows - var err error +// QueryRaw executes a prepared statement query and returns *sql.Rows without automatic parsing. +func (s *StarStmt) QueryRaw(args ...interface{}) (*sql.Rows, error) { + return s.queryRaw(nil, args...) +} +// QueryRawContext executes a prepared statement query with context and returns *sql.Rows without automatic parsing. +func (s *StarStmt) QueryRawContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) { + return s.queryRaw(ctx, args...) +} + +func (s *StarStmt) queryRaw(ctx context.Context, args ...interface{}) (*sql.Rows, error) { + if err := s.ensureStmt(); err != nil { + return nil, err + } + + var beforeHook SQLBeforeHook + var afterHook SQLAfterHook + var slowThreshold time.Duration + if s.db != nil { + beforeHook, afterHook, slowThreshold = s.db.sqlHooks() + } + var hookArgs []interface{} + if beforeHook != nil || afterHook != nil { + hookArgs = cloneHookArgs(args) + } + hookCtx := ctx + if s.db != nil { + hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook) + } + if beforeHook != nil { + beforeHook(hookCtx, s.sqlText, hookArgs) + } + start := time.Now() + + var ( + rows *sql.Rows + err error + ) if ctx == nil { rows, err = s.stmt.Query(args...) } else { rows, err = s.stmt.QueryContext(ctx, args...) } + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, s.sqlText, hookArgs, duration, err) + } + if err != nil { + return nil, err + } + return rows, nil +} +// query is the internal query implementation +func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) { + if err := s.ensureStmtWithDB(); err != nil { + return nil, err + } + + rows, err := s.queryRaw(ctx, args...) if err != nil { return nil, err } @@ -266,10 +480,13 @@ func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, e } if !s.db.ManualScan { - err = starRows.parse() + if err := starRows.parse(); err != nil { + _ = rows.Close() + return nil, err + } } - return starRows, err + return starRows, nil } // Exec executes a prepared statement @@ -284,13 +501,52 @@ func (s *StarStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Re // exec is the internal exec implementation func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { - if ctx == nil { - return s.stmt.Exec(args...) + if err := s.ensureStmt(); err != nil { + return nil, err } - return s.stmt.ExecContext(ctx, args...) + + var beforeHook SQLBeforeHook + var afterHook SQLAfterHook + var slowThreshold time.Duration + if s.db != nil { + beforeHook, afterHook, slowThreshold = s.db.sqlHooks() + } + var hookArgs []interface{} + if beforeHook != nil || afterHook != nil { + hookArgs = cloneHookArgs(args) + } + hookCtx := ctx + if s.db != nil { + hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook) + } + if beforeHook != nil { + beforeHook(hookCtx, s.sqlText, hookArgs) + } + start := time.Now() + + var ( + result sql.Result + err error + ) + if ctx == nil { + result, err = s.stmt.Exec(args...) + } else { + result, err = s.stmt.ExecContext(ctx, args...) + } + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, s.sqlText, hookArgs, duration, err) + } + if err != nil { + return nil, err + } + return result, nil } // Close closes the prepared statement func (s *StarStmt) Close() error { + if err := s.ensureStmt(); err != nil { + return err + } return s.stmt.Close() } diff --git a/stardb_safe_test.go b/stardb_safe_test.go new file mode 100644 index 0000000..1348b00 --- /dev/null +++ b/stardb_safe_test.go @@ -0,0 +1,127 @@ +package stardb + +import ( + "context" + "errors" + "testing" +) + +func TestStarDB_NotInitialized(t *testing.T) { + db := NewStarDB() + + if err := db.Close(); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Close, got %v", err) + } + if err := db.Ping(); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Ping, got %v", err) + } + if err := db.PingContext(context.Background()); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from PingContext, got %v", err) + } + if _, err := db.Conn(context.Background()); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Conn, got %v", err) + } + if _, err := db.Query("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Query, got %v", err) + } + if err := db.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from ScanEach, got %v", err) + } + var model struct { + ID int `db:"id"` + } + if err := db.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from ScanEachORM, got %v", err) + } + if _, err := db.QueryRaw("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from QueryRaw, got %v", err) + } + if _, err := db.Exec("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Exec, got %v", err) + } + if _, err := db.Prepare("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Prepare, got %v", err) + } + if _, err := db.Begin(); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from Begin, got %v", err) + } + if err := db.WithTx(nil); !errors.Is(err, ErrTxFuncNil) { + t.Fatalf("expected ErrTxFuncNil from WithTx, got %v", err) + } + if err := db.WithTx(func(tx *StarTx) error { return nil }); !errors.Is(err, ErrDBNotInitialized) { + t.Fatalf("expected ErrDBNotInitialized from WithTx, got %v", err) + } + + if _, err := db.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) { + t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err) + } + if _, err := db.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) { + t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err) + } + + _ = db.Stats() + db.SetMaxOpenConns(5) + db.SetMaxIdleConns(2) +} + +func TestStarTx_NotInitialized(t *testing.T) { + tx := &StarTx{} + var model struct { + ID int `db:"id"` + } + + if _, err := tx.Query("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from Query, got %v", err) + } + if err := tx.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from ScanEach, got %v", err) + } + if err := tx.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from ScanEachORM, got %v", err) + } + if _, err := tx.QueryRaw("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from QueryRaw, got %v", err) + } + if _, err := tx.Exec("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from Exec, got %v", err) + } + if _, err := tx.Prepare("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from Prepare, got %v", err) + } + if err := tx.Commit(); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from Commit, got %v", err) + } + if err := tx.Rollback(); !errors.Is(err, ErrTxNotInitialized) { + t.Fatalf("expected ErrTxNotInitialized from Rollback, got %v", err) + } + + if _, err := tx.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) { + t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err) + } + if _, err := tx.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) { + t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err) + } +} + +func TestStarStmt_NotInitialized(t *testing.T) { + stmt := &StarStmt{} + + if _, err := stmt.Query(); !errors.Is(err, ErrStmtNotInitialized) { + t.Fatalf("expected ErrStmtNotInitialized from Query, got %v", err) + } + if err := stmt.ScanEach(func(row *StarResult) error { return nil }); !errors.Is(err, ErrStmtNotInitialized) { + t.Fatalf("expected ErrStmtNotInitialized from ScanEach, got %v", err) + } + if err := stmt.ScanEachORM(nil, func(target interface{}) error { return nil }); !errors.Is(err, ErrTargetNil) { + t.Fatalf("expected ErrTargetNil from ScanEachORM, got %v", err) + } + if _, err := stmt.QueryRaw(); !errors.Is(err, ErrStmtNotInitialized) { + t.Fatalf("expected ErrStmtNotInitialized from QueryRaw, got %v", err) + } + if _, err := stmt.Exec(); !errors.Is(err, ErrStmtNotInitialized) { + t.Fatalf("expected ErrStmtNotInitialized from Exec, got %v", err) + } + if err := stmt.Close(); !errors.Is(err, ErrStmtNotInitialized) { + t.Fatalf("expected ErrStmtNotInitialized from Close, got %v", err) + } +} diff --git a/testing/batch_test.go b/testing/batch_test.go index a8a78a3..d183d03 100644 --- a/testing/batch_test.go +++ b/testing/batch_test.go @@ -2,6 +2,7 @@ package testing import ( "context" + "errors" "testing" "time" @@ -110,8 +111,53 @@ func TestStarDB_BatchInsert_Empty(t *testing.T) { values := [][]interface{}{} _, err := db.BatchInsert("users", columns, values) - if err == nil { - t.Error("Expected error with empty values, got nil") + if !errors.Is(err, stardb.ErrNoInsertValues) { + t.Errorf("Expected ErrNoInsertValues, got %v", err) + } +} + +func TestStarDB_BatchInsert_EmptyColumns(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + values := [][]interface{}{ + {"Alice"}, + } + + _, err := db.BatchInsert("users", nil, values) + if !errors.Is(err, stardb.ErrNoInsertColumns) { + t.Errorf("Expected ErrNoInsertColumns, got %v", err) + } +} + +func TestStarDB_BatchInsert_EmptyTableName(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + } + + _, err := db.BatchInsert("", columns, values) + if !errors.Is(err, stardb.ErrTableNameEmpty) { + t.Errorf("Expected ErrTableNameEmpty, got %v", err) + } +} + +func TestStarDB_BatchInsert_RowLengthMismatch(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com"}, + } + + _, err := db.BatchInsert("users", columns, values) + if !errors.Is(err, stardb.ErrBatchRowValueCountMismatch) { + t.Errorf("Expected ErrBatchRowValueCountMismatch, got %v", err) } } @@ -204,6 +250,236 @@ func TestStarDB_BatchInsertContext_Timeout(t *testing.T) { } } +func TestStarDB_BatchInsertMaxRows_Config(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + if got := db.BatchInsertMaxRows(); got != 0 { + t.Fatalf("Expected default chunk size 0, got %d", got) + } + + db.SetBatchInsertMaxRows(3) + if got := db.BatchInsertMaxRows(); got != 3 { + t.Fatalf("Expected chunk size 3, got %d", got) + } + + db.SetBatchInsertMaxRows(-10) + if got := db.BatchInsertMaxRows(); got != 0 { + t.Fatalf("Expected chunk size reset to 0, got %d", got) + } +} + +func TestStarDB_BatchInsertMaxParams_Config(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + if got := db.BatchInsertMaxParams(); got != 0 { + t.Fatalf("Expected default max params 0, got %d", got) + } + + db.SetBatchInsertMaxParams(100) + if got := db.BatchInsertMaxParams(); got != 100 { + t.Fatalf("Expected max params 100, got %d", got) + } + + db.SetBatchInsertMaxParams(-1) + if got := db.BatchInsertMaxParams(); got != 0 { + t.Fatalf("Expected max params reset to 0, got %d", got) + } +} + +func TestStarDB_BatchInsert_Chunked(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(2) + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + {"Charlie", "charlie@example.com", 35}, + {"David", "david@example.com", 40}, + {"Eva", "eva@example.com", 28}, + } + + result, err := db.BatchInsert("users", columns, values) + if err != nil { + t.Fatalf("Chunked BatchInsert failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + if affected != int64(len(values)) { + t.Fatalf("Expected %d affected rows, got %d", len(values), affected) + } + + rows, err := db.Query("SELECT COUNT(*) as count FROM users") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + if count := rows.Row(0).MustInt("count"); count != len(values) { + t.Fatalf("Expected %d rows in db, got %d", len(values), count) + } +} + +func TestStarDB_BatchInsert_ChunkedRollbackOnError(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(2) + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + {"Charlie", nil, 35}, // email NOT NULL, forces second chunk failure + } + + if _, err := db.BatchInsert("users", columns, values); err == nil { + t.Fatal("Expected chunked BatchInsert to fail, got nil") + } + + rows, err := db.Query("SELECT COUNT(*) as count FROM users") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + if count := rows.Row(0).MustInt("count"); count != 0 { + t.Fatalf("Expected rollback to keep table empty, got %d rows", count) + } +} + +func TestStarDB_BatchInsert_ChunkedByMaxParams(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(0) // disabled + db.SetBatchInsertMaxParams(4) // 3 columns -> 1 row per chunk + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + {"Charlie", "charlie@example.com", 35}, + } + + result, err := db.BatchInsert("users", columns, values) + if err != nil { + t.Fatalf("BatchInsert by max params failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + if affected != int64(len(values)) { + t.Fatalf("Expected %d affected rows, got %d", len(values), affected) + } +} + +func TestStarDB_BatchInsert_MaxParamsTooLow(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(0) + db.SetBatchInsertMaxParams(2) // columns=3 -> invalid + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + } + + _, err := db.BatchInsert("users", columns, values) + if !errors.Is(err, stardb.ErrBatchInsertMaxParamsTooLow) { + t.Fatalf("Expected ErrBatchInsertMaxParamsTooLow, got %v", err) + } +} + +func TestStarDB_BatchInsert_ChunkedHookMeta(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(2) + db.SetBatchInsertMaxParams(0) + + var metas []stardb.BatchExecMeta + db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) { + if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok { + metas = append(metas, meta) + } + }) + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + {"Charlie", "charlie@example.com", 35}, + {"David", "david@example.com", 40}, + {"Eva", "eva@example.com", 28}, + } + + _, err := db.BatchInsertContext(context.Background(), "users", columns, values) + if err != nil { + t.Fatalf("BatchInsertContext failed: %v", err) + } + + if len(metas) != 3 { + t.Fatalf("Expected 3 chunk metas, got %d", len(metas)) + } + + wantRows := []int{2, 2, 1} + for i, meta := range metas { + if meta.ChunkIndex != i+1 { + t.Fatalf("Chunk %d: expected index %d, got %d", i, i+1, meta.ChunkIndex) + } + if meta.ChunkCount != 3 { + t.Fatalf("Chunk %d: expected count 3, got %d", i, meta.ChunkCount) + } + if meta.ChunkRows != wantRows[i] { + t.Fatalf("Chunk %d: expected rows %d, got %d", i, wantRows[i], meta.ChunkRows) + } + if meta.TotalRows != len(values) { + t.Fatalf("Chunk %d: expected total rows %d, got %d", i, len(values), meta.TotalRows) + } + if meta.ColumnCount != len(columns) { + t.Fatalf("Chunk %d: expected column count %d, got %d", i, len(columns), meta.ColumnCount) + } + } +} + +func TestStarDB_BatchInsert_HookMetaAbsentWithoutChunking(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(0) + db.SetBatchInsertMaxParams(0) + + metaFound := false + db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) { + if _, ok := stardb.BatchExecMetaFromContext(ctx); ok { + metaFound = true + } + }) + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + } + + _, err := db.BatchInsertContext(context.Background(), "users", columns, values) + if err != nil { + t.Fatalf("BatchInsertContext failed: %v", err) + } + if metaFound { + t.Fatal("Expected no batch meta for non-chunked execution") + } +} + func TestStarDB_BatchInsertStructs_Basic(t *testing.T) { db := setupBatchTestDB(t) defer db.Close() @@ -246,6 +522,34 @@ func TestStarDB_BatchInsertStructs_Basic(t *testing.T) { } } +func TestStarDB_BatchInsertStructs_Chunked(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + db.SetBatchInsertMaxRows(2) + + users := []TestUser{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, + {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()}, + {Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()}, + {Name: "David", Email: "david@example.com", Age: 40, CreatedAt: time.Now()}, + {Name: "Eva", Email: "eva@example.com", Age: 28, CreatedAt: time.Now()}, + } + + result, err := db.BatchInsertStructs("users", users, "id") + if err != nil { + t.Fatalf("Chunked BatchInsertStructs failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + if affected != int64(len(users)) { + t.Fatalf("Expected %d affected rows, got %d", len(users), affected) + } +} + func TestStarDB_BatchInsertStructs_Single(t *testing.T) { db := setupBatchTestDB(t) defer db.Close() @@ -276,8 +580,8 @@ func TestStarDB_BatchInsertStructs_Empty(t *testing.T) { users := []TestUser{} _, err := db.BatchInsertStructs("users", users, "id") - if err == nil { - t.Error("Expected error with empty slice, got nil") + if !errors.Is(err, stardb.ErrNoStructsToInsert) { + t.Errorf("Expected ErrNoStructsToInsert, got %v", err) } } @@ -288,8 +592,29 @@ func TestStarDB_BatchInsertStructs_NotSlice(t *testing.T) { user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25} _, err := db.BatchInsertStructs("users", user, "id") - if err == nil { - t.Error("Expected error with non-slice, got nil") + if !errors.Is(err, stardb.ErrStructsNotSlice) { + t.Errorf("Expected ErrStructsNotSlice, got %v", err) + } +} + +func TestStarDB_BatchInsertStructs_Nil(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + _, err := db.BatchInsertStructs("users", nil, "id") + if !errors.Is(err, stardb.ErrStructsNil) { + t.Errorf("Expected ErrStructsNil, got %v", err) + } +} + +func TestStarDB_BatchInsertStructs_NilPointer(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + var users *[]TestUser + _, err := db.BatchInsertStructs("users", users, "id") + if !errors.Is(err, stardb.ErrStructsPointerNil) { + t.Errorf("Expected ErrStructsPointerNil, got %v", err) } } diff --git a/testing/orm_test.go b/testing/orm_test.go index d856d2d..cb4d76e 100644 --- a/testing/orm_test.go +++ b/testing/orm_test.go @@ -2,8 +2,11 @@ package testing import ( "context" + "errors" "testing" "time" + + "b612.me/stardb" ) type User struct { @@ -28,6 +31,12 @@ type NestedUser struct { Profile `db:"---"` } +type UserWithPrivateField struct { + ID int64 `db:"id"` + Name string `db:"name"` + age int `db:"age"` +} + func TestStarRows_Orm_Single(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -93,6 +102,47 @@ func TestStarRows_Orm_Multiple(t *testing.T) { } } +func TestStarRows_Orm_Array(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var users [3]User + err = rows.Orm(&users) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + expectedNames := []string{"Alice", "Bob", "Charlie"} + for i, user := range users { + if user.Name != expectedNames[i] { + t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name) + } + } +} + +func TestStarRows_Orm_ArrayTooSmall(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var users [2]User + err = rows.Orm(&users) + if err == nil { + t.Error("Expected error when target array is smaller than row count, got nil") + } +} + func TestStarRows_Orm_Empty(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -126,8 +176,105 @@ func TestStarRows_Orm_NotPointer(t *testing.T) { var user User err = rows.Orm(user) // Not a pointer + if !errors.Is(err, stardb.ErrTargetNotPointer) { + t.Errorf("Expected ErrTargetNotPointer, got %v", err) + } +} + +func TestStarRows_Orm_NilTarget(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + err = rows.Orm(nil) + if !errors.Is(err, stardb.ErrTargetNil) { + t.Errorf("Expected ErrTargetNil, got %v", err) + } +} + +func TestStarRows_Orm_NilPointerTarget(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user *User + err = rows.Orm(user) + if !errors.Is(err, stardb.ErrTargetPointerNil) { + t.Errorf("Expected ErrTargetPointerNil, got %v", err) + } +} + +func TestStarRows_Orm_MissingColumns_NonStrict(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user User + err = rows.Orm(&user) + if err != nil { + t.Fatalf("Expected non-strict ORM to ignore missing columns, got error: %v", err) + } + + if user.Name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", user.Name) + } +} + +func TestStarRows_Orm_MissingColumns_Strict(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetStrictORM(true) + + rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user User + err = rows.Orm(&user) if err == nil { - t.Errorf("Expected error when passing non-pointer, got nil") + t.Fatalf("Expected strict ORM to fail on missing columns, got nil") + } +} + +func TestStarRows_Orm_UnexportedTaggedField(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user UserWithPrivateField + err = rows.Orm(&user) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if user.ID <= 0 { + t.Errorf("Expected positive ID, got %d", user.ID) + } + if user.Name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", user.Name) } } @@ -355,6 +502,28 @@ func TestStarDB_QueryXContext(t *testing.T) { } } +func TestStarDB_QueryX_MissingField(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + user := User{Name: "Alice"} + + _, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":unknown") + if !errors.Is(err, stardb.ErrFieldNotFound) { + t.Errorf("Expected ErrFieldNotFound, got %v", err) + } +} + +func TestStarDB_QueryX_NilTarget(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + _, err := db.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id") + if !errors.Is(err, stardb.ErrTargetNil) { + t.Errorf("Expected ErrTargetNil, got %v", err) + } +} + func TestStarDB_QueryXS(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -380,6 +549,16 @@ func TestStarDB_QueryXS(t *testing.T) { } } +func TestStarDB_QueryXS_NilTargets(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + _, err := db.QueryXS(nil, "SELECT * FROM users") + if !errors.Is(err, stardb.ErrTargetsNil) { + t.Errorf("Expected ErrTargetsNil, got %v", err) + } +} + func TestStarDB_ExecX(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -441,6 +620,28 @@ func TestStarDB_ExecXContext(t *testing.T) { } } +func TestStarDB_ExecX_MissingField(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + user := User{Name: "Alice", Age: 99} + + _, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":unknown") + if !errors.Is(err, stardb.ErrFieldNotFound) { + t.Errorf("Expected ErrFieldNotFound, got %v", err) + } +} + +func TestStarDB_ExecX_NilTarget(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + _, err := db.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id") + if !errors.Is(err, stardb.ErrTargetNil) { + t.Errorf("Expected ErrTargetNil, got %v", err) + } +} + func TestStarDB_ExecXS(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -470,6 +671,16 @@ func TestStarDB_ExecXS(t *testing.T) { } } +func TestStarDB_ExecXS_NilTargets(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + _, err := db.ExecXS(nil, "UPDATE users SET age = age") + if !errors.Is(err, stardb.ErrTargetsNil) { + t.Errorf("Expected ErrTargetsNil, got %v", err) + } +} + func TestStarTx_Insert(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -592,6 +803,22 @@ func TestStarTx_QueryX(t *testing.T) { } } +func TestStarTx_QueryX_NilTarget(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + _, err = tx.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id") + if !errors.Is(err, stardb.ErrTargetNil) { + t.Errorf("Expected ErrTargetNil, got %v", err) + } +} + func TestStarTx_ExecX(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -627,6 +854,22 @@ func TestStarTx_ExecX(t *testing.T) { } } +func TestStarTx_ExecX_NilTarget(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + _, err = tx.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id") + if !errors.Is(err, stardb.ErrTargetNil) { + t.Errorf("Expected ErrTargetNil, got %v", err) + } +} + func TestStarTx_Rollback(t *testing.T) { db := setupTestDB(t) defer db.Close() diff --git a/testing/perf_test.go b/testing/perf_test.go new file mode 100644 index 0000000..106fec2 --- /dev/null +++ b/testing/perf_test.go @@ -0,0 +1,128 @@ +package testing + +import ( + "fmt" + "testing" + "time" + + "b612.me/stardb" + _ "modernc.org/sqlite" +) + +func setupBenchmarkDB(b *testing.B) *stardb.StarDB { + b.Helper() + + db := &stardb.StarDB{} + if err := db.Open("sqlite", ":memory:"); err != nil { + b.Fatalf("Failed to open database: %v", err) + } + + _, err := db.Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + email TEXT NOT NULL, + age INTEGER, + balance REAL, + active BOOLEAN, + created_at DATETIME + ) + `) + if err != nil { + b.Fatalf("Failed to create table: %v", err) + } + + _, err = db.Exec(` + INSERT INTO users (name, email, age, balance, active, created_at) VALUES + ('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'), + ('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'), + ('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00') + `) + if err != nil { + b.Fatalf("Failed to insert seed data: %v", err) + } + + return db +} + +func BenchmarkQueryX(b *testing.B) { + db := setupBenchmarkDB(b) + defer db.Close() + + target := User{Name: "Alice"} + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.QueryX(&target, "SELECT * FROM users WHERE name = ?", ":name") + if err != nil { + b.Fatalf("QueryX failed: %v", err) + } + _ = rows.Close() + } +} + +func BenchmarkOrm(b *testing.B) { + db := setupBenchmarkDB(b) + defer db.Close() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + b.Fatalf("Query failed: %v", err) + } + + var users []User + if err := rows.Orm(&users); err != nil { + _ = rows.Close() + b.Fatalf("Orm failed: %v", err) + } + _ = rows.Close() + } +} + +func BenchmarkScanEach(b *testing.B) { + db := setupBenchmarkDB(b) + defer db.Close() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + count := 0 + err := db.ScanEach("SELECT * FROM users ORDER BY name", func(row *stardb.StarResult) error { + _ = row.MustString("name") + count++ + return nil + }) + if err != nil { + b.Fatalf("ScanEach failed: %v", err) + } + if count != 3 { + b.Fatalf("Unexpected row count: %d", count) + } + } +} + +func BenchmarkBatchInsert(b *testing.B) { + db := setupBenchmarkDB(b) + defer db.Close() + + columns := []string{"name", "email", "age", "balance", "active", "created_at"} + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + base := i * 2 + values := [][]interface{}{ + {fmt.Sprintf("bench_user_%d", base), fmt.Sprintf("bench_%d@example.com", base), 20 + (base % 20), 99.5, true, time.Now()}, + {fmt.Sprintf("bench_user_%d", base+1), fmt.Sprintf("bench_%d@example.com", base+1), 20 + ((base + 1) % 20), 199.5, false, time.Now()}, + } + if _, err := db.BatchInsert("users", columns, values); err != nil { + b.Fatalf("BatchInsert failed: %v", err) + } + } +} diff --git a/testing/pool_test.go b/testing/pool_test.go index d9a9552..b87a624 100644 --- a/testing/pool_test.go +++ b/testing/pool_test.go @@ -95,6 +95,32 @@ func TestStarDB_SetPoolConfig_Zero(t *testing.T) { } } +func TestStarDB_SetPoolConfig_NilConfig(t *testing.T) { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + db.SetPoolConfig(nil) + + err = db.Ping() + if err != nil { + t.Errorf("Ping failed after SetPoolConfig(nil): %v", err) + } +} + +func TestStarDB_SetPoolConfig_BeforeOpen(t *testing.T) { + db := stardb.NewStarDB() + + db.SetPoolConfig(&stardb.PoolConfig{ + MaxOpenConns: 10, + }) + + // should not panic when called before Open +} + func TestOpenWithPool_Default(t *testing.T) { db, err := stardb.OpenWithPool("sqlite", ":memory:", nil) if err != nil { diff --git a/testing/result_test.go b/testing/result_test.go index 0d56fde..0164654 100644 --- a/testing/result_test.go +++ b/testing/result_test.go @@ -1,8 +1,11 @@ package testing import ( + "errors" "testing" "time" + + "b612.me/stardb" ) func TestStarResult_MustString(t *testing.T) { @@ -229,3 +232,100 @@ func TestStarResultCol_MustBool(t *testing.T) { } } } + +func TestStarResult_GetColumnNotFoundError(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + _, err = row.GetString("does_not_exist") + if !errors.Is(err, stardb.ErrColumnNotFound) { + t.Fatalf("Expected ErrColumnNotFound, got %v", err) + } +} + +func TestStarResult_GetNullValues(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + _, err := db.Exec( + "INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)", + "NullUser", "null@example.com", nil, nil, nil, nil, + ) + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NullUser") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + + name, err := row.GetNullString("name") + if err != nil { + t.Fatalf("GetNullString failed: %v", err) + } + if !name.Valid || name.String != "NullUser" { + t.Fatalf("Expected valid name NullUser, got %+v", name) + } + + age, err := row.GetNullInt64("age") + if err != nil { + t.Fatalf("GetNullInt64 failed: %v", err) + } + if age.Valid { + t.Fatalf("Expected NULL age, got %+v", age) + } + + balance, err := row.GetNullFloat64("balance") + if err != nil { + t.Fatalf("GetNullFloat64 failed: %v", err) + } + if balance.Valid { + t.Fatalf("Expected NULL balance, got %+v", balance) + } + + active, err := row.GetNullBool("active") + if err != nil { + t.Fatalf("GetNullBool failed: %v", err) + } + if active.Valid { + t.Fatalf("Expected NULL active, got %+v", active) + } + + createdAt, err := row.GetNullTime("created_at") + if err != nil { + t.Fatalf("GetNullTime failed: %v", err) + } + if createdAt.Valid { + t.Fatalf("Expected NULL created_at, got %+v", createdAt) + } +} + +func TestStarResult_GetNullTime_Valid(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT created_at FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + value, err := rows.Row(0).GetNullTime("created_at") + if err != nil { + t.Fatalf("GetNullTime failed: %v", err) + } + if !value.Valid { + t.Fatal("Expected valid created_at") + } +} diff --git a/testing/rows_test.go b/testing/rows_test.go index 8f18aa0..18a81bf 100644 --- a/testing/rows_test.go +++ b/testing/rows_test.go @@ -26,6 +26,12 @@ func TestStarRows_Row(t *testing.T) { if len(row.Result()) != 0 { t.Errorf("Expected empty result for out of bounds index") } + + // Test negative index + row = rows.Row(-1) + if len(row.Result()) != 0 { + t.Errorf("Expected empty result for negative index") + } } func TestStarRows_Col(t *testing.T) { diff --git a/testing/stardb_test.go b/testing/stardb_test.go index ffa0261..68a148f 100644 --- a/testing/stardb_test.go +++ b/testing/stardb_test.go @@ -2,6 +2,10 @@ package testing import ( "context" + "errors" + "strings" + "sync" + "sync/atomic" "testing" "time" @@ -60,6 +64,457 @@ func TestStarDB_QueryContext(t *testing.T) { } } +func TestStarDB_QueryRaw(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.QueryRaw("SELECT name FROM users WHERE age > ? ORDER BY name", 25) + if err != nil { + t.Fatalf("QueryRaw failed: %v", err) + } + defer rows.Close() + + var ( + name string + count int + ) + for rows.Next() { + if err := rows.Scan(&name); err != nil { + t.Fatalf("Scan failed: %v", err) + } + count++ + } + if err := rows.Err(); err != nil { + t.Fatalf("Rows.Err failed: %v", err) + } + if count != 2 { + t.Errorf("Expected 2 rows, got %d", count) + } +} + +func TestStarDB_QueryRawContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + rows, err := db.QueryRawContext(ctx, "SELECT name FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("QueryRawContext failed: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatal("Expected at least one row") + } + var name string + if err := rows.Scan(&name); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if name != "Alice" { + t.Errorf("Expected name Alice, got %s", name) + } +} + +func TestStarDB_QueryRaw_EmptyQuery(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + _, err := db.QueryRaw(" ") + if !errors.Is(err, stardb.ErrQueryEmpty) { + t.Fatalf("Expected ErrQueryEmpty, got %v", err) + } +} + +func TestStarDB_ScanEach(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + var names []string + err := db.ScanEach("SELECT name FROM users ORDER BY name", func(row *stardb.StarResult) error { + names = append(names, row.MustString("name")) + return nil + }) + if err != nil { + t.Fatalf("ScanEach failed: %v", err) + } + + if len(names) != 3 { + t.Fatalf("Expected 3 rows, got %d", len(names)) + } + if names[0] != "Alice" || names[1] != "Bob" || names[2] != "Charlie" { + t.Fatalf("Unexpected names: %v", names) + } +} + +func TestStarDB_ScanEach_Stop(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + count := 0 + err := db.ScanEach("SELECT name FROM users ORDER BY name", func(row *stardb.StarResult) error { + count++ + if row.MustString("name") == "Bob" { + return stardb.ErrScanStopped + } + return nil + }) + if err != nil { + t.Fatalf("ScanEach stop failed: %v", err) + } + if count != 2 { + t.Fatalf("Expected callback count 2, got %d", count) + } +} + +func TestStarDB_ScanEach_NilCallback(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + err := db.ScanEach("SELECT name FROM users", nil) + if !errors.Is(err, stardb.ErrScanFuncNil) { + t.Fatalf("Expected ErrScanFuncNil, got %v", err) + } +} + +func TestStarDB_ScanEachORM(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + var user User + var names []string + err := db.ScanEachORM("SELECT * FROM users ORDER BY name", &user, func(target interface{}) error { + u := target.(*User) + names = append(names, u.Name) + return nil + }) + if err != nil { + t.Fatalf("ScanEachORM failed: %v", err) + } + + if len(names) != 3 { + t.Fatalf("Expected 3 names, got %d", len(names)) + } + if names[0] != "Alice" || names[1] != "Bob" || names[2] != "Charlie" { + t.Fatalf("Unexpected names: %v", names) + } +} + +func TestStarDB_ScanEachORM_NilCallback(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + var user User + err := db.ScanEachORM("SELECT * FROM users", &user, nil) + if !errors.Is(err, stardb.ErrScanORMFuncNil) { + t.Fatalf("Expected ErrScanORMFuncNil, got %v", err) + } +} + +func TestStarDB_PlaceholderDollar(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetPlaceholderStyle(stardb.PlaceholderDollar) + + rows, err := db.Query("SELECT name FROM users WHERE name = ? AND age = ?", "Alice", 25) + if err != nil { + t.Fatalf("Query with dollar placeholders failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Fatalf("Expected 1 row, got %d", rows.Length()) + } + if got := rows.Row(0).MustString("name"); got != "Alice" { + t.Fatalf("Expected Alice, got %s", got) + } + + _, err = db.Exec("UPDATE users SET age = ? WHERE name = ?", 26, "Alice") + if err != nil { + t.Fatalf("Exec with dollar placeholders failed: %v", err) + } +} + +func TestStarDB_SQLHooks(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetPlaceholderStyle(stardb.PlaceholderDollar) + + var beforeCount int64 + var afterCount int64 + var mu sync.Mutex + var beforeQuery string + var afterQuery string + var afterErr error + + db.SetSQLHooks( + func(ctx context.Context, query string, args []interface{}) { + atomic.AddInt64(&beforeCount, 1) + mu.Lock() + beforeQuery = query + mu.Unlock() + }, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + atomic.AddInt64(&afterCount, 1) + mu.Lock() + afterQuery = query + afterErr = err + mu.Unlock() + }, + ) + + if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 31, "Bob"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + if atomic.LoadInt64(&beforeCount) == 0 { + t.Fatal("Expected before hook to be called") + } + if atomic.LoadInt64(&afterCount) == 0 { + t.Fatal("Expected after hook to be called") + } + if !strings.Contains(beforeQuery, "$1") { + t.Fatalf("Expected converted query in before hook, got %s", beforeQuery) + } + if !strings.Contains(afterQuery, "$1") { + t.Fatalf("Expected converted query in after hook, got %s", afterQuery) + } + if afterErr != nil { + t.Fatalf("Expected nil error in after hook, got %v", afterErr) + } + + _, execErr := db.Exec("UPDATE table_does_not_exist SET age = ? WHERE name = ?", 31, "Bob") + if execErr == nil { + t.Fatal("Expected execution error for invalid table") + } + if atomic.LoadInt64(&afterCount) < 2 { + t.Fatalf("Expected after hook call count >= 2, got %d", atomic.LoadInt64(&afterCount)) + } + if afterErr == nil { + t.Fatal("Expected after hook to capture non-nil error for failed SQL") + } +} + +func TestStarDB_SQLHooks_SlowThreshold(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + var afterCount int64 + db.SetSQLHooks( + nil, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + atomic.AddInt64(&afterCount, 1) + }, + ) + + db.SetSQLSlowThreshold(time.Hour) + if got := db.SQLSlowThreshold(); got != time.Hour { + t.Fatalf("Expected threshold 1h, got %v", got) + } + + if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 41, "Alice"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + if atomic.LoadInt64(&afterCount) != 0 { + t.Fatalf("Expected after hook to be skipped under threshold, got %d", afterCount) + } + + _, err := db.Exec("UPDATE table_does_not_exist SET age = ? WHERE name = ?", 31, "Bob") + if err == nil { + t.Fatal("Expected execution error for invalid table") + } + if atomic.LoadInt64(&afterCount) != 1 { + t.Fatalf("Expected error path to trigger after hook, got %d", afterCount) + } + + db.SetSQLSlowThreshold(0) + if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 42, "Alice"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + if atomic.LoadInt64(&afterCount) != 2 { + t.Fatalf("Expected after hook to run after disabling threshold, got %d", afterCount) + } +} + +func TestStarDB_SQLHooks_FingerprintMeta(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintEnabled(true) + + var gotMeta stardb.SQLHookMeta + var metaFound bool + db.SetSQLHooks( + nil, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx) + }, + ) + + if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 31, "Bob"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + if !metaFound { + t.Fatal("Expected SQL fingerprint metadata in hook context") + } + if gotMeta.Fingerprint != "update users set age = ? where name = ?" { + t.Fatalf("Unexpected fingerprint: %q", gotMeta.Fingerprint) + } +} + +func TestStarDB_SQLHooks_FingerprintMetaMaskLiterals(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintEnabled(true) + db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) + + var gotMeta stardb.SQLHookMeta + var metaFound bool + db.SetSQLHooks( + nil, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx) + }, + ) + + if _, err := db.Exec("UPDATE users SET age = 42 WHERE name = 'Bob' AND age < 100"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + if !metaFound { + t.Fatal("Expected SQL fingerprint metadata in hook context") + } + if gotMeta.Fingerprint != "update users set age = ? where name = ? and age < ?" { + t.Fatalf("Unexpected fingerprint for mask mode: %q", gotMeta.Fingerprint) + } +} + +func TestStarDB_SQLHooks_FingerprintMetaMaskLiterals_StripComments(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintEnabled(true) + db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) + db.SetSQLFingerprintKeepComments(false) // default, explicit for clarity + + var gotMeta stardb.SQLHookMeta + var metaFound bool + db.SetSQLHooks( + nil, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx) + }, + ) + + if _, err := db.Exec("UPDATE users SET age = 42 /* trace:abc */ WHERE name = 'Bob'"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + if !metaFound { + t.Fatal("Expected SQL fingerprint metadata in hook context") + } + if gotMeta.Fingerprint != "update users set age = ? where name = ?" { + t.Fatalf("Unexpected fingerprint with stripped comments: %q", gotMeta.Fingerprint) + } +} + +func TestStarDB_SQLHooks_FingerprintMetaMaskLiterals_KeepComments(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintEnabled(true) + db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) + db.SetSQLFingerprintKeepComments(true) + + var gotMeta stardb.SQLHookMeta + var metaFound bool + db.SetSQLHooks( + nil, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + gotMeta, metaFound = stardb.SQLHookMetaFromContext(ctx) + }, + ) + + if _, err := db.Exec("UPDATE users SET age = 42 /* trace:abc */ WHERE name = 'Bob'"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + if !metaFound { + t.Fatal("Expected SQL fingerprint metadata in hook context") + } + if gotMeta.Fingerprint != "update users set age = ? /* trace:abc */ where name = ?" { + t.Fatalf("Unexpected fingerprint with kept comments: %q", gotMeta.Fingerprint) + } +} + +func TestStarDB_SQLFingerprintCounter(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintCounterEnabled(true) + db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) + db.SetSQLFingerprintKeepComments(false) + + if _, err := db.Exec("UPDATE users SET age = 41 WHERE name = 'Alice'"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + if _, err := db.Exec("UPDATE users SET age = 42 WHERE name = 'Bob'"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + counters := db.SQLFingerprintCounters() + key := "update users set age = ? where name = ?" + if counters[key] != 2 { + t.Fatalf("Expected fingerprint %q count=2, got %d", key, counters[key]) + } + + db.ResetSQLFingerprintCounters() + if got := len(db.SQLFingerprintCounters()); got != 0 { + t.Fatalf("Expected counters to be reset, got %d", got) + } +} + +func TestStarDB_SQLFingerprintCounterDisabled(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintCounterEnabled(false) + if _, err := db.Exec("UPDATE users SET age = 31 WHERE name = ?", "Bob"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + + if got := len(db.SQLFingerprintCounters()); got != 0 { + t.Fatalf("Expected no counters when disabled, got %d", got) + } +} + +func TestStarDB_SQLHooks_FingerprintMetaDisabled(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetSQLFingerprintEnabled(false) + + metaFound := false + db.SetSQLHooks( + nil, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + _, metaFound = stardb.SQLHookMetaFromContext(ctx) + }, + ) + + if _, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 31, "Bob"); err != nil { + t.Fatalf("Exec failed: %v", err) + } + if metaFound { + t.Fatal("Expected no SQL fingerprint metadata when disabled") + } +} + func TestStarDB_Exec(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -154,6 +609,119 @@ func TestStarDB_Prepare(t *testing.T) { } } +func TestStarStmt_QueryRaw(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + stmt, err := db.Prepare("SELECT name FROM users WHERE name = ?") + if err != nil { + t.Fatalf("Prepare failed: %v", err) + } + defer stmt.Close() + + rows, err := stmt.QueryRaw("Bob") + if err != nil { + t.Fatalf("Stmt.QueryRaw failed: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatal("Expected one row for Bob") + } + var name string + if err := rows.Scan(&name); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if name != "Bob" { + t.Errorf("Expected Bob, got %s", name) + } +} + +func TestStarStmt_ScanEach(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + stmt, err := db.Prepare("SELECT name FROM users WHERE age >= ? ORDER BY name") + if err != nil { + t.Fatalf("Prepare failed: %v", err) + } + defer stmt.Close() + + var names []string + err = stmt.ScanEach(func(row *stardb.StarResult) error { + names = append(names, row.MustString("name")) + return nil + }, 30) + if err != nil { + t.Fatalf("Stmt.ScanEach failed: %v", err) + } + if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" { + t.Fatalf("Unexpected names: %v", names) + } +} + +func TestStarStmt_ScanEachORM(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + stmt, err := db.Prepare("SELECT * FROM users WHERE age >= ? ORDER BY name") + if err != nil { + t.Fatalf("Prepare failed: %v", err) + } + defer stmt.Close() + + var user User + var names []string + err = stmt.ScanEachORM(&user, func(target interface{}) error { + u := target.(*User) + names = append(names, u.Name) + return nil + }, 30) + if err != nil { + t.Fatalf("Stmt.ScanEachORM failed: %v", err) + } + if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" { + t.Fatalf("Unexpected names: %v", names) + } +} + +func TestStarStmt_SQLHooks(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + var ( + beforeCount int64 + afterCount int64 + ) + db.SetSQLHooks( + func(ctx context.Context, query string, args []interface{}) { + atomic.AddInt64(&beforeCount, 1) + }, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + atomic.AddInt64(&afterCount, 1) + }, + ) + + stmt, err := db.Prepare("SELECT name FROM users WHERE name = ?") + if err != nil { + t.Fatalf("Prepare failed: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query("Alice") + if err != nil { + t.Fatalf("Stmt.Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Fatalf("Expected 1 row, got %d", rows.Length()) + } + if atomic.LoadInt64(&beforeCount) == 0 || atomic.LoadInt64(&afterCount) == 0 { + t.Fatalf("Expected stmt execution to trigger hooks, before=%d after=%d", beforeCount, afterCount) + } +} + func TestStarDB_Transaction(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -219,6 +787,131 @@ func TestStarDB_TransactionRollback(t *testing.T) { } } +func TestStarTx_QueryRaw(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + rows, err := tx.QueryRaw("SELECT name FROM users WHERE name = ?", "Charlie") + if err != nil { + t.Fatalf("Tx.QueryRaw failed: %v", err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatal("Expected one row for Charlie") + } + var name string + if err := rows.Scan(&name); err != nil { + t.Fatalf("Scan failed: %v", err) + } + if name != "Charlie" { + t.Errorf("Expected Charlie, got %s", name) + } +} + +func TestStarTx_ScanEach(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + var names []string + err = tx.ScanEach("SELECT name FROM users WHERE age >= ? ORDER BY name", func(row *stardb.StarResult) error { + names = append(names, row.MustString("name")) + return nil + }, 30) + if err != nil { + t.Fatalf("Tx.ScanEach failed: %v", err) + } + if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" { + t.Fatalf("Unexpected names: %v", names) + } +} + +func TestStarTx_ScanEachORM(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + var user User + var names []string + err = tx.ScanEachORM("SELECT * FROM users WHERE age >= ? ORDER BY name", &user, func(target interface{}) error { + u := target.(*User) + names = append(names, u.Name) + return nil + }, 30) + if err != nil { + t.Fatalf("Tx.ScanEachORM failed: %v", err) + } + if len(names) != 2 || names[0] != "Bob" || names[1] != "Charlie" { + t.Fatalf("Unexpected names: %v", names) + } +} + +func TestStarDB_WithTx(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + err := db.WithTx(func(tx *stardb.StarTx) error { + _, err := tx.Exec("UPDATE users SET age = ? WHERE name = ?", 41, "Alice") + return err + }) + if err != nil { + t.Fatalf("WithTx failed: %v", err) + } + + rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if got := rows.Row(0).MustInt("age"); got != 41 { + t.Fatalf("Expected age 41, got %d", got) + } +} + +func TestStarDB_WithTx_RollbackOnError(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + expectedErr := errors.New("business error") + err := db.WithTx(func(tx *stardb.StarTx) error { + if _, err := tx.Exec("UPDATE users SET age = ? WHERE name = ?", 55, "Alice"); err != nil { + return err + } + return expectedErr + }) + if !errors.Is(err, expectedErr) { + t.Fatalf("Expected %v, got %v", expectedErr, err) + } + + rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if got := rows.Row(0).MustInt("age"); got == 55 { + t.Fatalf("Expected rollback to keep original age, got %d", got) + } +} + func TestStarDB_SetMaxConnections(t *testing.T) { db := setupTestDB(t) defer db.Close() @@ -231,3 +924,80 @@ func TestStarDB_SetMaxConnections(t *testing.T) { t.Errorf("Expected MaxOpenConnections 10, got %d", stats.MaxOpenConnections) } } + +func TestStarDB_ConcurrentRuntimeAndQuery(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(1) + + db.SetSQLHooks( + func(ctx context.Context, query string, args []interface{}) {}, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {}, + ) + + var wg sync.WaitGroup + + for i := 0; i < 4; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + rows, err := db.Query("SELECT id FROM users WHERE name = ?", "Alice") + if err != nil { + t.Errorf("Concurrent query failed: %v", err) + return + } + _ = rows.Close() + } + }() + } + + for i := 0; i < 2; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + if j%2 == 0 { + db.SetPlaceholderStyle(stardb.PlaceholderDollar) + db.SetSQLSlowThreshold(time.Millisecond) + } else { + db.SetPlaceholderStyle(stardb.PlaceholderQuestion) + db.SetSQLSlowThreshold(0) + } + } + }() + } + + wg.Wait() +} + +func TestStarTx_SQLHooks(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + var beforeCount int64 + var afterCount int64 + db.SetSQLHooks( + func(ctx context.Context, query string, args []interface{}) { + atomic.AddInt64(&beforeCount, 1) + }, + func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) { + atomic.AddInt64(&afterCount, 1) + }, + ) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + if _, err := tx.Exec("UPDATE users SET age = ? WHERE name = ?", 27, "Alice"); err != nil { + t.Fatalf("Tx.Exec failed: %v", err) + } + + if atomic.LoadInt64(&beforeCount) == 0 || atomic.LoadInt64(&afterCount) == 0 { + t.Fatalf("Expected tx execution to trigger hooks, before=%d after=%d", beforeCount, afterCount) + } +} diff --git a/tx.go b/tx.go index 541f453..7e2478a 100644 --- a/tx.go +++ b/tx.go @@ -3,7 +3,8 @@ package stardb import ( "context" "database/sql" - "errors" + "strings" + "time" ) // StarTx represents a database transaction @@ -12,6 +13,13 @@ type StarTx struct { db *StarDB } +func (t *StarTx) ensureTx() error { + if t == nil || t.tx == nil || t.db == nil { + return ErrTxNotInitialized + } + return nil +} + // Query executes a query within the transaction func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) { return t.query(nil, query, args...) @@ -22,21 +30,53 @@ func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interfa return t.query(ctx, query, args...) } -// query is the internal query implementation -func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { - if err := t.db.Ping(); err != nil { +// QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing. +func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) { + return t.queryRaw(nil, query, args...) +} + +// QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing. +func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + return t.queryRaw(ctx, query, args...) +} + +func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { + if err := t.ensureTx(); err != nil { return nil, err } + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty + } - var rows *sql.Rows - var err error + query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args) + hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, hookArgs) + } + start := time.Now() + var ( + rows *sql.Rows + err error + ) if ctx == nil { rows, err = t.tx.Query(query, args...) } else { rows, err = t.tx.QueryContext(ctx, query, args...) } + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, hookArgs, duration, err) + } + if err != nil { + return nil, err + } + return rows, nil +} +// query is the internal query implementation +func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + rows, err := t.queryRaw(ctx, query, args...) if err != nil { return nil, err } @@ -47,10 +87,13 @@ func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) ( } if !t.db.ManualScan { - err = starRows.parse() + if err := starRows.parse(); err != nil { + _ = rows.Close() + return nil, err + } } - return starRows, err + return starRows, nil } // Exec executes a query within the transaction @@ -65,38 +108,97 @@ func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interfac // exec is the internal exec implementation func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - if err := t.db.Ping(); err != nil { + if err := t.ensureTx(); err != nil { return nil, err } - - if ctx == nil { - return t.tx.Exec(query, args...) + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } - return t.tx.ExecContext(ctx, query, args...) + + query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args) + hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, hookArgs) + } + start := time.Now() + + var ( + result sql.Result + err error + ) + if ctx == nil { + result, err = t.tx.Exec(query, args...) + } else { + result, err = t.tx.ExecContext(ctx, query, args...) + } + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, hookArgs, duration, err) + } + if err != nil { + return nil, err + } + return result, nil } // Prepare creates a prepared statement within the transaction func (t *StarTx) Prepare(query string) (*StarStmt, error) { + if err := t.ensureTx(); err != nil { + return nil, err + } + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty + } + + query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil) + hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, nil) + } + start := time.Now() + stmt, err := t.tx.Prepare(query) + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, nil, duration, err) + } if err != nil { return nil, err } - return &StarStmt{stmt: stmt, db: t.db}, nil + return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil } // PrepareContext creates a prepared statement with context func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) { + if err := t.ensureTx(); err != nil { + return nil, err + } + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty + } + + query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil) + hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook) + if beforeHook != nil { + beforeHook(hookCtx, query, nil) + } + start := time.Now() + stmt, err := t.tx.PrepareContext(ctx, query) + duration := time.Since(start) + if shouldRunAfterHook(afterHook, slowThreshold, duration, err) { + afterHook(hookCtx, query, nil, duration, err) + } if err != nil { return nil, err } - return &StarStmt{stmt: stmt, db: t.db}, nil + return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil } // QueryStmt executes a prepared statement query within the transaction func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := t.Prepare(query) if err != nil { @@ -108,8 +210,8 @@ func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) // QueryStmtContext executes a prepared statement query with context func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := t.PrepareContext(ctx, query) if err != nil { @@ -121,8 +223,8 @@ func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...int // ExecStmt executes a prepared statement within the transaction func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := t.Prepare(query) if err != nil { @@ -134,8 +236,8 @@ func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) // ExecStmtContext executes a prepared statement with context func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - if query == "" { - return nil, errors.New("query string cannot be empty") + if strings.TrimSpace(query) == "" { + return nil, ErrQueryEmpty } stmt, err := t.PrepareContext(ctx, query) if err != nil { @@ -147,10 +249,16 @@ func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...inte // Commit commits the transaction func (t *StarTx) Commit() error { + if err := t.ensureTx(); err != nil { + return err + } return t.tx.Commit() } // Rollback rolls back the transaction func (t *StarTx) Rollback() error { + if err := t.ensureTx(); err != nil { + return err + } return t.tx.Rollback() } diff --git a/tx_helper.go b/tx_helper.go new file mode 100644 index 0000000..109355c --- /dev/null +++ b/tx_helper.go @@ -0,0 +1,45 @@ +package stardb + +import ( + "context" + "database/sql" +) + +// WithTx runs fn in a transaction and handles commit/rollback automatically. +func (s *StarDB) WithTx(fn func(tx *StarTx) error) error { + return s.WithTxContext(context.Background(), nil, fn) +} + +// WithTxContext runs fn in a transaction with context/options and handles commit/rollback automatically. +func (s *StarDB) WithTxContext(ctx context.Context, opts *sql.TxOptions, fn func(tx *StarTx) error) (err error) { + if fn == nil { + return ErrTxFuncNil + } + if ctx == nil { + ctx = context.Background() + } + + tx, err := s.BeginTx(ctx, opts) + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + _ = tx.Rollback() + panic(p) + } + }() + + if err := fn(tx); err != nil { + _ = tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + _ = tx.Rollback() + return err + } + + return nil +}