Compare commits

..

No commits in common. "master" and "v1.0.0" have entirely different histories.

52 changed files with 864 additions and 10697 deletions

24
.gitignore vendored
View File

@ -1,24 +0,0 @@
# Binaries
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
# Go build cache
*.out
# Dependency directories
vendor/
# IDE
.idea
# OS
.DS_Store
# Agent local governance files
.sentrux/
agent_readme.md
target.md

8
.idea/.gitignore generated vendored
View File

@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 数据源本地存储已忽略文件
/dataSources/
/dataSources.local.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/

8
.idea/modules.xml generated
View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/stardb.iml" filepath="$PROJECT_DIR$/.idea/stardb.iml" />
</modules>
</component>
</project>

17
.idea/stardb.iml generated
View File

@ -1,17 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true">
<buildTags>
<option name="customFlags">
<array>
<option value="testing" />
</array>
</option>
</buildTags>
</component>
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>

6
.idea/vcs.xml generated
View File

@ -1,6 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$" vcs="Git" />
</component>
</project>

View File

@ -1,59 +0,0 @@
# Changelog
本文档记录 StarDB 的主要变更。
## [Unreleased] - 2026-03-20
### Added
- 新增可判定错误类型(`errors.Is` 友好):
- 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized`
- 参数/目标校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` 等
- 映射与批量写入:`ErrColumnNotFound` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch` 等
- 新增流式查询能力DB / Tx / Stmt
- `QueryRaw` / `QueryRawContext`
- `ScanEach` / `ScanEachContext`
- `ScanEachORM` / `ScanEachORMContext`
- 新增 NULL 安全取值:
- `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime`
- 新增 ORM 行为开关:
- `SetStrictORM(true)` 启用严格列检查
- `ClearReflectCache()` 清理反射缓存
- 新增 SQL 运行时可观测能力:
- Hook`SetSQLHooks` `SetSQLBeforeHook` `SetSQLAfterHook`
- 慢 SQL 阈值:`SetSQLSlowThreshold`
- 指纹:`SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments`
- 指纹计数:`SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters`
- Context 元信息:`SQLHookMetaFromContext` `BatchExecMetaFromContext`
- 新增占位符方言适配:
- `SetPlaceholderStyle(PlaceholderQuestion|PlaceholderDollar)``?` / `$1,$2...`
- 新增批量插入分片控制:
- `SetBatchInsertMaxRows`
- `SetBatchInsertMaxParams`
- 常见驱动参数上限自动识别SQLite / PostgreSQL / MySQL / SQL Server
### Changed
- 批量写入在开启分片或触发参数阈值时,改为事务内多分片执行,降低单条 SQL 过大风险。
- 分片批量写入结果语义明确:
- `RowsAffected()` 返回分片累计值
- `LastInsertId()` 返回最后一个分片的 insert id
- 内部结构按模块归档到 `internal/`,保持外部 API 稳定:
- `internal/convert`
- `internal/scanutil`
- `internal/sqlplaceholder`
- `internal/sqlruntime`
- README 重写为面向使用场景的说明,补齐能力边界、接入顺序和 API 细节。
### Behavior Notes
- 默认查询 `Query` 仍为内存模式(解析到 `StarRows`)。
- 关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`。
- SQL Hook、指纹与指纹计数默认关闭需显式开启。
- 批量分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。
### Tests
- 新增/补强测试覆盖:
- 流式查询与流式 ORM
- NULL 安全取值
- 严格 ORM 行为
- 占位符转换
- SQL Hook、慢 SQL 阈值、指纹模式、注释保留开关、指纹计数
- BatchInsert 分片(按行数/参数)、失败回滚与结果语义

201
LICENSE
View File

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

560
README.MD
View File

@ -1,560 +0,0 @@
# StarDB
StarDB 是对 Go `database/sql` 的轻量封装,目标是把常见的数据库操作做得更直白:
- 少量 API 覆盖日常 CRUD、事务、批量写入、结构体映射。
- 兼容原生 `database/sql` 心智,不引入重量级依赖。
- 在可读性、可调试性和性能之间做实用平衡。
适合:
- 想保留 SQL 控制权,但不想反复写样板代码。
- 需要轻量 ORM 映射(不是全功能 ORM
- 需要在生产里追踪 SQL可选 Hook + 慢 SQL 阈值)。
不适合:
- 需要完整领域模型关系管理、自动迁移、复杂查询 DSL 的项目。
## 安装
```bash
go get b612.me/stardb
```
要求:
- Go `>= 1.16`
- 自行选择并导入数据库驱动(本库只封装 `database/sql`
## 常见 DSN 示例
下面示例都可以直接用于 `db.Open(driver, dsn)`,替换为实际账号、密码、库名即可。
### MySQL`github.com/go-sql-driver/mysql`
```go
import _ "github.com/go-sql-driver/mysql"
dsn := "app:secret@tcp(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=true&loc=Local"
if err := db.Open("mysql", dsn); err != nil {
log.Fatal(err)
}
```
常用参数说明:
- `charset=utf8mb4`:避免字符集问题。
- `parseTime=true`:把 `DATETIME/TIMESTAMP` 解析为 `time.Time`。
- `loc=Local`:指定时间解析时区(也可改成 `Asia/Shanghai`)。
### PostgreSQL`github.com/lib/pq`
```go
import _ "github.com/lib/pq"
dsn := "host=127.0.0.1 port=5432 user=postgres password=secret dbname=demo sslmode=disable"
if err := db.Open("postgres", dsn); err != nil {
log.Fatal(err)
}
```
也可以用 URL 形式:
```go
urlDSN := "postgres://postgres:secret@127.0.0.1:5432/demo?sslmode=disable"
if err := db.Open("postgres", urlDSN); err != nil {
log.Fatal(err)
}
```
### SQLite`modernc.org/sqlite`
```go
import _ "modernc.org/sqlite"
// 文件数据库
if err := db.Open("sqlite", "file:demo.db"); err != nil {
log.Fatal(err)
}
// 内存数据库(适合测试)
if err := db.Open("sqlite", "file::memory:?cache=shared"); err != nil {
log.Fatal(err)
}
```
Windows 路径建议使用 `file:C:/data/demo.db` 这种写法,跨平台更稳。
## 能力概览
| 能力 | 主要 API | 说明 |
|---|---|---|
| 连接与连接池 | `Open` `Close` `Ping` `SetPoolConfig` | 保留原生 `sql.DB` 用法 |
| 常规查询 | `Query` `QueryContext` | 自动解析为 `StarRows` |
| 流式查询 | `QueryRaw` `ScanEach` | 大结果集不必全量进内存 |
| 流式 ORM | `ScanEachORM` | 逐行映射结构体 |
| 安全取值 | `Get*` `GetNull*` | 明确错误与 NULL 语义 |
| 结构体 ORM | `rows.Orm` | 支持单个、切片、数组映射 |
| 命名参数 | `QueryX` `ExecX` | `:field` 绑定结构体字段 |
| 结构体写入 | `Insert` `Update` | 通过 `db` tag 生成 SQL |
| 批量写入 | `BatchInsert` `BatchInsertStructs` `SetBatchInsertMaxRows` `SetBatchInsertMaxParams` | 多行插入,支持按行数/参数阈值分片 |
| 事务 | `Begin/Commit/Rollback` `WithTx` | 手动或托管事务 |
| 可观测性 | `SetSQLHooks` `SetSQLSlowThreshold` `SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments` `SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters` `SQLHookMetaFromContext` `BatchExecMetaFromContext` | Before/After Hook默认关闭支持指纹策略、命中计数与批量分片元信息 |
| 方言占位符 | `SetPlaceholderStyle` | `?` / `$1,$2...` |
| 查询构建 | `QueryBuilder` | 支持 `JOIN/GROUP BY/HAVING` |
## 场景选型
| 场景 | 首选 API | 说明 |
|---|---|---|
| 中小结果集查询 | `Query` + `rows.Orm` | 读取方便,开发效率高 |
| 大结果集查询 | `ScanEach` / `ScanEachORM` | 逐行处理,避免全量缓存 |
| 需要底层 `Scan` 控制 | `QueryRaw` | 直接返回 `*sql.Rows` |
| 批量写入 | `BatchInsert` + 分片阈值 | 控制单条 SQL 大小与参数数量 |
| SQL 可观测 | `SetSQLHooks` + `SetSQLSlowThreshold` + 指纹配置 | 支持慢 SQL、指纹、分片元信息 |
## 快速开始
```go
package main
import (
"log"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Age int `db:"age"`
}
func main() {
db := stardb.NewStarDB()
if err := db.Open("sqlite", "test.db"); err != nil {
log.Fatal(err)
}
defer db.Close()
rows, err := db.Query("SELECT id, name, age FROM users WHERE age >= ?", 18)
if err != nil {
log.Fatal(err)
}
defer rows.Close()
var users []User
if err := rows.Orm(&users); err != nil {
log.Fatal(err)
}
log.Printf("users: %d", len(users))
}
```
## 接入流程
按下面顺序接入,可在开发阶段先固定运行边界:
1. 建立连接并设置连接池。
2. 在查询路径中区分内存模式与流式模式。
3. 在批量写入路径设置分片阈值(按行数和参数数)。
4. 启用 SQL Hook、慢 SQL 阈值、指纹策略。
5. 在调用侧统一使用 `errors.Is` 判定错误类别。
一个常用初始化示例:
```go
db := stardb.NewStarDB()
if err := db.Open("mysql", dsn); err != nil {
return err
}
db.SetPoolConfig(&stardb.PoolConfig{
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: 10 * time.Minute,
})
db.SetBatchInsertMaxRows(500)
db.SetBatchInsertMaxParams(60000)
db.SetSQLSlowThreshold(200 * time.Millisecond)
db.SetSQLFingerprintEnabled(true)
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals)
db.SetSQLFingerprintKeepComments(false)
```
## API 指南
### 1) 连接与连接池
```go
db := stardb.NewStarDB()
_ = db.Open("mysql", "user:pass@tcp(localhost:3306)/app")
db.SetPoolConfig(&stardb.PoolConfig{
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: 10 * time.Minute,
})
```
也可以直接拿到底层连接:
```go
raw := db.DB()
raw.SetMaxOpenConns(50)
```
### 2) 三种查询模式
#### A. 内存模式(默认)
`Query` 会把结果解析为 `StarRows`,适合中小结果集:
```go
rows, err := db.Query("SELECT * FROM users WHERE active = ?", true)
if err != nil { /* ... */ }
defer rows.Close()
for i := 0; i < rows.Length(); i++ {
row := rows.Row(i)
_ = row.MustString("name")
}
```
#### B. 原生流式模式
`QueryRaw` 返回 `*sql.Rows`,完全按原生 `Scan` 处理:
```go
rawRows, err := db.QueryRaw("SELECT id, name FROM users")
if err != nil { /* ... */ }
defer rawRows.Close()
```
#### C. 回调流式模式(常用)
`ScanEach` 逐行回调,避免全量缓存:
关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`,不使用 `Query`。
```go
err := db.ScanEach("SELECT id, name FROM users", func(row *stardb.StarResult) error {
id := row.MustInt64("id")
name := row.MustString("name")
_ = id
_ = name
return nil
})
```
可通过 `stardb.ErrScanStopped` 提前终止:
```go
count := 0
_ = db.ScanEach("SELECT * FROM users", func(row *stardb.StarResult) error {
count++
if count >= 1000 {
return stardb.ErrScanStopped
}
return nil
})
```
### 3) 流式 ORM逐行映射
`ScanEachORM` 将每行映射到结构体,再回调。
```go
var model User
var users []User
err := db.ScanEachORM("SELECT id, name, age FROM users", &model, func(target interface{}) error {
u := *(target.(*User)) // 注意拷贝一份target 会被复用
users = append(users, u)
return nil
})
```
同样支持 `Tx` / `Stmt`
- `tx.ScanEachORM(...)`
- `stmt.ScanEachORM(...)`
### 4) 结果读取与 NULL 语义
#### Must 系列(无错误,失败给零值)
- `MustString` `MustInt64` `MustFloat64` `MustBool` ...
#### 安全系列(带错误)
- `GetString` `GetInt64` `GetFloat64`
- `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime`
```go
name, err := row.GetString("name")
age, err := row.GetNullInt64("age")
if age.Valid {
// use age.Int64
}
```
### 5) ORM 映射
```go
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
}
var u User
_ = rows.Orm(&u)
var list []User
_ = rows.Orm(&list)
```
严格列检查(字段/SQL 变更敏感场景可开启):
```go
db.SetStrictORM(true)
```
若结构体 tag 大范围调整,可清理反射缓存:
```go
stardb.ClearReflectCache()
```
### 6) 命名参数绑定
```go
type Filter struct {
Name string `db:"name"`
MinAge int `db:"min_age"`
}
f := Filter{Name: "Alice", MinAge: 18}
rows, err := db.QueryX(&f,
"SELECT * FROM users WHERE name = ? AND age >= ?",
":name", ":min_age")
```
### 7) 写入能力
#### Insert / Update
```go
_, _ = db.Insert(&user, "users", "id") // id 作为自增字段跳过
_, _ = db.Update(&user, "users", "id") // id 作为主键
```
#### BatchInsert
```go
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
}
_, _ = db.BatchInsert("users", columns, values)
```
如需避免单条 SQL 过大(参数过多),可打开分片:
```go
db.SetBatchInsertMaxRows(500) // 0 或负数表示关闭分片(默认)
db.SetBatchInsertMaxParams(60000) // 0 表示自动识别常见驱动参数上限
```
分片模式下会在一个事务里按块执行,避免部分写入成功。
自动识别当前覆盖SQLite `999`、PostgreSQL `65535`、MySQL `65535`、SQL Server `2100`。
分片行为细节:
- 分片阈值按更严格条件生效:`min(maxRows, maxParams/列数)`(忽略未设置项)。
- 分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。
- 分片执行失败会回滚整个批次。
- 分片结果语义:
- `RowsAffected()` 返回所有分片累计值。
- `LastInsertId()` 返回最后一个分片的 insert id。
#### BatchInsertStructs
```go
users := []User{{Name: "Alice"}, {Name: "Bob"}}
_, _ = db.BatchInsertStructs("users", users, "id")
```
### 8) 事务
#### 手动事务
```go
tx, err := db.Begin()
if err != nil { /* ... */ }
defer tx.Rollback()
if _, err := tx.Exec("UPDATE users SET age = age + 1 WHERE id = ?", 1); err != nil {
return err
}
return tx.Commit()
```
#### 托管事务(常用)
```go
err := db.WithTx(func(tx *stardb.StarTx) error {
if _, err := tx.Exec("UPDATE users SET age = ? WHERE id = ?", 26, 1); err != nil {
return err
}
if _, err := tx.Exec("INSERT INTO logs (msg) VALUES (?)", "age updated"); err != nil {
return err
}
return nil
})
```
`WithTx` 规则:
- `fn` 返回 `nil` -> `Commit`
- `fn` 返回错误 -> `Rollback`
- `fn` panic -> `Rollback` 后继续抛出 panic
### 9) SQL Hook 与慢 SQL 阈值
默认关闭;仅在显式设置时生效。
```go
db.SetSQLSlowThreshold(200 * time.Millisecond)
db.SetSQLFingerprintEnabled(true) // 可选:在 Hook context 附带 SQL 指纹
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) // 可选:指纹里脱敏数字/字符串字面量
db.SetSQLFingerprintKeepComments(false) // 默认 false指纹不保留 SQL 注释
db.SetSQLFingerprintCounterEnabled(true) // 可选:记录指纹命中次数(内存级)
db.SetSQLHooks(
func(ctx context.Context, query string, args []interface{}) {
// before
},
func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
// after
if hookMeta, ok := stardb.SQLHookMetaFromContext(ctx); ok {
_ = hookMeta.Fingerprint
}
if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok {
// chunked batch insert metadata
// meta.ChunkIndex / meta.ChunkCount / meta.ChunkRows ...
}
},
)
```
阈值行为:
- `threshold <= 0``After` 每次都触发
- `threshold > 0`:仅在“慢于阈值”或“执行出错”时触发
- 打开 `SetSQLFingerprintEnabled(true)` 后,可从 `SQLHookMetaFromContext` 获取 SQL 指纹
- 指纹模式:`SQLFingerprintBasic`(默认,仅归一化)/ `SQLFingerprintMaskLiterals`(归一化 + 字面量脱敏)
- `SetSQLFingerprintKeepComments(true)` 可保留注释文本(默认关闭,利于聚合)
- `SetSQLFingerprintCounterEnabled(true)` 后,可通过 `SQLFingerprintCounters()` 查看命中次数,`ResetSQLFingerprintCounters()` 清空
- 若是分片批量写入Hook 可通过 `BatchExecMetaFromContext` 读取分片元信息
Hook 上下文字段说明:
- `SQLHookMetaFromContext(ctx)`
- `Fingerprint`:按配置生成的 SQL 指纹。
- `BatchExecMetaFromContext(ctx)`(仅分片批量写入):
- `ChunkIndex`:当前分片序号(从 1 开始)
- `ChunkCount`:总分片数
- `ChunkRows`:当前分片行数
- `TotalRows`:本次批量总行数
- `ColumnCount`:本次写入列数
### 10) 占位符方言
```go
db.SetPlaceholderStyle(stardb.PlaceholderQuestion) // 默认
// 或
db.SetPlaceholderStyle(stardb.PlaceholderDollar) // ? -> $1,$2...
```
### 11) QueryBuilder
```go
query, args := stardb.NewQueryBuilder("users u").
Select("u.id", "u.name", "COUNT(o.id) AS order_count").
Join("LEFT JOIN orders o ON o.user_id = u.id").
Where("u.active = ?", true).
GroupBy("u.id", "u.name").
Having("COUNT(o.id) > ?", 2).
OrderBy("order_count DESC").
Limit(20).
Offset(0).
Build()
_ = query
_ = args
```
## 错误处理
库内置可判定错误,调用侧使用 `errors.Is` 做分支处理:
```go
if errors.Is(err, stardb.ErrDBNotInitialized) {
// 未初始化
}
if errors.Is(err, stardb.ErrColumnNotFound) {
// 字段/列不匹配
}
if errors.Is(err, stardb.ErrNoInsertValues) {
// 批量插入空数据
}
```
常见错误类别:
- 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized`
- 参数校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` ...
- 映射问题:`ErrColumnNotFound` `ErrFieldNotFound`
- 批量写入:`ErrNoInsertColumns` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch`
- 流式回调:`ErrScanFuncNil` `ErrScanORMFuncNil`
## 使用边界
1. 这是轻量封装,不是全功能 ORM。
- 不做模型关系管理has-many/association
- 不做自动迁移
- 不做复杂查询 DSL
2. 大结果集优先用流式 API。
- `Query` 适合中小结果集
- `ScanEach` / `ScanEachORM` 更稳
3. 日志 Hook 按需打开。
- 生产环境最好配合慢 SQL 阈值,减少噪音
4. `ScanEachORM` 回调里的 target 会复用。
- 需要持久化时请拷贝结构体值
## 测试、竞态与基准
```bash
# 根模块
go test ./...
go test -race ./...
go test -run ^$ -bench BenchmarkQueryBuilder_ -benchmem ./...
# testing 子模块(集成测试/基准)
cd testing
go test ./...
go test -race ./...
go test -run ^$ -bench "Benchmark(QueryX|Orm|ScanEach|BatchInsert)" -benchmem
```
## 支持数据库驱动
本库兼容所有实现 `database/sql` 的驱动。常见示例:
- SQLite: `_ "modernc.org/sqlite"`
- MySQL: `_ "github.com/go-sql-driver/mysql"`
- PostgreSQL: `_ "github.com/lib/pq"`
## License
Apache License 2.0

338
batch.go
View File

@ -1,338 +0,0 @@
package stardb
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
"sync/atomic"
)
type multiSQLResult struct {
results []sql.Result
}
type batchExecMetaKey struct{}
// BatchExecMeta contains chunk execution metadata for batch insert operations.
// It is attached to context for chunked execution and can be read in SQL hooks.
type BatchExecMeta struct {
ChunkIndex int
ChunkCount int
ChunkRows int
TotalRows int
ColumnCount int
}
// BatchExecMetaFromContext extracts batch chunk metadata from context.
func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) {
if ctx == nil {
return BatchExecMeta{}, false
}
meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta)
return meta, ok
}
func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, batchExecMetaKey{}, meta)
}
func (m multiSQLResult) LastInsertId() (int64, error) {
if len(m.results) == 0 {
return 0, nil
}
return m.results[len(m.results)-1].LastInsertId()
}
func (m multiSQLResult) RowsAffected() (int64, error) {
var total int64
for _, result := range m.results {
affected, err := result.RowsAffected()
if err != nil {
return 0, err
}
total += affected
}
return total, nil
}
// SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs.
// <= 0 disables splitting and keeps single-statement behavior.
func (s *StarDB) SetBatchInsertMaxRows(maxRows int) {
if s == nil {
return
}
if maxRows < 0 {
maxRows = 0
}
atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows))
}
// BatchInsertMaxRows returns current batch split threshold.
// 0 means disabled.
func (s *StarDB) BatchInsertMaxRows() int {
if s == nil {
return 0
}
value := atomic.LoadInt64(&s.batchInsertMaxRows)
if value <= 0 {
return 0
}
return int(value)
}
// SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs.
// <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers.
func (s *StarDB) SetBatchInsertMaxParams(maxParams int) {
if s == nil {
return
}
if maxParams < 0 {
maxParams = 0
}
atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams))
}
// BatchInsertMaxParams returns configured max bind parameter threshold.
// 0 means auto mode.
func (s *StarDB) BatchInsertMaxParams() int {
if s == nil {
return 0
}
value := atomic.LoadInt64(&s.batchInsertMaxParams)
if value <= 0 {
return 0
}
return int(value)
}
func detectBatchInsertMaxParams(db *sql.DB) int {
if db == nil {
return 0
}
driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver()))
switch {
case strings.Contains(driverType, "sqlite"):
// Keep conservative default for wide compatibility.
return 999
case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"):
return 65535
case strings.Contains(driverType, "mysql"):
return 65535
case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"):
return 2100
default:
return 0
}
}
func minPositive(a, b int) int {
if a <= 0 {
return b
}
if b <= 0 {
return a
}
if a < b {
return a
}
return b
}
func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) {
maxRows := s.BatchInsertMaxRows()
maxParams := s.BatchInsertMaxParams()
if maxParams <= 0 {
maxParams = detectBatchInsertMaxParams(s.db)
}
maxRowsByParams := 0
if maxParams > 0 {
maxRowsByParams = maxParams / columnCount
if maxRowsByParams <= 0 {
return 0, ErrBatchInsertMaxParamsTooLow
}
}
return minPositive(maxRows, maxRowsByParams), nil
}
func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) {
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
tableName,
strings.Join(columns, ", "),
placeholders)
args := make([]interface{}, 0, len(values)*len(columns))
for _, row := range values {
args = append(args, row...)
}
return query, args
}
func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) {
if ctx == nil {
ctx = context.Background()
}
tx, err := s.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
chunkCount := (len(values) + chunkSize - 1) / chunkSize
results := make([]sql.Result, 0, chunkCount)
for start := 0; start < len(values); start += chunkSize {
end := start + chunkSize
if end > len(values) {
end = len(values)
}
chunkIndex := start/chunkSize + 1
query, args := buildBatchInsertQuery(tableName, columns, values[start:end])
chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{
ChunkIndex: chunkIndex,
ChunkCount: chunkCount,
ChunkRows: end - start,
TotalRows: len(values),
ColumnCount: len(columns),
})
result, execErr := tx.exec(chunkCtx, query, args...)
if execErr != nil {
_ = tx.Rollback()
return nil, execErr
}
results = append(results, result)
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return nil, err
}
return multiSQLResult{results: results}, nil
}
// BatchInsert performs batch insert operation
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
return s.batchInsert(nil, tableName, columns, values)
}
// BatchInsertContext performs batch insert with context
func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
return s.batchInsert(ctx, tableName, columns, values)
}
// batchInsert is the internal implementation
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
if strings.TrimSpace(tableName) == "" {
return nil, ErrTableNameEmpty
}
if len(columns) == 0 {
return nil, ErrNoInsertColumns
}
if len(values) == 0 {
return nil, ErrNoInsertValues
}
for i, row := range values {
if len(row) != len(columns) {
return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns))
}
}
chunkSize, err := s.batchInsertChunkSize(len(columns))
if err != nil {
return nil, err
}
if chunkSize > 0 && len(values) > chunkSize {
return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize)
}
query, args := buildBatchInsertQuery(tableName, columns, values)
return s.exec(ctx, query, args...)
}
// BatchInsertStructs performs batch insert using structs
func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...)
}
// BatchInsertStructsContext performs batch insert using structs with context
func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...)
}
// batchInsertStructs is the internal implementation
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
if structs == nil {
return nil, ErrStructsNil
}
// Get slice of structs
targetValue := reflect.ValueOf(structs)
if targetValue.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return nil, ErrStructsPointerNil
}
targetValue = targetValue.Elem()
}
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
return nil, ErrStructsNotSlice
}
if targetValue.Len() == 0 {
return nil, ErrNoStructsToInsert
}
// Get field names from first struct
firstStruct := targetValue.Index(0).Interface()
fieldNames, err := getStructFieldNames(firstStruct, "db")
if err != nil {
return nil, err
}
// Filter out auto-increment fields
var columns []string
for _, fieldName := range fieldNames {
isAutoIncrement := false
for _, autoField := range autoIncrementFields {
if fieldName == autoField {
isAutoIncrement = true
break
}
}
if !isAutoIncrement {
columns = append(columns, fieldName)
}
}
// Extract values from all structs
var values [][]interface{}
for i := 0; i < targetValue.Len(); i++ {
structVal := targetValue.Index(i).Interface()
fieldValues, err := getStructFieldValues(structVal, "db")
if err != nil {
return nil, err
}
var row []interface{}
for _, col := range columns {
row = append(row, fieldValues[col])
}
values = append(values, row)
}
return s.batchInsert(ctx, tableName, columns, values)
}

View File

@ -1,134 +0,0 @@
package stardb
import (
"fmt"
"strings"
)
// QueryBuilder helps build SQL queries
type QueryBuilder struct {
table string
columns []string
joins []string
where []string
whereArgs []interface{}
groupBy []string
having []string
havingArgs []interface{}
orderBy string
limit int
offset int
}
// NewQueryBuilder creates a new query builder
func NewQueryBuilder(table string) *QueryBuilder {
return &QueryBuilder{
table: table,
columns: []string{"*"},
}
}
// Select sets the columns to select
func (qb *QueryBuilder) Select(columns ...string) *QueryBuilder {
qb.columns = columns
return qb
}
// Where adds a WHERE condition
func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuilder {
qb.where = append(qb.where, condition)
qb.whereArgs = append(qb.whereArgs, args...)
return qb
}
// Join adds a JOIN clause.
func (qb *QueryBuilder) Join(clause string) *QueryBuilder {
qb.joins = append(qb.joins, clause)
return qb
}
// GroupBy sets GROUP BY columns.
func (qb *QueryBuilder) GroupBy(columns ...string) *QueryBuilder {
qb.groupBy = append(qb.groupBy, columns...)
return qb
}
// Having adds a HAVING condition.
func (qb *QueryBuilder) Having(condition string, args ...interface{}) *QueryBuilder {
qb.having = append(qb.having, condition)
qb.havingArgs = append(qb.havingArgs, args...)
return qb
}
// OrderBy sets the ORDER BY clause
func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder {
qb.orderBy = orderBy
return qb
}
// Limit sets the LIMIT
func (qb *QueryBuilder) Limit(limit int) *QueryBuilder {
qb.limit = limit
return qb
}
// Offset sets the OFFSET
func (qb *QueryBuilder) Offset(offset int) *QueryBuilder {
qb.offset = offset
return qb
}
// Build builds the SQL query and returns query string and args
func (qb *QueryBuilder) Build() (string, []interface{}) {
var parts []string
// SELECT
parts = append(parts, fmt.Sprintf("SELECT %s FROM %s",
strings.Join(qb.columns, ", "), qb.table))
// JOIN
if len(qb.joins) > 0 {
parts = append(parts, strings.Join(qb.joins, " "))
}
// WHERE
if len(qb.where) > 0 {
parts = append(parts, "WHERE "+strings.Join(qb.where, " AND "))
}
// GROUP BY
if len(qb.groupBy) > 0 {
parts = append(parts, "GROUP BY "+strings.Join(qb.groupBy, ", "))
}
// HAVING
if len(qb.having) > 0 {
parts = append(parts, "HAVING "+strings.Join(qb.having, " AND "))
}
// ORDER BY
if qb.orderBy != "" {
parts = append(parts, "ORDER BY "+qb.orderBy)
}
// LIMIT
if qb.limit > 0 {
parts = append(parts, fmt.Sprintf("LIMIT %d", qb.limit))
}
// OFFSET
if qb.offset > 0 {
parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset))
}
args := make([]interface{}, 0, len(qb.whereArgs)+len(qb.havingArgs))
args = append(args, qb.whereArgs...)
args = append(args, qb.havingArgs...)
return strings.Join(parts, " "), args
}
// Query executes the query
func (qb *QueryBuilder) Query(db *StarDB) (*StarRows, error) {
query, args := qb.Build()
return db.Query(query, args...)
}

View File

@ -1,521 +0,0 @@
package stardb
import (
"reflect"
"testing"
)
func TestNewQueryBuilder(t *testing.T) {
qb := NewQueryBuilder("users")
if qb.table != "users" {
t.Errorf("Expected table 'users', got '%s'", qb.table)
}
if len(qb.columns) != 1 || qb.columns[0] != "*" {
t.Errorf("Expected default columns ['*'], got %v", qb.columns)
}
}
func TestQueryBuilder_Select(t *testing.T) {
qb := NewQueryBuilder("users").Select("id", "name", "email")
expected := []string{"id", "name", "email"}
if !reflect.DeepEqual(qb.columns, expected) {
t.Errorf("Expected columns %v, got %v", expected, qb.columns)
}
}
func TestQueryBuilder_Where(t *testing.T) {
qb := NewQueryBuilder("users").
Where("age > ?", 18).
Where("active = ?", true)
if len(qb.where) != 2 {
t.Errorf("Expected 2 where conditions, got %d", len(qb.where))
}
if qb.where[0] != "age > ?" {
t.Errorf("Expected first where 'age > ?', got '%s'", qb.where[0])
}
if qb.where[1] != "active = ?" {
t.Errorf("Expected second where 'active = ?', got '%s'", qb.where[1])
}
if len(qb.whereArgs) != 2 {
t.Errorf("Expected 2 where args, got %d", len(qb.whereArgs))
}
if qb.whereArgs[0] != 18 {
t.Errorf("Expected first arg 18, got %v", qb.whereArgs[0])
}
if qb.whereArgs[1] != true {
t.Errorf("Expected second arg true, got %v", qb.whereArgs[1])
}
}
func TestQueryBuilder_OrderBy(t *testing.T) {
qb := NewQueryBuilder("users").OrderBy("name ASC")
if qb.orderBy != "name ASC" {
t.Errorf("Expected orderBy 'name ASC', got '%s'", qb.orderBy)
}
}
func TestQueryBuilder_Limit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(10)
if qb.limit != 10 {
t.Errorf("Expected limit 10, got %d", qb.limit)
}
}
func TestQueryBuilder_Offset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(20)
if qb.offset != 20 {
t.Errorf("Expected offset 20, got %d", qb.offset)
}
}
func TestQueryBuilder_Build_Simple(t *testing.T) {
qb := NewQueryBuilder("users")
query, args := qb.Build()
expectedQuery := "SELECT * FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithSelect(t *testing.T) {
qb := NewQueryBuilder("users").Select("id", "name", "email")
query, args := qb.Build()
expectedQuery := "SELECT id, name, email FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithWhere(t *testing.T) {
qb := NewQueryBuilder("users").
Where("age > ?", 18).
Where("active = ?", true)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users WHERE age > ? AND active = ?"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 2 {
t.Errorf("Expected 2 args, got %d", len(args))
}
if args[0] != 18 {
t.Errorf("Expected first arg 18, got %v", args[0])
}
if args[1] != true {
t.Errorf("Expected second arg true, got %v", args[1])
}
}
func TestQueryBuilder_Build_WithOrderBy(t *testing.T) {
qb := NewQueryBuilder("users").OrderBy("created_at DESC")
query, args := qb.Build()
expectedQuery := "SELECT * FROM users ORDER BY created_at DESC"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithLimit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(10)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users LIMIT 10"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithOffset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(20)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users OFFSET 20"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithLimitAndOffset(t *testing.T) {
qb := NewQueryBuilder("users").Limit(10).Offset(20)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users LIMIT 10 OFFSET 20"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_Complex(t *testing.T) {
qb := NewQueryBuilder("users").
Select("id", "name", "email", "age").
Where("age > ?", 18).
Where("active = ?", true).
Where("country = ?", "US").
OrderBy("name ASC").
Limit(10).
Offset(20)
query, args := qb.Build()
expectedQuery := "SELECT id, name, email, age FROM users WHERE age > ? AND active = ? AND country = ? ORDER BY name ASC LIMIT 10 OFFSET 20"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedArgs := []interface{}{18, true, "US"}
if len(args) != len(expectedArgs) {
t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args))
}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_Build_MultipleWhere(t *testing.T) {
qb := NewQueryBuilder("orders").
Where("user_id = ?", 123).
Where("status IN (?, ?, ?)", "pending", "processing", "shipped").
Where("created_at > ?", "2024-01-01")
query, args := qb.Build()
expectedQuery := "SELECT * FROM orders WHERE user_id = ? AND status IN (?, ?, ?) AND created_at > ?"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
if len(args) != 5 {
t.Errorf("Expected 5 args, got %d", len(args))
}
expectedArgs := []interface{}{123, "pending", "processing", "shipped", "2024-01-01"}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_Chaining(t *testing.T) {
// Test method chaining returns the same builder
qb := NewQueryBuilder("users")
qb2 := qb.Select("id", "name")
if qb != qb2 {
t.Error("Select should return the same builder instance")
}
qb3 := qb.Where("age > ?", 18)
if qb != qb3 {
t.Error("Where should return the same builder instance")
}
qb4 := qb.OrderBy("name ASC")
if qb != qb4 {
t.Error("OrderBy should return the same builder instance")
}
qb5 := qb.Limit(10)
if qb != qb5 {
t.Error("Limit should return the same builder instance")
}
qb6 := qb.Offset(20)
if qb != qb6 {
t.Error("Offset should return the same builder instance")
}
qb7 := qb.Join("LEFT JOIN orders o ON o.user_id = users.id")
if qb != qb7 {
t.Error("Join should return the same builder instance")
}
qb8 := qb.GroupBy("users.id")
if qb != qb8 {
t.Error("GroupBy should return the same builder instance")
}
qb9 := qb.Having("COUNT(o.id) > ?", 1)
if qb != qb9 {
t.Error("Having should return the same builder instance")
}
}
func TestQueryBuilder_EmptyWhere(t *testing.T) {
qb := NewQueryBuilder("users").Select("id", "name")
query, args := qb.Build()
expectedQuery := "SELECT id, name FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_OnlyLimit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(5)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users LIMIT 5"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_OnlyOffset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(10)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users OFFSET 10"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_ZeroLimit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(0)
query, args := qb.Build()
// Limit 0 should not be included
expectedQuery := "SELECT * FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_ZeroOffset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(0)
query, args := qb.Build()
// Offset 0 should not be included
expectedQuery := "SELECT * FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_MultipleSelect(t *testing.T) {
// Last Select should override previous ones
qb := NewQueryBuilder("users").
Select("id", "name").
Select("email", "age")
query, _ := qb.Build()
expectedQuery := "SELECT email, age FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
}
func TestQueryBuilder_MultipleOrderBy(t *testing.T) {
// Last OrderBy should override previous ones
qb := NewQueryBuilder("users").
OrderBy("name ASC").
OrderBy("created_at DESC")
query, _ := qb.Build()
expectedQuery := "SELECT * FROM users ORDER BY created_at DESC"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
}
func TestQueryBuilder_ComplexWhereConditions(t *testing.T) {
qb := NewQueryBuilder("products").
Select("id", "name", "price").
Where("category_id = ?", 5).
Where("price BETWEEN ? AND ?", 10.0, 100.0).
Where("stock > ?", 0).
Where("name LIKE ?", "%phone%").
OrderBy("price DESC").
Limit(20)
query, args := qb.Build()
expectedQuery := "SELECT id, name, price FROM products WHERE category_id = ? AND price BETWEEN ? AND ? AND stock > ? AND name LIKE ? ORDER BY price DESC LIMIT 20"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedArgs := []interface{}{5, 10.0, 100.0, 0, "%phone%"}
if len(args) != len(expectedArgs) {
t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args))
}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_SingleColumn(t *testing.T) {
qb := NewQueryBuilder("users").Select("COUNT(*)")
query, args := qb.Build()
expectedQuery := "SELECT COUNT(*) FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_JoinLikeWhere(t *testing.T) {
// Test that WHERE can handle JOIN-like conditions
qb := NewQueryBuilder("users").
Select("users.id", "users.name", "orders.total").
Where("users.id = orders.user_id").
Where("orders.status = ?", "completed")
query, args := qb.Build()
expectedQuery := "SELECT users.id, users.name, orders.total FROM users WHERE users.id = orders.user_id AND orders.status = ?"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
if len(args) != 1 {
t.Errorf("Expected 1 arg, got %d", len(args))
}
if args[0] != "completed" {
t.Errorf("Expected arg 'completed', got %v", args[0])
}
}
func TestQueryBuilder_Build_WithJoinGroupByHaving(t *testing.T) {
qb := NewQueryBuilder("users u").
Select("u.id", "u.name", "COUNT(o.id) AS order_count").
Join("LEFT JOIN orders o ON o.user_id = u.id").
Where("u.active = ?", true).
GroupBy("u.id", "u.name").
Having("COUNT(o.id) > ?", 2).
OrderBy("order_count DESC")
query, args := qb.Build()
expectedQuery := "SELECT u.id, u.name, COUNT(o.id) AS order_count FROM users u LEFT JOIN orders o ON o.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(o.id) > ? ORDER BY order_count DESC"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedArgs := []interface{}{true, 2}
if len(args) != len(expectedArgs) {
t.Fatalf("Expected %d args, got %d", len(expectedArgs), len(args))
}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_Build_HavingWithoutWhere(t *testing.T) {
qb := NewQueryBuilder("orders").
Select("user_id", "COUNT(*) AS cnt").
GroupBy("user_id").
Having("COUNT(*) >= ?", 3)
query, args := qb.Build()
expectedQuery := "SELECT user_id, COUNT(*) AS cnt FROM orders GROUP BY user_id HAVING COUNT(*) >= ?"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 1 || args[0] != 3 {
t.Errorf("Expected args [3], got %v", args)
}
}
// Benchmark tests
func BenchmarkQueryBuilder_Simple(b *testing.B) {
for i := 0; i < b.N; i++ {
qb := NewQueryBuilder("users")
qb.Build()
}
}
func BenchmarkQueryBuilder_Complex(b *testing.B) {
for i := 0; i < b.N; i++ {
qb := NewQueryBuilder("users").
Select("id", "name", "email", "age").
Where("age > ?", 18).
Where("active = ?", true).
Where("country = ?", "US").
OrderBy("name ASC").
Limit(10).
Offset(20)
qb.Build()
}
}

View File

@ -1,32 +0,0 @@
package stardb
import (
internalconv "b612.me/stardb/internal/convert"
"time"
)
// convertToInt64 converts any value to int64
func convertToInt64(val interface{}) int64 {
return internalconv.ToInt64(val)
}
// convertToUint64 converts any value to uint64
func convertToUint64(val interface{}) uint64 {
return internalconv.ToUint64(val)
}
// convertToFloat64 converts any value to float64
func convertToFloat64(val interface{}) float64 {
return internalconv.ToFloat64(val)
}
// convertToBool converts any value to bool
// Non-zero numbers are considered true
func convertToBool(val interface{}) bool {
return internalconv.ToBool(val)
}
// convertToTime converts any value to time.Time
func convertToTime(val interface{}, layout string) time.Time {
return internalconv.ToTime(val, layout)
}

View File

@ -1,13 +0,0 @@
package stardb
import internalconv "b612.me/stardb/internal/convert"
// ConvertToInt64Safe converts any value to int64 with error handling
func ConvertToInt64Safe(val interface{}) (int64, error) {
return internalconv.ToInt64Safe(val)
}
// ConvertToStringSafe converts any value to string with error handling
func ConvertToStringSafe(val interface{}) (string, error) {
return internalconv.ToStringSafe(val)
}

View File

@ -1,183 +0,0 @@
package stardb
import (
"testing"
"time"
)
func TestConvertToInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected int64
}{
{"nil", nil, 0},
{"int", 42, 42},
{"int32", int32(42), 42},
{"int64", int64(42), 42},
{"uint64", uint64(42), 42},
{"float32", float32(42.7), 42},
{"float64", float64(42.7), 42},
{"string", "42", 42},
{"bool true", true, 1},
{"bool false", false, 0},
{"bytes", []byte("42"), 42},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToInt64(tt.input)
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestConvertToUint64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected uint64
}{
{"nil", nil, 0},
{"int", 42, 42},
{"int32", int32(42), 42},
{"int64", int64(42), 42},
{"uint64", uint64(42), 42},
{"float32", float32(42.7), 42},
{"float64", float64(42.7), 42},
{"string", "42", 42},
{"bool true", true, 1},
{"bool false", false, 0},
{"bytes", []byte("42"), 42},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToUint64(tt.input)
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestConvertToFloat64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected float64
}{
{"nil", nil, 0.0},
{"int", 42, 42.0},
{"int32", int32(42), 42.0},
{"int64", int64(42), 42.0},
{"uint64", uint64(42), 42.0},
{"float32", float32(42.5), 42.5},
{"float64", float64(42.5), 42.5},
{"string", "42.5", 42.5},
{"bool true", true, 1.0},
{"bool false", false, 0.0},
{"bytes", []byte("42.5"), 42.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToFloat64(tt.input)
if result != tt.expected {
t.Errorf("Expected %f, got %f", tt.expected, result)
}
})
}
}
func TestConvertToBool(t *testing.T) {
tests := []struct {
name string
input interface{}
expected bool
}{
{"nil", nil, false},
{"bool true", true, true},
{"bool false", false, false},
{"int positive", 1, true},
{"int zero", 0, false},
{"int negative", -1, true},
{"float positive", 1.5, true},
{"float zero", 0.0, false},
{"string true", "true", true},
{"string false", "false", false},
{"string 1", "1", true},
{"string 0", "0", false},
{"bytes true", []byte("true"), true},
{"bytes false", []byte("false"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToBool(tt.input)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestConvertToString(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
expected string
}{
{"nil", nil, ""},
{"string", "hello", "hello"},
{"int", 42, "42"},
{"int32", int32(42), "42"},
{"int64", int64(42), "42"},
{"float32", float32(42.5), "42.5"},
{"float64", float64(42.5), "42.5"},
{"bool true", true, "true"},
{"bool false", false, "false"},
{"time", now, now.String()},
{"bytes", []byte("hello"), "hello"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToString(tt.input)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestConvertToTime(t *testing.T) {
layout := "2006-01-02 15:04:05"
timeStr := "2024-01-01 10:00:00"
expectedTime, _ := time.Parse(layout, timeStr)
tests := []struct {
name string
input interface{}
layout string
expected time.Time
}{
{"nil", nil, layout, time.Time{}},
{"time.Time", expectedTime, layout, expectedTime},
{"int64 unix", int64(1704103200), layout, time.Unix(1704103200, 0)},
{"string", timeStr, layout, expectedTime},
{"bytes", []byte(timeStr), layout, expectedTime},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToTime(tt.input, tt.layout)
if !result.Equal(tt.expected) {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}

View File

@ -1,64 +0,0 @@
package stardb
import (
"errors"
"fmt"
)
var (
// Lifecycle errors.
ErrDBNotInitialized = errors.New("database is not initialized; call Open or SetDB first")
ErrTxNotInitialized = errors.New("transaction is not initialized")
ErrStmtNotInitialized = errors.New("statement is not initialized")
ErrStmtDBNotInitialized = errors.New("statement database context is not initialized")
// SQL input errors.
ErrQueryEmpty = errors.New("query string cannot be empty")
ErrScanStopped = errors.New("scan stopped by callback")
ErrScanFuncNil = errors.New("scan callback cannot be nil")
ErrScanORMFuncNil = errors.New("scan orm callback cannot be nil")
// Mapping and schema errors.
ErrColumnNotFound = errors.New("column not found")
ErrFieldNotFound = errors.New("field not found")
ErrRowIndexOutOfRange = errors.New("row index out of range")
// Target validation errors.
ErrTargetNil = errors.New("target cannot be nil")
ErrTargetsNil = errors.New("targets cannot be nil")
ErrTargetNotPointer = errors.New("target must be a pointer")
ErrTargetPointerNil = errors.New("target pointer cannot be nil")
ErrTargetsPointerNil = errors.New("targets pointer is nil")
ErrTargetNotStruct = errors.New("target is not a struct")
ErrTargetNotWritable = errors.New("target is not writable")
ErrPointerTargetNil = errors.New("pointer target is nil")
// SQL builder errors.
ErrTableNameEmpty = errors.New("table name cannot be empty")
ErrPrimaryKeyRequired = errors.New("at least one primary key is required")
ErrPrimaryKeyEmpty = errors.New("primary key cannot be empty")
ErrNoInsertColumns = errors.New("no columns to insert")
ErrNoInsertValues = errors.New("no values to insert")
ErrBatchInsertMaxParamsTooLow = errors.New("batch insert max params is lower than column count")
ErrNoUpdateFields = errors.New("no fields to update after excluding primary keys")
ErrBatchRowValueCountMismatch = errors.New("row values count does not match columns")
ErrStructsNil = errors.New("structs cannot be nil")
ErrStructsPointerNil = errors.New("structs pointer is nil")
ErrStructsNotSlice = errors.New("structs must be a slice or array")
ErrNoStructsToInsert = errors.New("no structs to insert")
// Transaction helper errors.
ErrTxFuncNil = errors.New("transaction callback cannot be nil")
)
func wrapColumnNotFound(column string) error {
return fmt.Errorf("%w: %s", ErrColumnNotFound, column)
}
func wrapFieldNotFound(field string) error {
return fmt.Errorf("%w: %s", ErrFieldNotFound, field)
}
func wrapBatchRowValueCountMismatch(rowIndex, got, expected int) error {
return fmt.Errorf("%w: row %d has %d values, expected %d", ErrBatchRowValueCountMismatch, rowIndex, got, expected)
}

3
go.mod
View File

@ -1,3 +0,0 @@
module b612.me/stardb
go 1.16

2
go.sum
View File

@ -1,2 +0,0 @@
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=

View File

@ -1,359 +0,0 @@
package convert
import (
"fmt"
"strconv"
"strings"
"time"
)
var defaultNullTimeLayouts = []string{
time.RFC3339Nano,
time.RFC3339,
"2006-01-02 15:04:05",
"2006-01-02",
}
// ToInt64 converts any value to int64.
func ToInt64(val interface{}) int64 {
switch v := val.(type) {
case nil:
return 0
case int:
return int64(v)
case int32:
return int64(v)
case int64:
return v
case uint64:
return int64(v)
case float32:
return int64(v)
case float64:
return int64(v)
case string:
result, _ := strconv.ParseInt(v, 10, 64)
return result
case bool:
if v {
return 1
}
return 0
case time.Time:
return v.Unix()
case []byte:
result, _ := strconv.ParseInt(string(v), 10, 64)
return result
default:
return 0
}
}
// ToUint64 converts any value to uint64.
func ToUint64(val interface{}) uint64 {
switch v := val.(type) {
case nil:
return 0
case int:
return uint64(v)
case int32:
return uint64(v)
case int64:
return uint64(v)
case uint64:
return v
case float32:
return uint64(v)
case float64:
return uint64(v)
case string:
result, _ := strconv.ParseUint(v, 10, 64)
return result
case bool:
if v {
return 1
}
return 0
case time.Time:
return uint64(v.Unix())
case []byte:
result, _ := strconv.ParseUint(string(v), 10, 64)
return result
default:
return 0
}
}
// ToFloat64 converts any value to float64.
func ToFloat64(val interface{}) float64 {
switch v := val.(type) {
case nil:
return 0
case int:
return float64(v)
case int32:
return float64(v)
case int64:
return float64(v)
case uint64:
return float64(v)
case float32:
return float64(v)
case float64:
return v
case string:
result, _ := strconv.ParseFloat(v, 64)
return result
case bool:
if v {
return 1
}
return 0
case time.Time:
return float64(v.Unix())
case []byte:
result, _ := strconv.ParseFloat(string(v), 64)
return result
default:
return 0
}
}
// ToBool converts any value to bool.
func ToBool(val interface{}) bool {
switch v := val.(type) {
case nil:
return false
case bool:
return v
case int:
return v != 0
case int32:
return v != 0
case int64:
return v != 0
case uint64:
return v != 0
case float32:
return v != 0
case float64:
return v != 0
case string:
result, _ := strconv.ParseBool(v)
return result
case []byte:
result, _ := strconv.ParseBool(string(v))
return result
default:
return false
}
}
// ToTime converts any value to time.Time.
func ToTime(val interface{}, layout string) time.Time {
switch v := val.(type) {
case nil:
return time.Time{}
case time.Time:
return v
case int:
return time.Unix(int64(v), 0)
case int32:
return time.Unix(int64(v), 0)
case int64:
return time.Unix(v, 0)
case uint64:
return time.Unix(int64(v), 0)
case float32:
sec := int64(v)
nsec := int64((v - float32(sec)) * 1e9)
return time.Unix(sec, nsec)
case float64:
sec := int64(v)
nsec := int64((v - float64(sec)) * 1e9)
return time.Unix(sec, nsec)
case string:
result, _ := time.Parse(layout, v)
return result
case []byte:
result, _ := time.Parse(layout, string(v))
return result
default:
return time.Time{}
}
}
// ToInt64Safe converts any value to int64 with error handling.
func ToInt64Safe(val interface{}) (int64, error) {
switch v := val.(type) {
case nil:
return 0, nil
case int:
return int64(v), nil
case int32:
return int64(v), nil
case int64:
return v, nil
case uint64:
return int64(v), nil
case float32:
return int64(v), nil
case float64:
return int64(v), nil
case string:
return strconv.ParseInt(v, 10, 64)
case bool:
if v {
return 1, nil
}
return 0, nil
case time.Time:
return v.Unix(), nil
case []byte:
return strconv.ParseInt(string(v), 10, 64)
default:
return 0, fmt.Errorf("cannot convert %T to int64", val)
}
}
// ToStringSafe converts any value to string with error handling.
func ToStringSafe(val interface{}) (string, error) {
switch v := val.(type) {
case nil:
return "", nil
case string:
return v, nil
case int:
return strconv.Itoa(v), nil
case int32:
return strconv.FormatInt(int64(v), 10), nil
case int64:
return strconv.FormatInt(v, 10), nil
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
case float64:
return strconv.FormatFloat(v, 'f', -1, 64), nil
case bool:
return strconv.FormatBool(v), nil
case time.Time:
return v.String(), nil
case []byte:
return string(v), nil
default:
return "", fmt.Errorf("cannot convert %T to string", val)
}
}
// ToFloat64Safe converts any value to float64 with error handling.
func ToFloat64Safe(val interface{}) (float64, error) {
switch v := val.(type) {
case nil:
return 0, nil
case float64:
return v, nil
case float32:
return float64(v), nil
case int, int32, int64, uint64:
intVal, err := ToInt64Safe(v)
return float64(intVal), err
case string:
return strconv.ParseFloat(v, 64)
case []byte:
return strconv.ParseFloat(string(v), 64)
default:
return 0, fmt.Errorf("cannot convert %T to float64", val)
}
}
// ToBoolSafe converts any value to bool with error handling.
func ToBoolSafe(val interface{}) (bool, error) {
switch v := val.(type) {
case nil:
return false, nil
case bool:
return v, nil
case int:
return v != 0, nil
case int8:
return v != 0, nil
case int16:
return v != 0, nil
case int32:
return v != 0, nil
case int64:
return v != 0, nil
case uint:
return v != 0, nil
case uint8:
return v != 0, nil
case uint16:
return v != 0, nil
case uint32:
return v != 0, nil
case uint64:
return v != 0, nil
case float32:
return v != 0, nil
case float64:
return v != 0, nil
case string:
return ParseBoolString(v)
case []byte:
return ParseBoolString(string(v))
default:
return false, fmt.Errorf("cannot convert %T to bool", val)
}
}
// ParseBoolString parses string-like bool values.
func ParseBoolString(raw string) (bool, error) {
normalized := strings.TrimSpace(strings.ToLower(raw))
switch normalized {
case "", "0", "false", "f", "off", "no", "n":
return false, nil
case "1", "true", "t", "on", "yes", "y":
return true, nil
default:
return false, fmt.Errorf("cannot parse bool value: %q", raw)
}
}
// ToTimeSafe converts any value to time.Time with error handling.
func ToTimeSafe(val interface{}) (time.Time, error) {
switch v := val.(type) {
case nil:
return time.Time{}, nil
case time.Time:
return v, nil
case int:
return time.Unix(int64(v), 0), nil
case int64:
return time.Unix(v, 0), nil
case string:
return ParseTimeValue(v)
case []byte:
return ParseTimeValue(string(v))
default:
return time.Time{}, fmt.Errorf("cannot convert %T to time.Time", val)
}
}
// ParseTimeValue parses common SQL date-time formats and unix timestamp.
func ParseTimeValue(raw string) (time.Time, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return time.Time{}, nil
}
for _, layout := range defaultNullTimeLayouts {
if t, err := time.Parse(layout, trimmed); err == nil {
return t, nil
}
}
if ts, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
return time.Unix(ts, 0), nil
}
return time.Time{}, fmt.Errorf("cannot parse time value: %q", raw)
}

View File

@ -1,12 +0,0 @@
package scanutil
// CloneScannedValue copies driver-scanned values that may be reused by driver.
// []byte is deep-copied; other types are returned as-is.
func CloneScannedValue(val interface{}) interface{} {
if b, ok := val.([]byte); ok {
copied := make([]byte, len(b))
copy(copied, b)
return copied
}
return val
}

View File

@ -1,119 +0,0 @@
package sqlplaceholder
import (
"strconv"
"strings"
)
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
// It skips quoted strings, quoted identifiers and comments.
func ConvertQuestionToDollarPlaceholders(query string) string {
if query == "" || !strings.Contains(query, "?") {
return query
}
const (
stateNormal = iota
stateSingleQuote
stateDoubleQuote
stateBacktick
stateLineComment
stateBlockComment
)
var b strings.Builder
b.Grow(len(query) + 8)
state := stateNormal
index := 1
for i := 0; i < len(query); i++ {
c := query[i]
switch state {
case stateNormal:
if c == '\'' {
state = stateSingleQuote
b.WriteByte(c)
continue
}
if c == '"' {
state = stateDoubleQuote
b.WriteByte(c)
continue
}
if c == '`' {
state = stateBacktick
b.WriteByte(c)
continue
}
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
state = stateLineComment
b.WriteByte(c)
i++
b.WriteByte(query[i])
continue
}
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
state = stateBlockComment
b.WriteByte(c)
i++
b.WriteByte(query[i])
continue
}
if c == '?' {
b.WriteByte('$')
b.WriteString(strconv.Itoa(index))
index++
continue
}
b.WriteByte(c)
case stateSingleQuote:
b.WriteByte(c)
if c == '\'' {
// SQL escaped single quote: ''
if i+1 < len(query) && query[i+1] == '\'' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateDoubleQuote:
b.WriteByte(c)
if c == '"' {
// escaped double quote: ""
if i+1 < len(query) && query[i+1] == '"' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateBacktick:
b.WriteByte(c)
if c == '`' {
state = stateNormal
}
case stateLineComment:
b.WriteByte(c)
if c == '\n' {
state = stateNormal
}
case stateBlockComment:
b.WriteByte(c)
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
i++
b.WriteByte(query[i])
state = stateNormal
}
}
}
return b.String()
}

View File

@ -1,323 +0,0 @@
package sqlruntime
import "strings"
// FingerprintSQL creates a normalized SQL fingerprint.
// mode controls literal masking; keepComments controls whether comments are preserved.
func FingerprintSQL(query string, mode int, keepComments bool) string {
prepared := query
if !keepComments {
prepared = stripSQLComments(prepared)
}
normalized := normalizeSQL(prepared)
if normalized == "" {
return ""
}
if NormalizeFingerprintMode(mode) == fingerprintModeMaskLiterals {
return maskSQLLiterals(normalized, keepComments)
}
return normalized
}
func normalizeSQL(query string) string {
normalized := strings.ToLower(strings.TrimSpace(query))
if normalized == "" {
return ""
}
return strings.Join(strings.Fields(normalized), " ")
}
func stripSQLComments(query string) string {
if query == "" {
return ""
}
const (
stateNormal = iota
stateSingleQuote
stateDoubleQuote
stateBacktick
stateLineComment
stateBlockComment
)
var b strings.Builder
b.Grow(len(query))
state := stateNormal
for i := 0; i < len(query); i++ {
c := query[i]
switch state {
case stateNormal:
if c == '\'' {
state = stateSingleQuote
b.WriteByte(c)
continue
}
if c == '"' {
state = stateDoubleQuote
b.WriteByte(c)
continue
}
if c == '`' {
state = stateBacktick
b.WriteByte(c)
continue
}
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
b.WriteByte(' ')
i++
state = stateLineComment
continue
}
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
b.WriteByte(' ')
i++
state = stateBlockComment
continue
}
b.WriteByte(c)
case stateSingleQuote:
b.WriteByte(c)
if c == '\'' {
if i+1 < len(query) && query[i+1] == '\'' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateDoubleQuote:
b.WriteByte(c)
if c == '"' {
if i+1 < len(query) && query[i+1] == '"' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateBacktick:
b.WriteByte(c)
if c == '`' {
state = stateNormal
}
case stateLineComment:
if c == '\n' {
b.WriteByte(' ')
state = stateNormal
}
case stateBlockComment:
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
b.WriteByte(' ')
i++
state = stateNormal
}
}
}
return b.String()
}
func maskSQLLiterals(query string, keepComments bool) string {
if query == "" {
return ""
}
const (
stateNormal = iota
stateSingleQuote
stateDoubleQuote
stateBacktick
stateLineComment
stateBlockComment
)
var b strings.Builder
b.Grow(len(query))
state := stateNormal
for i := 0; i < len(query); i++ {
c := query[i]
switch state {
case stateNormal:
if c == '\'' {
b.WriteByte('?')
state = stateSingleQuote
continue
}
if c == '"' {
b.WriteByte(c)
state = stateDoubleQuote
continue
}
if c == '`' {
b.WriteByte(c)
state = stateBacktick
continue
}
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
if keepComments {
b.WriteByte(c)
i++
b.WriteByte(query[i])
} else {
b.WriteByte(' ')
i++
}
state = stateLineComment
continue
}
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
if keepComments {
b.WriteByte(c)
i++
b.WriteByte(query[i])
} else {
b.WriteByte(' ')
i++
}
state = stateBlockComment
continue
}
if c == '$' {
j := i + 1
for j < len(query) && isDigit(query[j]) {
j++
}
if j > i+1 {
b.WriteByte('?')
i = j - 1
continue
}
}
if c == '-' && i+1 < len(query) && isDigit(query[i+1]) && isNumberBoundaryBefore(query, i) {
j := scanNumber(query, i+1)
if isNumberBoundaryAfter(query, j) {
b.WriteByte('?')
i = j - 1
continue
}
}
if isDigit(c) && isNumberBoundaryBefore(query, i) {
j := scanNumber(query, i)
if isNumberBoundaryAfter(query, j) {
b.WriteByte('?')
i = j - 1
continue
}
}
b.WriteByte(c)
case stateSingleQuote:
if c == '\'' {
if i+1 < len(query) && query[i+1] == '\'' {
i++
continue
}
state = stateNormal
}
case stateDoubleQuote:
b.WriteByte(c)
if c == '"' {
if i+1 < len(query) && query[i+1] == '"' {
i++
b.WriteByte(query[i])
continue
}
state = stateNormal
}
case stateBacktick:
b.WriteByte(c)
if c == '`' {
state = stateNormal
}
case stateLineComment:
if keepComments {
b.WriteByte(c)
}
if c == '\n' {
if !keepComments {
b.WriteByte(' ')
}
state = stateNormal
}
case stateBlockComment:
if keepComments {
b.WriteByte(c)
}
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
if keepComments {
i++
b.WriteByte(query[i])
} else {
b.WriteByte(' ')
i++
}
state = stateNormal
}
}
}
return strings.Join(strings.Fields(b.String()), " ")
}
func isDigit(c byte) bool {
return c >= '0' && c <= '9'
}
func isIdentifierChar(c byte) bool {
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_'
}
func isNumberBoundaryBefore(query string, index int) bool {
if index <= 0 {
return true
}
prev := query[index-1]
return !isIdentifierChar(prev) && prev != '$' && prev != '.'
}
func isNumberBoundaryAfter(query string, index int) bool {
if index >= len(query) {
return true
}
next := query[index]
return !isIdentifierChar(next) && next != '.'
}
func scanNumber(query string, start int) int {
i := start
for i < len(query) && isDigit(query[i]) {
i++
}
if i < len(query) && query[i] == '.' {
i++
for i < len(query) && isDigit(query[i]) {
i++
}
}
if i < len(query) && (query[i] == 'e' || query[i] == 'E') {
i++
if i < len(query) && (query[i] == '+' || query[i] == '-') {
i++
}
for i < len(query) && isDigit(query[i]) {
i++
}
}
return i
}

View File

@ -1,27 +0,0 @@
package sqlruntime
import "time"
// CloneHookArgs creates a shallow copy for hook consumers to avoid mutation races.
func CloneHookArgs(args []interface{}) []interface{} {
if len(args) == 0 {
return nil
}
copied := make([]interface{}, len(args))
copy(copied, args)
return copied
}
// ShouldRunAfterHook decides whether after-hook should run.
func ShouldRunAfterHook(hasAfterHook bool, slowThreshold, duration time.Duration, err error) bool {
if !hasAfterHook {
return false
}
if err != nil {
return true
}
if slowThreshold <= 0 {
return true
}
return duration >= slowThreshold
}

View File

@ -1,269 +0,0 @@
package sqlruntime
import (
"sync"
"time"
)
const (
placeholderQuestion = 0
placeholderDollar = 1
fingerprintModeBasic = 0
fingerprintModeMaskLiterals = 1
)
// NormalizePlaceholderStyle converts unknown style values to default question style.
func NormalizePlaceholderStyle(style int) int {
switch style {
case placeholderDollar:
return placeholderDollar
default:
return placeholderQuestion
}
}
// NormalizeFingerprintMode converts unknown mode values to default basic mode.
func NormalizeFingerprintMode(mode int) int {
switch mode {
case fingerprintModeMaskLiterals:
return fingerprintModeMaskLiterals
default:
return fingerprintModeBasic
}
}
// State stores runtime SQL behavior toggles in a thread-safe manner.
type State struct {
mu sync.RWMutex
beforeHook interface{}
afterHook interface{}
placeholder int
slowThreshold time.Duration
fingerprintEnabled bool
fingerprintMode int
fingerprintKeepComments bool
fingerprintCounterEnabled bool
fingerprintCounts map[string]uint64
}
// Options returns snapshot of current runtime options.
func (s *State) Options() (before, after interface{}, placeholder int, slowThreshold time.Duration) {
if s == nil {
return nil, nil, placeholderQuestion, 0
}
s.mu.RLock()
before = s.beforeHook
after = s.afterHook
placeholder = NormalizePlaceholderStyle(s.placeholder)
slowThreshold = s.slowThreshold
s.mu.RUnlock()
return before, after, placeholder, slowThreshold
}
// Hooks returns before/after hooks and slow threshold.
func (s *State) Hooks() (before, after interface{}, slowThreshold time.Duration) {
before, after, _, slowThreshold = s.Options()
return before, after, slowThreshold
}
// SetHooks sets before/after hooks.
func (s *State) SetHooks(before, after interface{}) {
if s == nil {
return
}
s.mu.Lock()
s.beforeHook = before
s.afterHook = after
s.mu.Unlock()
}
// SetBeforeHook sets before hook.
func (s *State) SetBeforeHook(before interface{}) {
if s == nil {
return
}
s.mu.Lock()
s.beforeHook = before
s.mu.Unlock()
}
// SetAfterHook sets after hook.
func (s *State) SetAfterHook(after interface{}) {
if s == nil {
return
}
s.mu.Lock()
s.afterHook = after
s.mu.Unlock()
}
// SetPlaceholderStyle sets placeholder style.
func (s *State) SetPlaceholderStyle(style int) {
if s == nil {
return
}
s.mu.Lock()
s.placeholder = NormalizePlaceholderStyle(style)
s.mu.Unlock()
}
// PlaceholderStyle returns placeholder style.
func (s *State) PlaceholderStyle() int {
if s == nil {
return placeholderQuestion
}
s.mu.RLock()
style := NormalizePlaceholderStyle(s.placeholder)
s.mu.RUnlock()
return style
}
// SetSlowThreshold sets minimum duration for triggering after hook.
func (s *State) SetSlowThreshold(threshold time.Duration) {
if s == nil {
return
}
if threshold < 0 {
threshold = 0
}
s.mu.Lock()
s.slowThreshold = threshold
s.mu.Unlock()
}
// SlowThreshold returns current slow threshold.
func (s *State) SlowThreshold() time.Duration {
if s == nil {
return 0
}
s.mu.RLock()
threshold := s.slowThreshold
s.mu.RUnlock()
return threshold
}
// SetFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
func (s *State) SetFingerprintEnabled(enabled bool) {
if s == nil {
return
}
s.mu.Lock()
s.fingerprintEnabled = enabled
s.mu.Unlock()
}
// FingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
func (s *State) FingerprintEnabled() bool {
if s == nil {
return false
}
s.mu.RLock()
enabled := s.fingerprintEnabled
s.mu.RUnlock()
return enabled
}
// SetFingerprintMode sets SQL fingerprint mode.
func (s *State) SetFingerprintMode(mode int) {
if s == nil {
return
}
s.mu.Lock()
s.fingerprintMode = NormalizeFingerprintMode(mode)
s.mu.Unlock()
}
// FingerprintMode returns SQL fingerprint mode.
func (s *State) FingerprintMode() int {
if s == nil {
return fingerprintModeBasic
}
s.mu.RLock()
mode := NormalizeFingerprintMode(s.fingerprintMode)
s.mu.RUnlock()
return mode
}
// SetFingerprintKeepComments toggles comment preservation in generated SQL fingerprints.
func (s *State) SetFingerprintKeepComments(keep bool) {
if s == nil {
return
}
s.mu.Lock()
s.fingerprintKeepComments = keep
s.mu.Unlock()
}
// FingerprintKeepComments reports whether comments are kept in generated SQL fingerprints.
func (s *State) FingerprintKeepComments() bool {
if s == nil {
return false
}
s.mu.RLock()
keep := s.fingerprintKeepComments
s.mu.RUnlock()
return keep
}
// SetFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
func (s *State) SetFingerprintCounterEnabled(enabled bool) {
if s == nil {
return
}
s.mu.Lock()
s.fingerprintCounterEnabled = enabled
s.mu.Unlock()
}
// FingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
func (s *State) FingerprintCounterEnabled() bool {
if s == nil {
return false
}
s.mu.RLock()
enabled := s.fingerprintCounterEnabled
s.mu.RUnlock()
return enabled
}
// IncFingerprintCount increments hit count for a fingerprint.
func (s *State) IncFingerprintCount(fingerprint string) {
if s == nil || fingerprint == "" {
return
}
s.mu.Lock()
if s.fingerprintCounts == nil {
s.fingerprintCounts = make(map[string]uint64)
}
s.fingerprintCounts[fingerprint]++
s.mu.Unlock()
}
// FingerprintCountsSnapshot returns a snapshot copy of fingerprint counters.
func (s *State) FingerprintCountsSnapshot() map[string]uint64 {
if s == nil {
return map[string]uint64{}
}
s.mu.RLock()
if len(s.fingerprintCounts) == 0 {
s.mu.RUnlock()
return map[string]uint64{}
}
out := make(map[string]uint64, len(s.fingerprintCounts))
for k, v := range s.fingerprintCounts {
out[k] = v
}
s.mu.RUnlock()
return out
}
// ResetFingerprintCounts clears all fingerprint counters.
func (s *State) ResetFingerprintCounts() {
if s == nil {
return
}
s.mu.Lock()
s.fingerprintCounts = nil
s.mu.Unlock()
}

615
orm.go
View File

@ -1,615 +0,0 @@
package stardb
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
)
// Orm maps query results to a struct or slice of structs
// Usage:
//
// var user User
// rows.Orm(&user) // single row
// var users []User
// rows.Orm(&users) // multiple rows
func (r *StarRows) Orm(target interface{}) error {
if !r.parsed {
if err := r.parse(); err != nil {
return err
}
}
if target == nil {
return ErrTargetNil
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() != reflect.Ptr {
return ErrTargetNotPointer
}
if targetValue.IsNil() {
return ErrTargetPointerNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
// Handle slice
if targetValue.Kind() == reflect.Slice {
elementType := targetType.Elem()
result := reflect.MakeSlice(targetType, 0, r.Length())
if r.Length() == 0 {
targetValue.Set(result)
return nil
}
for i := 0; i < r.Length(); i++ {
element := reflect.New(elementType)
if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil {
return err
}
result = reflect.Append(result, element.Elem())
}
targetValue.Set(result)
return nil
}
// Handle array
if targetValue.Kind() == reflect.Array {
elementType := targetType.Elem()
if r.Length() == 0 {
return nil
}
if r.Length() > targetValue.Len() {
return fmt.Errorf("target array length %d is smaller than rows %d", targetValue.Len(), r.Length())
}
for i := 0; i < r.Length(); i++ {
element := reflect.New(elementType)
if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil {
return err
}
targetValue.Index(i).Set(element.Elem())
}
return nil
}
// Handle single struct
if r.Length() == 0 {
return nil
}
return r.setStructFieldsFromRow(target, "db", 0)
}
func bindNamedArgs(args []interface{}, fieldValues map[string]interface{}) ([]interface{}, error) {
processedArgs := make([]interface{}, len(args))
for i, arg := range args {
str, ok := arg.(string)
if !ok {
processedArgs[i] = arg
continue
}
if strings.HasPrefix(str, `\:`) {
processedArgs[i] = str[1:]
continue
}
if strings.HasPrefix(str, ":") {
fieldName := str[1:]
val, exists := fieldValues[fieldName]
if !exists {
return nil, wrapFieldNotFound(fieldName)
}
processedArgs[i] = val
continue
}
processedArgs[i] = arg
}
return processedArgs, nil
}
// QueryX executes a query with named parameter binding
// Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id")
func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
return s.queryX(nil, target, query, args...)
}
// QueryXContext executes a query with context and named parameter binding
func (s *StarDB) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
return s.queryX(ctx, target, query, args...)
}
// queryX is the internal implementation
func (s *StarDB) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
processedArgs, err := bindNamedArgs(args, fieldValues)
if err != nil {
return nil, err
}
return s.query(ctx, query, processedArgs...)
}
// QueryXS executes queries for multiple structs
func (s *StarDB) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return s.queryXS(nil, targets, query, args...)
}
// QueryXSContext executes queries with context for multiple structs
func (s *StarDB) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return s.queryXS(ctx, targets, query, args...)
}
// queryXS is the internal implementation
func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
var results []*StarRows
if targets == nil {
return results, ErrTargetsNil
}
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return results, ErrTargetsPointerNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
results = make([]*StarRows, 0, targetValue.Len())
for i := 0; i < targetValue.Len(); i++ {
result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := s.queryX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// ExecX executes a statement with named parameter binding
func (s *StarDB) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) {
return s.execX(nil, target, query, args...)
}
// ExecXContext executes a statement with context and named parameter binding
func (s *StarDB) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
return s.execX(ctx, target, query, args...)
}
// execX is the internal implementation
func (s *StarDB) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
processedArgs, err := bindNamedArgs(args, fieldValues)
if err != nil {
return nil, err
}
return s.exec(ctx, query, processedArgs...)
}
// ExecXS executes statements for multiple structs
func (s *StarDB) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return s.execXS(nil, targets, query, args...)
}
// ExecXSContext executes statements with context for multiple structs
func (s *StarDB) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return s.execXS(ctx, targets, query, args...)
}
// execXS is the internal implementation
func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
var results []sql.Result
if targets == nil {
return results, ErrTargetsNil
}
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return results, ErrTargetsPointerNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
results = make([]sql.Result, 0, targetValue.Len())
for i := 0; i < targetValue.Len(); i++ {
result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := s.execX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// Insert inserts a struct into the database
// Usage: Insert(&user, "users", "id") // id is auto-increment
func (s *StarDB) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return s.insert(nil, target, tableName, autoIncrementFields...)
}
// InsertContext inserts a struct with context
func (s *StarDB) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return s.insert(ctx, target, tableName, autoIncrementFields...)
}
// insert is the internal implementation
func (s *StarDB) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...)
if err != nil {
return nil, err
}
args := make([]interface{}, len(params))
for i, param := range params {
args[i] = param
}
return s.execX(ctx, target, query, args...)
}
// Update updates a struct in the database
// Usage: Update(&user, "users", "id") // id is primary key
func (s *StarDB) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return s.update(nil, target, tableName, primaryKeys...)
}
// UpdateContext updates a struct with context
func (s *StarDB) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return s.update(ctx, target, tableName, primaryKeys...)
}
// update is the internal implementation
func (s *StarDB) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
query, params, err := buildUpdateSQL(target, tableName, primaryKeys...)
if err != nil {
return nil, err
}
args := make([]interface{}, len(params))
for i, param := range params {
args[i] = param
}
return s.execX(ctx, target, query, args...)
}
// buildInsertSQL builds an INSERT SQL statement
func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) {
if strings.TrimSpace(tableName) == "" {
return "", []string{}, ErrTableNameEmpty
}
fieldNames, err := getStructFieldNames(target, "db")
if err != nil {
return "", []string{}, err
}
var columns []string
var placeholders []string
var params []string
autoIncrementSet := make(map[string]struct{}, len(autoIncrementFields))
for _, autoField := range autoIncrementFields {
autoIncrementSet[autoField] = struct{}{}
}
for _, fieldName := range fieldNames {
// Skip auto-increment fields
if _, isAutoIncrement := autoIncrementSet[fieldName]; isAutoIncrement {
continue
}
columns = append(columns, fieldName)
placeholders = append(placeholders, "?")
params = append(params, ":"+fieldName)
}
if len(columns) == 0 {
return "", []string{}, ErrNoInsertColumns
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
tableName,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "))
return query, params, nil
}
// buildUpdateSQL builds an UPDATE SQL statement
func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) {
if strings.TrimSpace(tableName) == "" {
return "", []string{}, ErrTableNameEmpty
}
fieldNames, err := getStructFieldNames(target, "db")
if err != nil {
return "", []string{}, err
}
if len(primaryKeys) == 0 {
return "", []string{}, ErrPrimaryKeyRequired
}
primaryKeySet := make(map[string]struct{}, len(primaryKeys))
for _, pk := range primaryKeys {
if pk == "" {
return "", []string{}, ErrPrimaryKeyEmpty
}
primaryKeySet[pk] = struct{}{}
}
var setClauses []string
var params []string
// Build SET clause
for _, fieldName := range fieldNames {
if _, isPrimaryKey := primaryKeySet[fieldName]; isPrimaryKey {
continue
}
setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName))
params = append(params, ":"+fieldName)
}
if len(setClauses) == 0 {
return "", []string{}, ErrNoUpdateFields
}
// Build WHERE clause
var whereClauses []string
for _, pk := range primaryKeys {
whereClauses = append(whereClauses, fmt.Sprintf("%s = ?", pk))
params = append(params, ":"+pk)
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s",
tableName,
strings.Join(setClauses, ", "),
strings.Join(whereClauses, " AND "))
return query, params, nil
}
// Transaction ORM methods
// QueryX executes a query with named parameter binding in a transaction
func (t *StarTx) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
return t.queryX(nil, target, query, args...)
}
// QueryXContext executes a query with context in a transaction
func (t *StarTx) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
return t.queryX(ctx, target, query, args...)
}
// queryX is the internal implementation
func (t *StarTx) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
processedArgs, err := bindNamedArgs(args, fieldValues)
if err != nil {
return nil, err
}
return t.query(ctx, query, processedArgs...)
}
// QueryXS executes queries for multiple structs in a transaction
func (t *StarTx) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return t.queryXS(nil, targets, query, args...)
}
// QueryXSContext executes queries with context for multiple structs in a transaction
func (t *StarTx) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return t.queryXS(ctx, targets, query, args...)
}
// queryXS is the internal implementation
func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
var results []*StarRows
if targets == nil {
return results, ErrTargetsNil
}
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return results, ErrTargetsPointerNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
results = make([]*StarRows, 0, targetValue.Len())
for i := 0; i < targetValue.Len(); i++ {
result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := t.queryX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// ExecX executes a statement with named parameter binding in a transaction
func (t *StarTx) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) {
return t.execX(nil, target, query, args...)
}
// ExecXContext executes a statement with context in a transaction
func (t *StarTx) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
return t.execX(ctx, target, query, args...)
}
// execX is the internal implementation
func (t *StarTx) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
processedArgs, err := bindNamedArgs(args, fieldValues)
if err != nil {
return nil, err
}
return t.exec(ctx, query, processedArgs...)
}
// ExecXS executes statements for multiple structs in a transaction
func (t *StarTx) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return t.execXS(nil, targets, query, args...)
}
// ExecXSContext executes statements with context for multiple structs in a transaction
func (t *StarTx) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return t.execXS(ctx, targets, query, args...)
}
// execXS is the internal implementation
func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
var results []sql.Result
if targets == nil {
return results, ErrTargetsNil
}
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return results, ErrTargetsPointerNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
results = make([]sql.Result, 0, targetValue.Len())
for i := 0; i < targetValue.Len(); i++ {
result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := t.execX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// Insert inserts a struct in a transaction
func (t *StarTx) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return t.insert(nil, target, tableName, autoIncrementFields...)
}
// InsertContext inserts a struct with context in a transaction
func (t *StarTx) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return t.insert(ctx, target, tableName, autoIncrementFields...)
}
// insert is the internal implementation
func (t *StarTx) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...)
if err != nil {
return nil, err
}
args := make([]interface{}, len(params))
for i, param := range params {
args[i] = param
}
return t.execX(ctx, target, query, args...)
}
// Update updates a struct in a transaction
func (t *StarTx) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return t.update(nil, target, tableName, primaryKeys...)
}
// UpdateContext updates a struct with context in a transaction
func (t *StarTx) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return t.update(ctx, target, tableName, primaryKeys...)
}
// update is the internal implementation
func (t *StarTx) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
query, params, err := buildUpdateSQL(target, tableName, primaryKeys...)
if err != nil {
return nil, err
}
args := make([]interface{}, len(params))
for i, param := range params {
args[i] = param
}
return t.execX(ctx, target, query, args...)
}

View File

@ -1,275 +0,0 @@
package stardb
import (
"errors"
"reflect"
"testing"
"time"
)
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
Balance float64 `db:"balance"`
Active bool `db:"active"`
CreatedAt time.Time `db:"created_at"`
}
type Profile struct {
UserID int `db:"user_id"`
Bio string `db:"bio"`
Avatar string `db:"avatar"`
}
type NestedUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
Profile `db:"---"`
}
type NestedUserPtr struct {
ID int64 `db:"id"`
Name string `db:"name"`
Profile *Profile `db:"---"`
}
type AutoIDOnly struct {
ID int64 `db:"id"`
}
type HiddenTagged struct {
ID int64 `db:"id"`
hidden string `db:"hidden"`
}
func TestBuildInsertSQL(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
Email: "test@example.com",
Age: 30,
Balance: 100.0,
Active: true,
CreatedAt: time.Now(),
}
query, params, err := buildInsertSQL(&user, "users", "id")
if err != nil {
t.Fatalf("buildInsertSQL failed: %v", err)
}
expectedQuery := "INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedParams := []string{":name", ":email", ":age", ":balance", ":active", ":created_at"}
if len(params) != len(expectedParams) {
t.Errorf("Expected %d params, got %d", len(expectedParams), len(params))
}
for i, param := range params {
if param != expectedParams[i] {
t.Errorf("Expected param %s, got %s", expectedParams[i], param)
}
}
}
func TestBuildUpdateSQL(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
Email: "test@example.com",
Age: 30,
Balance: 100.0,
Active: true,
CreatedAt: time.Now(),
}
query, params, err := buildUpdateSQL(&user, "users", "id")
if err != nil {
t.Fatalf("buildUpdateSQL failed: %v", err)
}
expectedQuery := "UPDATE users SET name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedParamCount := 7 // 6 fields + 1 primary key
if len(params) != expectedParamCount {
t.Errorf("Expected %d params, got %d", expectedParamCount, len(params))
}
}
func TestBuildInsertSQL_NoColumns(t *testing.T) {
model := AutoIDOnly{ID: 1}
_, _, err := buildInsertSQL(&model, "users", "id")
if err == nil {
t.Fatal("Expected error when no columns remain to insert, got nil")
}
if !errors.Is(err, ErrNoInsertColumns) {
t.Fatalf("Expected ErrNoInsertColumns, got %v", err)
}
}
func TestBuildInsertSQL_EmptyTableName(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
}
_, _, err := buildInsertSQL(&user, "", "id")
if err == nil {
t.Fatal("Expected error when table name is empty, got nil")
}
if !errors.Is(err, ErrTableNameEmpty) {
t.Fatalf("Expected ErrTableNameEmpty, got %v", err)
}
}
func TestBuildUpdateSQL_NoPrimaryKey(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
}
_, _, err := buildUpdateSQL(&user, "users")
if err == nil {
t.Fatal("Expected error when no primary key is provided, got nil")
}
if !errors.Is(err, ErrPrimaryKeyRequired) {
t.Fatalf("Expected ErrPrimaryKeyRequired, got %v", err)
}
}
func TestBuildUpdateSQL_EmptyTableName(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
}
_, _, err := buildUpdateSQL(&user, "", "id")
if err == nil {
t.Fatal("Expected error when table name is empty, got nil")
}
if !errors.Is(err, ErrTableNameEmpty) {
t.Fatalf("Expected ErrTableNameEmpty, got %v", err)
}
}
func TestBuildUpdateSQL_OnlyPrimaryKey(t *testing.T) {
model := AutoIDOnly{ID: 1}
_, _, err := buildUpdateSQL(&model, "users", "id")
if err == nil {
t.Fatal("Expected error when no fields remain for SET clause, got nil")
}
if !errors.Is(err, ErrNoUpdateFields) {
t.Fatalf("Expected ErrNoUpdateFields, got %v", err)
}
}
func TestGetStructFieldValues_NilNestedPointer(t *testing.T) {
user := NestedUserPtr{
ID: 1,
Name: "Test",
Profile: nil,
}
values, err := getStructFieldValues(user, "db")
if err != nil {
t.Fatalf("getStructFieldValues failed: %v", err)
}
if values["id"] != int64(1) {
t.Errorf("Expected id=1, got %v", values["id"])
}
if values["name"] != "Test" {
t.Errorf("Expected name=Test, got %v", values["name"])
}
}
func TestGetStructFieldNames_NilNestedPointer(t *testing.T) {
user := NestedUserPtr{
ID: 1,
Name: "Test",
Profile: nil,
}
names, err := getStructFieldNames(user, "db")
if err != nil {
t.Fatalf("getStructFieldNames failed: %v", err)
}
expected := []string{"id", "name"}
if !reflect.DeepEqual(names, expected) {
t.Errorf("Expected names %v, got %v", expected, names)
}
}
func TestGetStructFieldNames_SkipUnexportedField(t *testing.T) {
model := HiddenTagged{
ID: 1,
hidden: "secret",
}
names, err := getStructFieldNames(model, "db")
if err != nil {
t.Fatalf("getStructFieldNames failed: %v", err)
}
expected := []string{"id"}
if !reflect.DeepEqual(names, expected) {
t.Errorf("Expected names %v, got %v", expected, names)
}
}
func TestGetStructFieldValues_NilTarget(t *testing.T) {
_, err := getStructFieldValues(nil, "db")
if err == nil {
t.Fatal("Expected error for nil target, got nil")
}
if !errors.Is(err, ErrTargetNil) {
t.Fatalf("Expected ErrTargetNil, got %v", err)
}
}
func TestGetStructFieldNames_NilTarget(t *testing.T) {
_, err := getStructFieldNames(nil, "db")
if err == nil {
t.Fatal("Expected error for nil target, got nil")
}
if !errors.Is(err, ErrTargetNil) {
t.Fatalf("Expected ErrTargetNil, got %v", err)
}
}
func TestClearReflectCache(t *testing.T) {
type cacheUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
}
typ := reflect.TypeOf(cacheUser{})
plan1, err := getStructTagPlan(typ, "db")
if err != nil {
t.Fatalf("getStructTagPlan failed: %v", err)
}
if len(plan1) != 2 {
t.Fatalf("Expected 2 fields in plan, got %d", len(plan1))
}
ClearReflectCache()
plan2, err := getStructTagPlan(typ, "db")
if err != nil {
t.Fatalf("getStructTagPlan after clear failed: %v", err)
}
if len(plan2) != 2 {
t.Fatalf("Expected 2 fields in plan after clear, got %d", len(plan2))
}
}

58
pool.go
View File

@ -1,58 +0,0 @@
package stardb
import (
"time"
)
// PoolConfig represents database connection pool configuration
type PoolConfig struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
// DefaultPoolConfig returns default pool configuration
func DefaultPoolConfig() *PoolConfig {
return &PoolConfig{
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: 10 * time.Minute,
}
}
// SetPoolConfig applies pool configuration to the database
func (s *StarDB) SetPoolConfig(config *PoolConfig) {
if s == nil || s.db == nil || config == nil {
return
}
if config.MaxOpenConns > 0 {
s.db.SetMaxOpenConns(config.MaxOpenConns)
}
if config.MaxIdleConns > 0 {
s.db.SetMaxIdleConns(config.MaxIdleConns)
}
if config.ConnMaxLifetime > 0 {
s.db.SetConnMaxLifetime(config.ConnMaxLifetime)
}
if config.ConnMaxIdleTime > 0 {
s.db.SetConnMaxIdleTime(config.ConnMaxIdleTime)
}
}
// OpenWithPool opens a database connection with pool configuration
func OpenWithPool(driver, connStr string, config *PoolConfig) (*StarDB, error) {
db := NewStarDB()
if err := db.Open(driver, connStr); err != nil {
return nil, err
}
if config == nil {
config = DefaultPoolConfig()
}
db.SetPoolConfig(config)
return db, nil
}

View File

@ -1,375 +0,0 @@
package stardb
import (
"reflect"
"sync"
"time"
)
type structTagField struct {
path []int
tag string
}
type structTagPlanKey struct {
typ reflect.Type
tagKey string
}
var structTagPlanCache sync.Map
// ClearReflectCache clears internal reflection metadata cache.
// Useful after schema/tag refactors in long-running processes.
func ClearReflectCache() {
structTagPlanCache = sync.Map{}
}
func getStructTagPlan(targetType reflect.Type, tagKey string) ([]structTagField, error) {
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetType.Kind() != reflect.Struct {
return nil, ErrTargetNotStruct
}
cacheKey := structTagPlanKey{
typ: targetType,
tagKey: tagKey,
}
if cached, ok := structTagPlanCache.Load(cacheKey); ok {
return cached.([]structTagField), nil
}
fields := make([]structTagField, 0, targetType.NumField())
if err := buildStructTagPlan(targetType, tagKey, nil, &fields); err != nil {
return nil, err
}
structTagPlanCache.Store(cacheKey, fields)
return fields, nil
}
func buildStructTagPlan(currentType reflect.Type, tagKey string, prefix []int, out *[]structTagField) error {
if currentType.Kind() == reflect.Ptr {
currentType = currentType.Elem()
}
if currentType.Kind() != reflect.Struct {
return ErrTargetNotStruct
}
for i := 0; i < currentType.NumField(); i++ {
field := currentType.Field(i)
tagValue := field.Tag.Get(tagKey)
fieldType := field.Type
path := make([]int, len(prefix)+1)
copy(path, prefix)
path[len(prefix)] = i
if tagValue == "---" {
if fieldType.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct {
if err := buildStructTagPlan(fieldType.Elem(), tagKey, path, out); err != nil {
return err
}
continue
}
if fieldType.Kind() == reflect.Struct {
if err := buildStructTagPlan(fieldType, tagKey, path, out); err != nil {
return err
}
continue
}
}
if tagValue != "" {
*out = append(*out, structTagField{
path: path,
tag: tagValue,
})
}
}
return nil
}
func resolveFieldByPath(root reflect.Value, path []int) (reflect.Value, bool) {
current := root
for _, idx := range path {
if current.Kind() == reflect.Ptr {
if current.IsNil() {
return reflect.Value{}, false
}
current = current.Elem()
}
if current.Kind() != reflect.Struct {
return reflect.Value{}, false
}
if idx < 0 || idx >= current.NumField() {
return reflect.Value{}, false
}
current = current.Field(idx)
}
return current, true
}
// setStructFieldsFromRow sets struct fields from a row result using reflection
func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error {
if target == nil {
return ErrTargetNil
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() == reflect.Ptr {
targetValue = targetValue.Elem()
}
if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() {
return ErrTargetNotWritable
}
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetValue.Kind() != reflect.Struct {
return ErrTargetNotStruct
}
row := r.Row(rowIndex)
if row.columnIndex == nil {
return ErrRowIndexOutOfRange
}
for i := 0; i < targetType.NumField(); i++ {
field := targetType.Field(i)
fieldValue := targetValue.Field(i)
tagValue := field.Tag.Get(tagKey)
if !fieldValue.CanInterface() {
continue
}
// Skip unexported or otherwise non-settable fields.
if !fieldValue.CanSet() {
continue
}
// Handle nested structs
if fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct {
if tagValue == "" {
continue
}
if tagValue == "---" {
nestedPtr := reflect.New(fieldValue.Type().Elem()).Interface()
if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil {
return err
}
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr))
continue
}
}
if fieldValue.Kind() == reflect.Struct {
if tagValue == "" {
continue
}
if tagValue == "---" {
nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface()
if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil {
return err
}
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem())
continue
}
}
if tagValue == "" {
continue
}
// Check if column exists
if _, ok := row.columnIndex[tagValue]; !ok {
if r.db != nil && r.db.StrictORM {
return wrapColumnNotFound(tagValue)
}
continue
}
// Set field value based on type
r.setFieldValue(fieldValue, tagValue, row)
}
return nil
}
// setFieldValue sets a single field value
func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, row *StarResult) {
switch fieldValue.Kind() {
case reflect.String:
fieldValue.SetString(row.MustString(columnName))
case reflect.Int:
fieldValue.SetInt(int64(row.MustInt(columnName)))
case reflect.Int8:
fieldValue.SetInt(int64(int8(row.MustInt64(columnName))))
case reflect.Int16:
fieldValue.SetInt(int64(int16(row.MustInt64(columnName))))
case reflect.Int32:
fieldValue.SetInt(int64(row.MustInt32(columnName)))
case reflect.Int64:
fieldValue.SetInt(row.MustInt64(columnName))
case reflect.Uint:
fieldValue.SetUint(uint64(row.MustUint64(columnName)))
case reflect.Uint8:
fieldValue.SetUint(uint64(uint8(row.MustUint64(columnName))))
case reflect.Uint16:
fieldValue.SetUint(uint64(uint16(row.MustUint64(columnName))))
case reflect.Uint32:
fieldValue.SetUint(uint64(uint32(row.MustUint64(columnName))))
case reflect.Uint64:
fieldValue.SetUint(row.MustUint64(columnName))
case reflect.Bool:
fieldValue.SetBool(row.MustBool(columnName))
case reflect.Float32:
fieldValue.SetFloat(float64(row.MustFloat32(columnName)))
case reflect.Float64:
fieldValue.SetFloat(row.MustFloat64(columnName))
case reflect.Struct:
// Handle special struct types like time.Time
colIndex := row.columnIndex[columnName]
val := row.Result()[colIndex]
if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
if t, ok := val.(time.Time); ok {
fieldValue.Set(reflect.ValueOf(t))
}
}
case reflect.Ptr:
// Handle pointer to special types like *time.Time
colIndex := row.columnIndex[columnName]
val := row.Result()[colIndex]
if fieldValue.Type().Elem() == reflect.TypeOf(time.Time{}) {
if t, ok := val.(time.Time); ok {
tCopy := t
fieldValue.Set(reflect.ValueOf(&tCopy))
}
}
case reflect.Interface:
colIndex := row.columnIndex[columnName]
val := row.Result()[colIndex]
if val != nil {
fieldValue.Set(reflect.ValueOf(val))
}
}
}
// getStructFieldValues extracts all field values from a struct
func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) {
if target == nil {
return nil, ErrTargetNil
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return nil, ErrPointerTargetNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() != reflect.Struct {
return nil, ErrTargetNotStruct
}
plan, err := getStructTagPlan(targetType, tagKey)
if err != nil {
return nil, err
}
result := make(map[string]interface{}, len(plan))
for _, field := range plan {
fieldValue, ok := resolveFieldByPath(targetValue, field.path)
if !ok {
continue
}
if !fieldValue.CanInterface() {
continue
}
result[field.tag] = fieldValue.Interface()
}
return result, nil
}
// getStructFieldNames extracts all field names (tag values) from a struct
func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
if target == nil {
return []string{}, ErrTargetNil
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return []string{}, ErrPointerTargetNil
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() != reflect.Struct {
return []string{}, ErrTargetNotStruct
}
plan, err := getStructTagPlan(targetType, tagKey)
if err != nil {
return []string{}, err
}
result := make([]string, 0, len(plan))
for _, field := range plan {
fieldValue, ok := resolveFieldByPath(targetValue, field.path)
if !ok {
continue
}
if !fieldValue.CanInterface() {
continue
}
result = append(result, field.tag)
}
return result, nil
}
// isWritable checks if a value is writable
func isWritable(target interface{}) bool {
if target == nil {
return false
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
return targetType.Kind() == reflect.Ptr || targetValue.CanSet()
}
// isStruct checks if a value is a struct
func isStruct(target interface{}) bool {
if target == nil {
return false
}
targetValue := reflect.ValueOf(target)
if targetValue.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return false
}
targetValue = targetValue.Elem()
}
return targetValue.Kind() == reflect.Struct
}

286
result.go
View File

@ -1,286 +0,0 @@
package stardb
import (
"fmt"
"reflect"
"time"
)
// StarResult represents a single row result
type StarResult struct {
result []interface{}
columns []string
columnIndex map[string]int
columnsType []reflect.Type
}
// Result returns the raw result slice
func (r *StarResult) Result() []interface{} {
return r.result
}
// Columns returns column names
func (r *StarResult) Columns() []string {
return r.columns
}
// ColumnsType returns column types
func (r *StarResult) ColumnsType() []reflect.Type {
return r.columnsType
}
// IsNil checks if a column value is nil
func (r *StarResult) IsNil(name string) bool {
index, ok := r.columnIndex[name]
if !ok {
return true
}
return r.result[index] == nil
}
// MustString returns column value as string
func (r *StarResult) MustString(name string) string {
index, ok := r.columnIndex[name]
if !ok {
return ""
}
return convertToString(r.result[index])
}
// MustInt returns column value as int
func (r *StarResult) MustInt(name string) int {
return int(r.MustInt64(name))
}
// MustInt32 returns column value as int32
func (r *StarResult) MustInt32(name string) int32 {
return int32(r.MustInt64(name))
}
// MustInt64 returns column value as int64
func (r *StarResult) MustInt64(name string) int64 {
index, ok := r.columnIndex[name]
if !ok {
return 0
}
return convertToInt64(r.result[index])
}
// MustUint64 returns column value as uint64
func (r *StarResult) MustUint64(name string) uint64 {
index, ok := r.columnIndex[name]
if !ok {
return 0
}
return convertToUint64(r.result[index])
}
// MustFloat32 returns column value as float32
func (r *StarResult) MustFloat32(name string) float32 {
return float32(r.MustFloat64(name))
}
// MustFloat64 returns column value as float64
func (r *StarResult) MustFloat64(name string) float64 {
index, ok := r.columnIndex[name]
if !ok {
return 0
}
return convertToFloat64(r.result[index])
}
// MustBool returns column value as bool
func (r *StarResult) MustBool(name string) bool {
index, ok := r.columnIndex[name]
if !ok {
return false
}
return convertToBool(r.result[index])
}
// MustBytes returns column value as []byte
func (r *StarResult) MustBytes(name string) []byte {
index, ok := r.columnIndex[name]
if !ok {
return []byte{}
}
if b, ok := r.result[index].([]byte); ok {
return b
}
return []byte(r.MustString(name))
}
// MustDate returns column value as time.Time
func (r *StarResult) MustDate(name, layout string) time.Time {
index, ok := r.columnIndex[name]
if !ok {
return time.Time{}
}
return convertToTime(r.result[index], layout)
}
// Scan scans the result into provided pointers
// Usage: row.Scan(&id, &name, &email)
func (r *StarResult) Scan(dest ...interface{}) error {
if len(dest) != len(r.result) {
return fmt.Errorf("expected %d destination arguments, got %d", len(r.result), len(dest))
}
for i, d := range dest {
if err := convertAssign(d, r.result[i]); err != nil {
return fmt.Errorf("error scanning column %d: %v", i, err)
}
}
return nil
}
// convertAssign assigns src to dest
func convertAssign(dest, src interface{}) error {
switch d := dest.(type) {
case *string:
*d = convertToString(src)
case *int:
*d = int(convertToInt64(src))
case *int32:
*d = int32(convertToInt64(src))
case *int64:
*d = convertToInt64(src)
case *uint64:
*d = convertToUint64(src)
case *float32:
*d = float32(convertToFloat64(src))
case *float64:
*d = convertToFloat64(src)
case *bool:
*d = convertToBool(src)
case *time.Time:
if t, ok := src.(time.Time); ok {
*d = t
}
case *[]byte:
if b, ok := src.([]byte); ok {
*d = b
}
default:
return fmt.Errorf("unsupported Scan type: %T", dest)
}
return nil
}
// StarResultCol represents a single column result
type StarResultCol struct {
result []interface{}
}
// Result returns the raw result slice
func (c *StarResultCol) Result() []interface{} {
return c.result
}
// Len returns the number of values
func (c *StarResultCol) Len() int {
return len(c.result)
}
// IsNil returns slice of nil checks for each row
func (c *StarResultCol) IsNil() []bool {
result := make([]bool, len(c.result))
for i, v := range c.result {
result[i] = (v == nil)
}
return result
}
// MustString returns all values as strings
func (c *StarResultCol) MustString() []string {
result := make([]string, len(c.result))
for i, v := range c.result {
result[i] = convertToString(v)
}
return result
}
// MustInt returns all values as ints
func (c *StarResultCol) MustInt() []int {
result := make([]int, len(c.result))
for i, v := range c.result {
result[i] = int(convertToInt64(v))
}
return result
}
// MustInt32 returns all values as int32
func (c *StarResultCol) MustInt32() []int32 {
result := make([]int32, len(c.result))
for i, v := range c.result {
result[i] = int32(convertToInt64(v))
}
return result
}
// MustInt64 returns all values as int64
func (c *StarResultCol) MustInt64() []int64 {
result := make([]int64, len(c.result))
for i, v := range c.result {
result[i] = convertToInt64(v)
}
return result
}
// MustUint64 returns all values as uint64
func (c *StarResultCol) MustUint64() []uint64 {
result := make([]uint64, len(c.result))
for i, v := range c.result {
result[i] = convertToUint64(v)
}
return result
}
// MustFloat32 returns all values as float32
func (c *StarResultCol) MustFloat32() []float32 {
result := make([]float32, len(c.result))
for i, v := range c.result {
result[i] = float32(convertToFloat64(v))
}
return result
}
// MustFloat64 returns all values as float64
func (c *StarResultCol) MustFloat64() []float64 {
result := make([]float64, len(c.result))
for i, v := range c.result {
result[i] = convertToFloat64(v)
}
return result
}
// MustBool returns all values as bool
func (c *StarResultCol) MustBool() []bool {
result := make([]bool, len(c.result))
for i, v := range c.result {
result[i] = convertToBool(v)
}
return result
}
// MustBytes returns all values as []byte
func (c *StarResultCol) MustBytes() [][]byte {
result := make([][]byte, len(c.result))
for i, v := range c.result {
if b, ok := v.([]byte); ok {
result[i] = b
} else {
result[i] = []byte(convertToString(v))
}
}
return result
}
// MustDate returns all values as time.Time
func (c *StarResultCol) MustDate(layout string) []time.Time {
result := make([]time.Time, len(c.result))
for i, v := range c.result {
result[i] = convertToTime(v, layout)
}
return result
}

View File

@ -1,124 +0,0 @@
package stardb
import (
internalconv "b612.me/stardb/internal/convert"
"database/sql"
)
func (r *StarResult) getColumnValue(name string) (interface{}, error) {
if r == nil || r.columnIndex == nil {
return nil, wrapColumnNotFound(name)
}
index, ok := r.columnIndex[name]
if !ok {
return nil, wrapColumnNotFound(name)
}
return r.Result()[index], nil
}
// GetString returns column value as string with error
func (r *StarResult) GetString(name string) (string, error) {
val, err := r.getColumnValue(name)
if err != nil {
return "", err
}
return ConvertToStringSafe(val)
}
// GetInt64 returns column value as int64 with error
func (r *StarResult) GetInt64(name string) (int64, error) {
val, err := r.getColumnValue(name)
if err != nil {
return 0, err
}
return ConvertToInt64Safe(val)
}
// GetFloat64 returns column value as float64 with error
func (r *StarResult) GetFloat64(name string) (float64, error) {
val, err := r.getColumnValue(name)
if err != nil {
return 0, err
}
return internalconv.ToFloat64Safe(val)
}
// GetNullString returns a nullable string value.
func (r *StarResult) GetNullString(name string) (sql.NullString, error) {
val, err := r.getColumnValue(name)
if err != nil {
return sql.NullString{}, err
}
if val == nil {
return sql.NullString{}, nil
}
str, err := ConvertToStringSafe(val)
if err != nil {
return sql.NullString{}, err
}
return sql.NullString{String: str, Valid: true}, nil
}
// GetNullInt64 returns a nullable int64 value.
func (r *StarResult) GetNullInt64(name string) (sql.NullInt64, error) {
val, err := r.getColumnValue(name)
if err != nil {
return sql.NullInt64{}, err
}
if val == nil {
return sql.NullInt64{}, nil
}
i, err := ConvertToInt64Safe(val)
if err != nil {
return sql.NullInt64{}, err
}
return sql.NullInt64{Int64: i, Valid: true}, nil
}
// GetNullFloat64 returns a nullable float64 value.
func (r *StarResult) GetNullFloat64(name string) (sql.NullFloat64, error) {
val, err := r.getColumnValue(name)
if err != nil {
return sql.NullFloat64{}, err
}
if val == nil {
return sql.NullFloat64{}, nil
}
f, err := internalconv.ToFloat64Safe(val)
if err != nil {
return sql.NullFloat64{}, err
}
return sql.NullFloat64{Float64: f, Valid: true}, nil
}
// GetNullBool returns a nullable bool value.
func (r *StarResult) GetNullBool(name string) (sql.NullBool, error) {
val, err := r.getColumnValue(name)
if err != nil {
return sql.NullBool{}, err
}
if val == nil {
return sql.NullBool{}, nil
}
b, err := internalconv.ToBoolSafe(val)
if err != nil {
return sql.NullBool{}, err
}
return sql.NullBool{Bool: b, Valid: true}, nil
}
// GetNullTime returns a nullable time value.
func (r *StarResult) GetNullTime(name string) (sql.NullTime, error) {
val, err := r.getColumnValue(name)
if err != nil {
return sql.NullTime{}, err
}
if val == nil {
return sql.NullTime{}, nil
}
t, err := internalconv.ToTimeSafe(val)
if err != nil {
return sql.NullTime{}, err
}
return sql.NullTime{Time: t, Valid: true}, nil
}

177
rows.go
View File

@ -1,177 +0,0 @@
package stardb
import (
"b612.me/stardb/internal/scanutil"
"database/sql"
"reflect"
"strconv"
"time"
)
// StarRows represents a result set from a query
type StarRows struct {
rows *sql.Rows
db *StarDB
length int
stringResult []map[string]string
columns []string
columnsType []reflect.Type
columnIndex map[string]int
data [][]interface{}
parsed bool
}
// Length returns the number of rows
func (r *StarRows) Length() int {
return r.length
}
// StringResult returns all rows as string maps
func (r *StarRows) StringResult() []map[string]string {
return r.stringResult
}
// Columns returns column names
func (r *StarRows) Columns() []string {
return r.columns
}
// ColumnsType returns column types
func (r *StarRows) ColumnsType() []reflect.Type {
return r.columnsType
}
// Close closes the result set
func (r *StarRows) Close() error {
return r.rows.Close()
}
// Rescan re-parses the result set
func (r *StarRows) Rescan() error {
return r.parse()
}
// Row returns a specific row by index
func (r *StarRows) Row(index int) *StarResult {
result := &StarResult{}
if index < 0 || index >= len(r.data) {
return result
}
result.result = r.data[index]
result.columns = r.columns
result.columnsType = r.columnsType
result.columnIndex = r.columnIndex
return result
}
// Col returns all values for a specific column
func (r *StarRows) Col(name string) *StarResultCol {
result := &StarResultCol{}
if _, ok := r.columnIndex[name]; !ok {
return result
}
colIndex := r.columnIndex[name]
for _, row := range r.data {
result.result = append(result.result, row[colIndex])
}
return result
}
// parse parses the sql.Rows into internal data structures
func (r *StarRows) parse() error {
if r.parsed {
return nil
}
r.data = [][]interface{}{}
r.columnIndex = make(map[string]int)
r.stringResult = []map[string]string{}
r.columnsType = []reflect.Type{}
var err error
r.columns, err = r.rows.Columns()
if err != nil {
return err
}
columnTypes, err := r.rows.ColumnTypes()
if err != nil {
return err
}
for _, colType := range columnTypes {
r.columnsType = append(r.columnsType, colType.ScanType())
}
// Build column index map
for i, colName := range r.columns {
r.columnIndex[colName] = i
}
// Prepare scan arguments
scanArgs := make([]interface{}, len(r.columns))
values := make([]interface{}, len(r.columns))
for i := range values {
scanArgs[i] = &values[i]
}
// Scan all rows
for r.rows.Next() {
if err := r.rows.Scan(scanArgs...); err != nil {
return err
}
record := make(map[string]string)
rowCopy := make([]interface{}, len(values))
for i, val := range values {
copiedVal := cloneScannedValue(val)
rowCopy[i] = copiedVal
record[r.columns[i]] = convertToString(copiedVal)
}
r.data = append(r.data, rowCopy)
r.stringResult = append(r.stringResult, record)
}
if err := r.rows.Err(); err != nil {
return err
}
r.length = len(r.stringResult)
r.parsed = true
return nil
}
func cloneScannedValue(val interface{}) interface{} {
return scanutil.CloneScannedValue(val)
}
// convertToString converts any value to string
func convertToString(val interface{}) string {
switch v := val.(type) {
case nil:
return ""
case string:
return v
case int:
return strconv.Itoa(v)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case bool:
return strconv.FormatBool(v)
case time.Time:
return v.String()
case []byte:
return string(v)
default:
return ""
}
}

View File

@ -1,29 +0,0 @@
package stardb
import "testing"
func TestCloneScannedValue_BytesAreCopied(t *testing.T) {
original := []byte("hello")
clonedAny := cloneScannedValue(original)
cloned, ok := clonedAny.([]byte)
if !ok {
t.Fatalf("expected []byte, got %T", clonedAny)
}
original[0] = 'H'
if string(cloned) != "hello" {
t.Fatalf("expected cloned value to remain 'hello', got '%s'", string(cloned))
}
if len(cloned) > 0 && &cloned[0] == &original[0] {
t.Fatal("expected cloned bytes to have a different backing array")
}
}
func TestCloneScannedValue_NonBytesKeepReference(t *testing.T) {
in := int64(42)
out := cloneScannedValue(in)
if out != in {
t.Fatalf("expected %v, got %v", in, out)
}
}

View File

@ -1,124 +0,0 @@
package stardb
import (
"context"
"database/sql"
"errors"
"reflect"
)
// ScanEachFunc is called for each scanned row in streaming mode.
type ScanEachFunc func(row *StarResult) error
// ScanEach executes query in streaming mode and invokes fn for each row.
func (s *StarDB) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(nil, query, fn, args...)
}
// ScanEachContext executes query with context in streaming mode and invokes fn for each row.
func (s *StarDB) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(ctx, query, fn, args...)
}
func (s *StarDB) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
rows, err := s.queryRaw(ctx, query, args...)
if err != nil {
return err
}
return scanEachSQLRows(rows, fn)
}
// ScanEach executes query in transaction streaming mode and invokes fn for each row.
func (t *StarTx) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
return t.scanEach(nil, query, fn, args...)
}
// ScanEachContext executes query with context in transaction streaming mode and invokes fn for each row.
func (t *StarTx) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
return t.scanEach(ctx, query, fn, args...)
}
func (t *StarTx) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
rows, err := t.queryRaw(ctx, query, args...)
if err != nil {
return err
}
return scanEachSQLRows(rows, fn)
}
// ScanEach executes prepared statement query in streaming mode and invokes fn for each row.
func (s *StarStmt) ScanEach(fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(nil, fn, args...)
}
// ScanEachContext executes prepared statement query with context in streaming mode and invokes fn for each row.
func (s *StarStmt) ScanEachContext(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
return s.scanEach(ctx, fn, args...)
}
func (s *StarStmt) scanEach(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
rows, err := s.queryRaw(ctx, args...)
if err != nil {
return err
}
return scanEachSQLRows(rows, fn)
}
func scanEachSQLRows(rows *sql.Rows, fn ScanEachFunc) error {
if fn == nil {
_ = rows.Close()
return ErrScanFuncNil
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
return err
}
columnTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
types := make([]reflect.Type, len(columnTypes))
for i, colType := range columnTypes {
types[i] = colType.ScanType()
}
columnIndex := make(map[string]int, len(columns))
for i, colName := range columns {
columnIndex[colName] = i
}
scanArgs := make([]interface{}, len(columns))
values := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
return err
}
rowCopy := make([]interface{}, len(values))
for i, val := range values {
rowCopy[i] = cloneScannedValue(val)
}
row := &StarResult{
result: rowCopy,
columns: columns,
columnIndex: columnIndex,
columnsType: types,
}
if err := fn(row); err != nil {
if errors.Is(err, ErrScanStopped) {
return nil
}
return err
}
}
return rows.Err()
}

View File

@ -1,119 +0,0 @@
package stardb
import (
"context"
"reflect"
)
// ScanEachORMFunc is called for each mapped struct in streaming ORM mode.
type ScanEachORMFunc func(target interface{}) error
// ScanEachORM streams query rows and maps each row to target before invoking fn.
// target must be a pointer to struct; it is reused for each row.
func (s *StarDB) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
return s.ScanEachORMContext(nil, query, target, fn, args...)
}
// ScanEachORMContext streams query rows with context and maps each row to target before invoking fn.
// target must be a pointer to struct; it is reused for each row.
func (s *StarDB) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
if fn == nil {
return ErrScanORMFuncNil
}
if err := validateScanORMTarget(target); err != nil {
return err
}
return s.ScanEachContext(ctx, query, func(row *StarResult) error {
if err := mapResultToStructTarget(row, target, s); err != nil {
return err
}
return fn(target)
}, args...)
}
// ScanEachORM streams transaction rows and maps each row to target before invoking fn.
// target must be a pointer to struct; it is reused for each row.
func (t *StarTx) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
return t.ScanEachORMContext(nil, query, target, fn, args...)
}
// ScanEachORMContext streams transaction rows with context and maps each row to target before invoking fn.
// target must be a pointer to struct; it is reused for each row.
func (t *StarTx) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
if fn == nil {
return ErrScanORMFuncNil
}
if err := validateScanORMTarget(target); err != nil {
return err
}
return t.ScanEachContext(ctx, query, func(row *StarResult) error {
if err := mapResultToStructTarget(row, target, t.db); err != nil {
return err
}
return fn(target)
}, args...)
}
// ScanEachORM streams prepared statement rows and maps each row to target before invoking fn.
// target must be a pointer to struct; it is reused for each row.
func (s *StarStmt) ScanEachORM(target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
return s.ScanEachORMContext(nil, target, fn, args...)
}
// ScanEachORMContext streams prepared statement rows with context and maps each row to target before invoking fn.
// target must be a pointer to struct; it is reused for each row.
func (s *StarStmt) ScanEachORMContext(ctx context.Context, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
if fn == nil {
return ErrScanORMFuncNil
}
if err := validateScanORMTarget(target); err != nil {
return err
}
return s.ScanEachContext(ctx, func(row *StarResult) error {
if err := mapResultToStructTarget(row, target, s.db); err != nil {
return err
}
return fn(target)
}, args...)
}
func validateScanORMTarget(target interface{}) error {
if target == nil {
return ErrTargetNil
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() != reflect.Ptr {
return ErrTargetNotPointer
}
if targetValue.IsNil() {
return ErrTargetPointerNil
}
if targetValue.Elem().Kind() != reflect.Struct {
return ErrTargetNotStruct
}
return nil
}
func mapResultToStructTarget(row *StarResult, target interface{}, db *StarDB) error {
targetValue := reflect.ValueOf(target)
targetValue.Elem().Set(reflect.Zero(targetValue.Elem().Type()))
rowWrapper := &StarRows{
db: db,
length: 1,
columns: row.columns,
columnsType: row.columnsType,
columnIndex: row.columnIndex,
data: [][]interface{}{row.result},
parsed: true,
}
return rowWrapper.setStructFieldsFromRow(target, "db", 0)
}

View File

@ -1,19 +0,0 @@
package stardb
import internalsqlplaceholder "b612.me/stardb/internal/sqlplaceholder"
// ConvertPlaceholders converts placeholders according to style.
func ConvertPlaceholders(query string, style PlaceholderStyle) string {
switch normalizePlaceholderStyle(style) {
case PlaceholderDollar:
return ConvertQuestionToDollarPlaceholders(query)
default:
return query
}
}
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
// It skips quoted strings, quoted identifiers and comments.
func ConvertQuestionToDollarPlaceholders(query string) string {
return internalsqlplaceholder.ConvertQuestionToDollarPlaceholders(query)
}

View File

@ -1,34 +0,0 @@
package stardb
import "testing"
func TestConvertQuestionToDollarPlaceholders(t *testing.T) {
query := "SELECT * FROM users WHERE id = ? AND name = ?"
got := ConvertQuestionToDollarPlaceholders(query)
want := "SELECT * FROM users WHERE id = $1 AND name = $2"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
func TestConvertQuestionToDollarPlaceholders_SkipQuotedAndComments(t *testing.T) {
query := "SELECT '?', \"?\", `?`, col FROM t WHERE id = ? -- ?\nAND note = '??' /* ? */ AND x = ?"
got := ConvertQuestionToDollarPlaceholders(query)
want := "SELECT '?', \"?\", `?`, col FROM t WHERE id = $1 -- ?\nAND note = '??' /* ? */ AND x = $2"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}
func TestConvertPlaceholders(t *testing.T) {
query := "SELECT * FROM t WHERE a = ? AND b = ?"
if got := ConvertPlaceholders(query, PlaceholderQuestion); got != query {
t.Fatalf("question style should keep query unchanged, got %q", got)
}
got := ConvertPlaceholders(query, PlaceholderDollar)
want := "SELECT * FROM t WHERE a = $1 AND b = $2"
if got != want {
t.Fatalf("expected %q, got %q", want, got)
}
}

View File

@ -1,276 +0,0 @@
package stardb
import (
internalsqlruntime "b612.me/stardb/internal/sqlruntime"
"context"
"time"
)
// SQLBeforeHook runs before a SQL statement is executed.
type SQLBeforeHook func(ctx context.Context, query string, args []interface{})
// SQLAfterHook runs after a SQL statement is executed.
type SQLAfterHook func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error)
// PlaceholderStyle controls SQL placeholder format conversion.
type PlaceholderStyle int
const (
// PlaceholderQuestion keeps '?' placeholders unchanged.
PlaceholderQuestion PlaceholderStyle = iota
// PlaceholderDollar converts '?' placeholders to '$1,$2,...'.
PlaceholderDollar
)
// SQLFingerprintMode controls SQL fingerprint generation strategy.
type SQLFingerprintMode int
const (
// SQLFingerprintBasic lowercases SQL and collapses whitespace.
SQLFingerprintBasic SQLFingerprintMode = iota
// SQLFingerprintMaskLiterals also masks numeric/string literals and $n placeholders.
SQLFingerprintMaskLiterals
)
type sqlHookMetaKey struct{}
// SQLHookMeta contains extra hook metadata attached to context.
type SQLHookMeta struct {
Fingerprint string
}
// SQLHookMetaFromContext extracts SQL hook metadata from context.
func SQLHookMetaFromContext(ctx context.Context) (SQLHookMeta, bool) {
if ctx == nil {
return SQLHookMeta{}, false
}
meta, ok := ctx.Value(sqlHookMetaKey{}).(SQLHookMeta)
return meta, ok
}
func withSQLHookMeta(ctx context.Context, meta SQLHookMeta) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, sqlHookMetaKey{}, meta)
}
type sqlRuntime struct {
state internalsqlruntime.State
}
func normalizePlaceholderStyle(style PlaceholderStyle) PlaceholderStyle {
return PlaceholderStyle(internalsqlruntime.NormalizePlaceholderStyle(int(style)))
}
func normalizeSQLFingerprintMode(mode SQLFingerprintMode) SQLFingerprintMode {
return SQLFingerprintMode(internalsqlruntime.NormalizeFingerprintMode(int(mode)))
}
func cloneHookArgs(args []interface{}) []interface{} {
return internalsqlruntime.CloneHookArgs(args)
}
func (s *StarDB) runtimeOptions() (SQLBeforeHook, SQLAfterHook, PlaceholderStyle, time.Duration) {
if s == nil {
return nil, nil, PlaceholderQuestion, 0
}
beforeAny, afterAny, rawStyle, slowThreshold := s.runtime.state.Options()
var before SQLBeforeHook
if b, ok := beforeAny.(SQLBeforeHook); ok {
before = b
}
var after SQLAfterHook
if a, ok := afterAny.(SQLAfterHook); ok {
after = a
}
return before, after, normalizePlaceholderStyle(PlaceholderStyle(rawStyle)), slowThreshold
}
func (s *StarDB) sqlHooks() (SQLBeforeHook, SQLAfterHook, time.Duration) {
before, after, _, slowThreshold := s.runtimeOptions()
return before, after, slowThreshold
}
func (s *StarDB) prepareSQLCall(query string, args []interface{}) (string, SQLBeforeHook, SQLAfterHook, []interface{}, time.Duration) {
before, after, style, slowThreshold := s.runtimeOptions()
query = ConvertPlaceholders(query, style)
if before == nil && after == nil {
return query, nil, nil, nil, slowThreshold
}
return query, before, after, cloneHookArgs(args), slowThreshold
}
func (s *StarDB) hookContext(ctx context.Context, query string, before SQLBeforeHook, after SQLAfterHook) context.Context {
if s == nil {
return ctx
}
needCounter := s.SQLFingerprintCounterEnabled()
needMeta := (before != nil || after != nil) && s.SQLFingerprintEnabled()
if !needCounter && !needMeta {
return ctx
}
mode := s.SQLFingerprintMode()
keepComments := s.SQLFingerprintKeepComments()
fingerprint := internalsqlruntime.FingerprintSQL(query, int(mode), keepComments)
if fingerprint == "" {
return ctx
}
if needCounter {
s.runtime.state.IncFingerprintCount(fingerprint)
}
if !needMeta {
return ctx
}
return withSQLHookMeta(ctx, SQLHookMeta{Fingerprint: fingerprint})
}
func shouldRunAfterHook(after SQLAfterHook, slowThreshold, duration time.Duration, err error) bool {
return internalsqlruntime.ShouldRunAfterHook(after != nil, slowThreshold, duration, err)
}
// SetSQLHooks sets SQL before/after hooks.
func (s *StarDB) SetSQLHooks(before SQLBeforeHook, after SQLAfterHook) {
if s == nil {
return
}
s.runtime.state.SetHooks(before, after)
}
// SetSQLBeforeHook sets SQL before hook.
func (s *StarDB) SetSQLBeforeHook(before SQLBeforeHook) {
if s == nil {
return
}
s.runtime.state.SetBeforeHook(before)
}
// SetSQLAfterHook sets SQL after hook.
func (s *StarDB) SetSQLAfterHook(after SQLAfterHook) {
if s == nil {
return
}
s.runtime.state.SetAfterHook(after)
}
// SetPlaceholderStyle sets placeholder conversion style.
func (s *StarDB) SetPlaceholderStyle(style PlaceholderStyle) {
if s == nil {
return
}
s.runtime.state.SetPlaceholderStyle(int(style))
}
// SetSQLSlowThreshold sets minimum duration for triggering after hook.
// When threshold > 0, after hook runs only for statements slower than threshold or those with error.
func (s *StarDB) SetSQLSlowThreshold(threshold time.Duration) {
if s == nil {
return
}
s.runtime.state.SetSlowThreshold(threshold)
}
// SQLSlowThreshold returns current slow SQL threshold.
func (s *StarDB) SQLSlowThreshold() time.Duration {
if s == nil {
return 0
}
return s.runtime.state.SlowThreshold()
}
// PlaceholderStyle returns current placeholder style.
func (s *StarDB) PlaceholderStyle() PlaceholderStyle {
if s == nil {
return PlaceholderQuestion
}
style := PlaceholderStyle(s.runtime.state.PlaceholderStyle())
return normalizePlaceholderStyle(style)
}
// SetSQLFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
func (s *StarDB) SetSQLFingerprintEnabled(enabled bool) {
if s == nil {
return
}
s.runtime.state.SetFingerprintEnabled(enabled)
}
// SQLFingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
func (s *StarDB) SQLFingerprintEnabled() bool {
if s == nil {
return false
}
return s.runtime.state.FingerprintEnabled()
}
// SetSQLFingerprintMode sets SQL fingerprint generation mode.
func (s *StarDB) SetSQLFingerprintMode(mode SQLFingerprintMode) {
if s == nil {
return
}
s.runtime.state.SetFingerprintMode(int(mode))
}
// SQLFingerprintMode returns SQL fingerprint generation mode.
func (s *StarDB) SQLFingerprintMode() SQLFingerprintMode {
if s == nil {
return SQLFingerprintBasic
}
mode := SQLFingerprintMode(s.runtime.state.FingerprintMode())
return normalizeSQLFingerprintMode(mode)
}
// SetSQLFingerprintKeepComments controls whether comments are preserved in SQL fingerprints.
// Default is false.
func (s *StarDB) SetSQLFingerprintKeepComments(keep bool) {
if s == nil {
return
}
s.runtime.state.SetFingerprintKeepComments(keep)
}
// SQLFingerprintKeepComments reports whether SQL fingerprints preserve comments.
func (s *StarDB) SQLFingerprintKeepComments() bool {
if s == nil {
return false
}
return s.runtime.state.FingerprintKeepComments()
}
// SetSQLFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
// Default is false.
func (s *StarDB) SetSQLFingerprintCounterEnabled(enabled bool) {
if s == nil {
return
}
s.runtime.state.SetFingerprintCounterEnabled(enabled)
}
// SQLFingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
func (s *StarDB) SQLFingerprintCounterEnabled() bool {
if s == nil {
return false
}
return s.runtime.state.FingerprintCounterEnabled()
}
// SQLFingerprintCounters returns a snapshot of fingerprint hit counters.
func (s *StarDB) SQLFingerprintCounters() map[string]uint64 {
if s == nil {
return map[string]uint64{}
}
return s.runtime.state.FingerprintCountsSnapshot()
}
// ResetSQLFingerprintCounters clears all in-memory fingerprint hit counters.
func (s *StarDB) ResetSQLFingerprintCounters() {
if s == nil {
return
}
s.runtime.state.ResetFingerprintCounts()
}

View File

@ -1,102 +0,0 @@
package stardb
import (
"context"
"sync"
"testing"
"time"
)
func TestStarDB_RuntimeConfigConcurrent(t *testing.T) {
db := NewStarDB()
before := func(ctx context.Context, query string, args []interface{}) {}
after := func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {}
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
for j := 0; j < 1000; j++ {
if (i+j)%2 == 0 {
db.SetPlaceholderStyle(PlaceholderDollar)
} else {
db.SetPlaceholderStyle(PlaceholderQuestion)
}
db.SetSQLSlowThreshold(time.Duration((i+j)%5) * time.Millisecond)
db.SetSQLFingerprintEnabled((i+j)%3 == 0)
db.SetSQLFingerprintMode(SQLFingerprintMode((i + j) % 3))
db.SetSQLFingerprintKeepComments((i+j)%4 == 0)
db.SetSQLFingerprintCounterEnabled((i+j)%5 == 0)
if (i+j)%7 == 0 {
db.ResetSQLFingerprintCounters()
}
db.SetSQLHooks(before, after)
_ = db.PlaceholderStyle()
_ = db.SQLSlowThreshold()
_ = db.SQLFingerprintEnabled()
_ = db.SQLFingerprintMode()
_ = db.SQLFingerprintKeepComments()
_ = db.SQLFingerprintCounterEnabled()
_ = db.SQLFingerprintCounters()
_, _, _, _ = db.runtimeOptions()
}
}(i)
}
wg.Wait()
}
func TestStarDB_SQLFingerprintMode(t *testing.T) {
db := NewStarDB()
if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic {
t.Fatalf("expected default mode SQLFingerprintBasic, got %v", got)
}
db.SetSQLFingerprintMode(SQLFingerprintMaskLiterals)
if got := db.SQLFingerprintMode(); got != SQLFingerprintMaskLiterals {
t.Fatalf("expected SQLFingerprintMaskLiterals, got %v", got)
}
db.SetSQLFingerprintMode(SQLFingerprintMode(99))
if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic {
t.Fatalf("expected invalid mode fallback to SQLFingerprintBasic, got %v", got)
}
}
func TestStarDB_SQLFingerprintKeepComments(t *testing.T) {
db := NewStarDB()
if db.SQLFingerprintKeepComments() {
t.Fatal("expected default keep comments to be false")
}
db.SetSQLFingerprintKeepComments(true)
if !db.SQLFingerprintKeepComments() {
t.Fatal("expected keep comments to be true")
}
db.SetSQLFingerprintKeepComments(false)
if db.SQLFingerprintKeepComments() {
t.Fatal("expected keep comments to be false")
}
}
func TestStarDB_SQLFingerprintCounterSwitch(t *testing.T) {
db := NewStarDB()
if db.SQLFingerprintCounterEnabled() {
t.Fatal("expected default counter switch to be false")
}
db.SetSQLFingerprintCounterEnabled(true)
if !db.SQLFingerprintCounterEnabled() {
t.Fatal("expected counter switch to be true")
}
db.ResetSQLFingerprintCounters()
if got := len(db.SQLFingerprintCounters()); got != 0 {
t.Fatalf("expected empty counters after reset, got %d", got)
}
}

552
stardb.go
View File

@ -1,552 +0,0 @@
package stardb
import (
"context"
"database/sql"
"strings"
"time"
)
// StarDB is a simple wrapper around sql.DB providing enhanced functionality
type StarDB struct {
db *sql.DB
ManualScan bool // If true, rows won't be automatically parsed
StrictORM bool // If true, Orm requires all tagged columns to exist in query results
// batchInsertMaxRows controls batch split size for BatchInsert/BatchInsertStructs.
// <= 0 means no split (single SQL statement).
batchInsertMaxRows int64
batchInsertMaxParams int64
runtime sqlRuntime
}
// NewStarDB creates a new StarDB instance
func NewStarDB() *StarDB {
return &StarDB{}
}
// NewStarDBWithDB creates a new StarDB instance with an existing *sql.DB
func NewStarDBWithDB(db *sql.DB) *StarDB {
return &StarDB{db: db}
}
// DB returns the underlying *sql.DB
func (s *StarDB) DB() *sql.DB {
return s.db
}
// SetDB sets the underlying *sql.DB
func (s *StarDB) SetDB(db *sql.DB) {
s.db = db
}
// SetStrictORM enables or disables strict column validation for Orm mapping.
func (s *StarDB) SetStrictORM(strict bool) {
if s == nil {
return
}
s.StrictORM = strict
}
func (s *StarDB) ensureDB() error {
if s == nil || s.db == nil {
return ErrDBNotInitialized
}
return nil
}
// Open opens a new database connection
func (s *StarDB) Open(driver, connStr string) error {
var err error
s.db, err = sql.Open(driver, connStr)
return err
}
// Close closes the database connection
func (s *StarDB) Close() error {
if err := s.ensureDB(); err != nil {
return err
}
return s.db.Close()
}
// Ping verifies the database connection is alive
func (s *StarDB) Ping() error {
if err := s.ensureDB(); err != nil {
return err
}
return s.db.Ping()
}
// PingContext verifies the database connection with context
func (s *StarDB) PingContext(ctx context.Context) error {
if err := s.ensureDB(); err != nil {
return err
}
return s.db.PingContext(ctx)
}
// Stats returns database statistics
func (s *StarDB) Stats() sql.DBStats {
if s == nil || s.db == nil {
return sql.DBStats{}
}
return s.db.Stats()
}
// SetMaxOpenConns sets the maximum number of open connections
func (s *StarDB) SetMaxOpenConns(n int) {
if s == nil || s.db == nil {
return
}
s.db.SetMaxOpenConns(n)
}
// SetMaxIdleConns sets the maximum number of idle connections
func (s *StarDB) SetMaxIdleConns(n int) {
if s == nil || s.db == nil {
return
}
s.db.SetMaxIdleConns(n)
}
// Conn returns a single connection from the pool
func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
return s.db.Conn(ctx)
}
// Query executes a query that returns rows
// Usage: Query("SELECT * FROM users WHERE id = ?", 1)
func (s *StarDB) Query(query string, args ...interface{}) (*StarRows, error) {
return s.query(nil, query, args...)
}
// QueryContext executes a query with context
func (s *StarDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
return s.query(ctx, query, args...)
}
// QueryRaw executes a query and returns *sql.Rows without automatic parsing.
func (s *StarDB) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
return s.queryRaw(nil, query, args...)
}
// QueryRawContext executes a query with context and returns *sql.Rows without automatic parsing.
func (s *StarDB) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return s.queryRaw(ctx, query, args...)
}
func (s *StarDB) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args)
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, hookArgs)
}
start := time.Now()
var (
rows *sql.Rows
err error
)
if ctx == nil {
rows, err = s.db.Query(query, args...)
} else {
rows, err = s.db.QueryContext(ctx, query, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return rows, nil
}
// query is the internal query implementation
func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
rows, err := s.queryRaw(ctx, query, args...)
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: s,
}
if !s.ManualScan {
if err := starRows.parse(); err != nil {
_ = rows.Close()
return nil, err
}
}
return starRows, nil
}
// Exec executes a query that doesn't return rows
// Usage: Exec("INSERT INTO users (name) VALUES (?)", "John")
func (s *StarDB) Exec(query string, args ...interface{}) (sql.Result, error) {
return s.exec(nil, query, args...)
}
// ExecContext executes a query with context
func (s *StarDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return s.exec(ctx, query, args...)
}
// exec is the internal exec implementation
func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args)
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, hookArgs)
}
start := time.Now()
var (
result sql.Result
err error
)
if ctx == nil {
result, err = s.db.Exec(query, args...)
} else {
result, err = s.db.ExecContext(ctx, query, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return result, nil
}
// Prepare creates a prepared statement
func (s *StarDB) Prepare(query string) (*StarStmt, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil)
hookCtx := s.hookContext(nil, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, nil)
}
start := time.Now()
stmt, err := s.db.Prepare(query)
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, nil, duration, err)
}
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil
}
// PrepareContext creates a prepared statement with context
func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil)
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, nil)
}
start := time.Now()
stmt, err := s.db.PrepareContext(ctx, query)
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, nil, duration, err)
}
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil
}
// QueryStmt executes a prepared statement query
// Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1)
func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := s.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(args...)
}
// QueryStmtContext executes a prepared statement query with context
func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := s.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.QueryContext(ctx, args...)
}
// ExecStmt executes a prepared statement
// Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John")
func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := s.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Exec(args...)
}
// ExecStmtContext executes a prepared statement with context
func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := s.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.ExecContext(ctx, args...)
}
// Begin starts a transaction
func (s *StarDB) Begin() (*StarTx, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
tx, err := s.db.Begin()
if err != nil {
return nil, err
}
return &StarTx{tx: tx, db: s}, nil
}
// BeginTx starts a transaction with options
func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) {
if err := s.ensureDB(); err != nil {
return nil, err
}
tx, err := s.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &StarTx{tx: tx, db: s}, nil
}
// StarStmt represents a prepared statement
type StarStmt struct {
stmt *sql.Stmt
db *StarDB
sqlText string
}
func (s *StarStmt) ensureStmt() error {
if s == nil || s.stmt == nil {
return ErrStmtNotInitialized
}
return nil
}
func (s *StarStmt) ensureStmtWithDB() error {
if err := s.ensureStmt(); err != nil {
return err
}
if s.db == nil {
return ErrStmtDBNotInitialized
}
return nil
}
// Query executes a prepared statement query
func (s *StarStmt) Query(args ...interface{}) (*StarRows, error) {
return s.query(nil, args...)
}
// QueryContext executes a prepared statement query with context
func (s *StarStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) {
return s.query(ctx, args...)
}
// QueryRaw executes a prepared statement query and returns *sql.Rows without automatic parsing.
func (s *StarStmt) QueryRaw(args ...interface{}) (*sql.Rows, error) {
return s.queryRaw(nil, args...)
}
// QueryRawContext executes a prepared statement query with context and returns *sql.Rows without automatic parsing.
func (s *StarStmt) QueryRawContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
return s.queryRaw(ctx, args...)
}
func (s *StarStmt) queryRaw(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
if err := s.ensureStmt(); err != nil {
return nil, err
}
var beforeHook SQLBeforeHook
var afterHook SQLAfterHook
var slowThreshold time.Duration
if s.db != nil {
beforeHook, afterHook, slowThreshold = s.db.sqlHooks()
}
var hookArgs []interface{}
if beforeHook != nil || afterHook != nil {
hookArgs = cloneHookArgs(args)
}
hookCtx := ctx
if s.db != nil {
hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook)
}
if beforeHook != nil {
beforeHook(hookCtx, s.sqlText, hookArgs)
}
start := time.Now()
var (
rows *sql.Rows
err error
)
if ctx == nil {
rows, err = s.stmt.Query(args...)
} else {
rows, err = s.stmt.QueryContext(ctx, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, s.sqlText, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return rows, nil
}
// query is the internal query implementation
func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) {
if err := s.ensureStmtWithDB(); err != nil {
return nil, err
}
rows, err := s.queryRaw(ctx, args...)
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: s.db,
}
if !s.db.ManualScan {
if err := starRows.parse(); err != nil {
_ = rows.Close()
return nil, err
}
}
return starRows, nil
}
// Exec executes a prepared statement
func (s *StarStmt) Exec(args ...interface{}) (sql.Result, error) {
return s.exec(nil, args...)
}
// ExecContext executes a prepared statement with context
func (s *StarStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
return s.exec(ctx, args...)
}
// exec is the internal exec implementation
func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) {
if err := s.ensureStmt(); err != nil {
return nil, err
}
var beforeHook SQLBeforeHook
var afterHook SQLAfterHook
var slowThreshold time.Duration
if s.db != nil {
beforeHook, afterHook, slowThreshold = s.db.sqlHooks()
}
var hookArgs []interface{}
if beforeHook != nil || afterHook != nil {
hookArgs = cloneHookArgs(args)
}
hookCtx := ctx
if s.db != nil {
hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook)
}
if beforeHook != nil {
beforeHook(hookCtx, s.sqlText, hookArgs)
}
start := time.Now()
var (
result sql.Result
err error
)
if ctx == nil {
result, err = s.stmt.Exec(args...)
} else {
result, err = s.stmt.ExecContext(ctx, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, s.sqlText, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return result, nil
}
// Close closes the prepared statement
func (s *StarStmt) Close() error {
if err := s.ensureStmt(); err != nil {
return err
}
return s.stmt.Close()
}

View File

@ -1,127 +0,0 @@
package stardb
import (
"context"
"errors"
"testing"
)
func TestStarDB_NotInitialized(t *testing.T) {
db := NewStarDB()
if err := db.Close(); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Close, got %v", err)
}
if err := db.Ping(); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Ping, got %v", err)
}
if err := db.PingContext(context.Background()); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from PingContext, got %v", err)
}
if _, err := db.Conn(context.Background()); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Conn, got %v", err)
}
if _, err := db.Query("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Query, got %v", err)
}
if err := db.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from ScanEach, got %v", err)
}
var model struct {
ID int `db:"id"`
}
if err := db.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from ScanEachORM, got %v", err)
}
if _, err := db.QueryRaw("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from QueryRaw, got %v", err)
}
if _, err := db.Exec("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Exec, got %v", err)
}
if _, err := db.Prepare("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Prepare, got %v", err)
}
if _, err := db.Begin(); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from Begin, got %v", err)
}
if err := db.WithTx(nil); !errors.Is(err, ErrTxFuncNil) {
t.Fatalf("expected ErrTxFuncNil from WithTx, got %v", err)
}
if err := db.WithTx(func(tx *StarTx) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
t.Fatalf("expected ErrDBNotInitialized from WithTx, got %v", err)
}
if _, err := db.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) {
t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err)
}
if _, err := db.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) {
t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err)
}
_ = db.Stats()
db.SetMaxOpenConns(5)
db.SetMaxIdleConns(2)
}
func TestStarTx_NotInitialized(t *testing.T) {
tx := &StarTx{}
var model struct {
ID int `db:"id"`
}
if _, err := tx.Query("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from Query, got %v", err)
}
if err := tx.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from ScanEach, got %v", err)
}
if err := tx.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from ScanEachORM, got %v", err)
}
if _, err := tx.QueryRaw("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from QueryRaw, got %v", err)
}
if _, err := tx.Exec("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from Exec, got %v", err)
}
if _, err := tx.Prepare("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from Prepare, got %v", err)
}
if err := tx.Commit(); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from Commit, got %v", err)
}
if err := tx.Rollback(); !errors.Is(err, ErrTxNotInitialized) {
t.Fatalf("expected ErrTxNotInitialized from Rollback, got %v", err)
}
if _, err := tx.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) {
t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err)
}
if _, err := tx.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) {
t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err)
}
}
func TestStarStmt_NotInitialized(t *testing.T) {
stmt := &StarStmt{}
if _, err := stmt.Query(); !errors.Is(err, ErrStmtNotInitialized) {
t.Fatalf("expected ErrStmtNotInitialized from Query, got %v", err)
}
if err := stmt.ScanEach(func(row *StarResult) error { return nil }); !errors.Is(err, ErrStmtNotInitialized) {
t.Fatalf("expected ErrStmtNotInitialized from ScanEach, got %v", err)
}
if err := stmt.ScanEachORM(nil, func(target interface{}) error { return nil }); !errors.Is(err, ErrTargetNil) {
t.Fatalf("expected ErrTargetNil from ScanEachORM, got %v", err)
}
if _, err := stmt.QueryRaw(); !errors.Is(err, ErrStmtNotInitialized) {
t.Fatalf("expected ErrStmtNotInitialized from QueryRaw, got %v", err)
}
if _, err := stmt.Exec(); !errors.Is(err, ErrStmtNotInitialized) {
t.Fatalf("expected ErrStmtNotInitialized from Exec, got %v", err)
}
if err := stmt.Close(); !errors.Is(err, ErrStmtNotInitialized) {
t.Fatalf("expected ErrStmtNotInitialized from Close, got %v", err)
}
}

864
stardb_v1.go Normal file
View File

@ -0,0 +1,864 @@
package stardb
import (
"database/sql"
"errors"
"reflect"
"strconv"
"time"
)
// StarDB 一个简单封装的DB库
type StarDB struct {
DB *sql.DB
Rows *sql.Rows
}
// StarRows 为查询结果集(按行)
type StarRows struct {
Rows *sql.Rows
Length int
StringResult []map[string]string
Columns []string
ColumnsType []reflect.Type
columnref map[string]int
result [][]interface{}
}
// StarResult 为查询结果集(总)
type StarResult struct {
Result []interface{}
Columns []string
columnref map[string]int
ColumnsType []reflect.Type
}
// StarResultCol 为查询结果集(按列)
type StarResultCol struct {
Result []interface{}
}
// MustBytes 列查询结果转Bytes
func (star *StarResultCol) MustBytes() [][]byte {
var res [][]byte
for _, v := range star.Result {
res = append(res, v.([]byte))
}
return res
}
// MustBool 列查询结果转Bool
func (star *StarResultCol) MustBool() []bool {
var res []bool
var tmp bool
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = false
case bool:
tmp = vtype
case float64:
if vtype > 0 {
tmp = true
} else {
tmp = false
}
case float32:
if vtype > 0 {
tmp = true
} else {
tmp = false
}
case int:
if vtype > 0 {
tmp = true
} else {
tmp = false
}
case int32:
if vtype > 0 {
tmp = true
} else {
tmp = false
}
case int64:
if vtype > 0 {
tmp = true
} else {
tmp = false
}
case string:
tmp, _ = strconv.ParseBool(vtype)
default:
tmp, _ = strconv.ParseBool(string(vtype.([]byte)))
}
res = append(res, tmp)
}
return res
}
// MustFloat32 列查询结果转Float32
func (star *StarResultCol) MustFloat32() []float32 {
var res []float32
var tmp float32
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = 0
case float32:
tmp = vtype
case float64:
tmp = float32(vtype)
case string:
tmps, _ := strconv.ParseFloat(vtype, 32)
tmp = float32(tmps)
case int:
tmp = float32(vtype)
case int32:
tmp = float32(vtype)
case int64:
tmp = float32(vtype)
case time.Time:
tmp = float32(vtype.Unix())
default:
tmpt := string(vtype.([]byte))
tmps, _ := strconv.ParseFloat(tmpt, 32)
tmp = float32(tmps)
}
res = append(res, tmp)
}
return res
}
// MustFloat64 列查询结果转Float64
func (star *StarResultCol) MustFloat64() []float64 {
var res []float64
var tmp float64
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = 0
case float64:
tmp = vtype
case float32:
tmp = float64(vtype)
case string:
tmp, _ = strconv.ParseFloat(vtype, 64)
case int:
tmp = float64(vtype)
case int32:
tmp = float64(vtype)
case int64:
tmp = float64(vtype)
case time.Time:
tmp = float64(vtype.Unix())
default:
tmpt := string(vtype.([]byte))
tmps, _ := strconv.ParseFloat(tmpt, 64)
tmp = float64(tmps)
}
res = append(res, tmp)
}
return res
}
// MustString 列查询结果转String
func (star *StarResultCol) MustString() []string {
var res []string
var tmp string
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = ""
case string:
tmp = vtype
case int64:
tmp = strconv.FormatInt(vtype, 10)
case int32:
tmp = strconv.Itoa(int(vtype))
case bool:
tmp = strconv.FormatBool(vtype)
case float64:
tmp = strconv.FormatFloat(vtype, 'f', 10, 64)
case float32:
tmp = strconv.FormatFloat(float64(vtype), 'f', 10, 32)
case int:
tmp = strconv.Itoa(vtype)
case time.Time:
tmp = vtype.String()
default:
tmp = string(vtype.([]byte))
}
res = append(res, tmp)
}
return res
}
// MustInt32 列查询结果转Int32
func (star *StarResultCol) MustInt32() []int32 {
var res []int32
var tmp int32
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = 0
case float64:
tmp = int32(vtype)
case float32:
tmp = int32(vtype)
case string:
tmps, _ := strconv.ParseInt(vtype, 10, 32)
tmp = int32(tmps)
case int:
tmp = int32(vtype)
case int64:
tmp = int32(vtype)
case int32:
tmp = vtype
case time.Time:
tmp = int32(vtype.Unix())
default:
tmpt := string(vtype.([]byte))
tmps, _ := strconv.ParseInt(tmpt, 10, 32)
tmp = int32(tmps)
}
res = append(res, tmp)
}
return res
}
// MustInt64 列查询结果转Int64
func (star *StarResultCol) MustInt64() []int64 {
var res []int64
var tmp int64
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = 0
case float64:
tmp = int64(vtype)
case float32:
tmp = int64(vtype)
case string:
tmps, _ := strconv.ParseInt(vtype, 10, 64)
tmp = int64(tmps)
case int:
tmp = int64(vtype)
case int32:
tmp = int64(vtype)
case int64:
tmp = vtype
case time.Time:
tmp = vtype.Unix()
default:
tmpt := string(vtype.([]byte))
tmp, _ = strconv.ParseInt(tmpt, 10, 64)
}
res = append(res, tmp)
}
return res
}
// MustInt 列查询结果转Int
func (star *StarResultCol) MustInt() []int {
var res []int
var tmp int
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = 0
case float64:
tmp = int(vtype)
case float32:
tmp = int(vtype)
case string:
tmps, _ := strconv.ParseInt(vtype, 10, 64)
tmp = int(tmps)
case int:
tmp = vtype
case int32:
tmp = int(vtype)
case int64:
tmp = int(vtype)
case time.Time:
tmp = int(vtype.Unix())
default:
tmpt := string(vtype.([]byte))
tmps, _ := strconv.ParseInt(tmpt, 10, 64)
tmp = int(tmps)
}
res = append(res, tmp)
}
return res
}
// MustDate 列查询结果转Date(time.Time)
func (star *StarResultCol) MustDate(layout string) []time.Time {
var res []time.Time
var tmp time.Time
for _, v := range star.Result {
switch vtype := v.(type) {
case nil:
tmp = time.Time{}
case float64:
tmp = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000)
case float32:
tmp = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000)
case string:
tmp, _ = time.Parse(layout, vtype)
case int:
tmp = time.Unix(int64(vtype), 0)
case int32:
tmp = time.Unix(int64(vtype), 0)
case int64:
tmp = time.Unix(vtype, 0)
case time.Time:
tmp = vtype
default:
tmpt := string(vtype.([]byte))
tmp, _ = time.Parse(layout, tmpt)
}
res = append(res, tmp)
}
return res
}
// IsNil 检测是不是nil 列查询结果是不是nil
func (star *StarResultCol) IsNil(name string) []bool {
var res []bool
var tmp bool
for _, v := range star.Result {
switch v.(type) {
case nil:
tmp = true
default:
tmp = false
}
res = append(res, tmp)
}
return res
}
// IsNil 检测是不是nil
func (star *StarResult) IsNil(name string) bool {
num, ok := star.columnref[name]
if !ok {
return false
}
tmp := star.Result[num]
switch tmp.(type) {
case nil:
return true
default:
return false
}
}
// MustDate 列查询结果转Date
func (star *StarResult) MustDate(name, layout string) time.Time {
var res time.Time
num, ok := star.columnref[name]
if !ok {
return time.Time{}
}
tmp := star.Result[num]
switch vtype := tmp.(type) {
case nil:
res = time.Time{}
case float64:
res = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000)
case float32:
res = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000)
case string:
res, _ = time.Parse(layout, vtype)
case int:
res = time.Unix(int64(vtype), 0)
case int32:
res = time.Unix(int64(vtype), 0)
case int64:
res = time.Unix(vtype, 0)
case time.Time:
res = vtype
default:
res, _ = time.Parse(layout, string(tmp.([]byte)))
}
return res
}
// MustInt64 列查询结果转int64
func (star *StarResult) MustInt64(name string) int64 {
var res int64
num, ok := star.columnref[name]
if !ok {
return 0
}
tmp := star.Result[num]
switch vtype := tmp.(type) {
case nil:
res = 0
case float64:
res = int64(vtype)
case float32:
res = int64(vtype)
case string:
res, _ = strconv.ParseInt(vtype, 10, 64)
case int:
res = int64(vtype)
case int32:
res = int64(vtype)
case int64:
res = vtype
case time.Time:
res = int64(vtype.Unix())
default:
res, _ = strconv.ParseInt(string(tmp.([]byte)), 10, 64)
}
return res
}
// MustInt32 列查询结果转Int32
func (star *StarResult) MustInt32(name string) int32 {
var res int32
num, ok := star.columnref[name]
if !ok {
return 0
}
tmp := star.Result[num]
switch vtype := tmp.(type) {
case nil:
res = 0
case float64:
res = int32(vtype)
case float32:
res = int32(vtype)
case string:
ress, _ := strconv.ParseInt(vtype, 10, 32)
res = int32(ress)
case int:
res = int32(vtype)
case int32:
res = vtype
case int64:
res = int32(vtype)
case time.Time:
res = int32(vtype.Unix())
default:
ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 32)
res = int32(ress)
}
return res
}
// MustString 列查询结果转string
func (star *StarResult) MustString(name string) string {
var res string
num, ok := star.columnref[name]
if !ok {
return ""
}
switch vtype := star.Result[num].(type) {
case nil:
res = ""
case string:
res = vtype
case int64:
res = strconv.FormatInt(vtype, 10)
case int32:
res = strconv.Itoa(int(vtype))
case bool:
res = strconv.FormatBool(vtype)
case float64:
res = strconv.FormatFloat(vtype, 'f', 10, 64)
case float32:
res = strconv.FormatFloat(float64(vtype), 'f', 10, 32)
case int:
res = strconv.Itoa(vtype)
case time.Time:
res = vtype.String()
default:
res = string(vtype.([]byte))
}
return res
}
// MustFloat64 列查询结果转float64
func (star *StarResult) MustFloat64(name string) float64 {
var res float64
num, ok := star.columnref[name]
if !ok {
return 0
}
switch vtype := star.Result[num].(type) {
case nil:
res = 0
case string:
res, _ = strconv.ParseFloat(vtype, 64)
case float64:
res = vtype
case int:
res = float64(vtype)
case int64:
res = float64(vtype)
case int32:
res = float64(vtype)
case float32:
res = float64(vtype)
case time.Time:
res = float64(vtype.Unix())
default:
res, _ = strconv.ParseFloat(string(vtype.([]byte)), 64)
}
return res
}
// MustFloat32 列查询结果转float32
func (star *StarResult) MustFloat32(name string) float32 {
var res float32
num, ok := star.columnref[name]
if !ok {
return 0
}
switch vtype := star.Result[num].(type) {
case nil:
res = 0
case string:
tmp, _ := strconv.ParseFloat(vtype, 32)
res = float32(tmp)
case float64:
res = float32(vtype)
case float32:
res = vtype
case int:
res = float32(vtype)
case int64:
res = float32(vtype)
case int32:
res = float32(vtype)
case time.Time:
res = float32(vtype.Unix())
default:
tmp, _ := strconv.ParseFloat(string(vtype.([]byte)), 32)
res = float32(tmp)
}
return res
}
// MustInt 列查询结果转int
func (star *StarResult) MustInt(name string) int {
var res int
num, ok := star.columnref[name]
if !ok {
return 0
}
tmp := star.Result[num]
switch vtype := tmp.(type) {
case nil:
res = 0
case float64:
res = int(vtype)
case float32:
res = int(vtype)
case string:
ress, _ := strconv.ParseInt(vtype, 10, 64)
res = int(ress)
case int:
res = vtype
case int32:
res = int(vtype)
case int64:
res = int(vtype)
case time.Time:
res = int(vtype.Unix())
default:
ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 64)
res = int(ress)
}
return res
}
// MustBool 列查询结果转bool
func (star *StarResult) MustBool(name string) bool {
var res bool
num, ok := star.columnref[name]
if !ok {
return false
}
tmp := star.Result[num]
switch vtype := tmp.(type) {
case nil:
res = false
case bool:
res = vtype
case float64:
if vtype > 0 {
res = true
} else {
res = false
}
case float32:
if vtype > 0 {
res = true
} else {
res = false
}
case int:
if vtype > 0 {
res = true
} else {
res = false
}
case int32:
if vtype > 0 {
res = true
} else {
res = false
}
case int64:
if vtype > 0 {
res = true
} else {
res = false
}
case string:
res, _ = strconv.ParseBool(vtype)
default:
res, _ = strconv.ParseBool(string(vtype.([]byte)))
}
return res
}
// MustBytes 列查询结果转byte
func (star *StarResult) MustBytes(name string) []byte {
num, ok := star.columnref[name]
if !ok {
return []byte{}
}
res := star.Result[num].([]byte)
return res
}
// Rescan 重新分析结果集
func (star *StarRows) Rescan() {
star.parserows()
}
// Col 选择需要进行操作的数据结果列
func (star *StarRows) Col(name string) *StarResultCol {
result := new(StarResultCol)
if _, ok := star.columnref[name]; !ok {
return result
}
var rescol []interface{}
for _, v := range star.result {
rescol = append(rescol, v[star.columnref[name]])
}
result.Result = rescol
return result
}
// Row 选择需要进行操作的数据结果行
func (star *StarRows) Row(id int) *StarResult {
result := new(StarResult)
if id+1 > len(star.result) {
return result
}
result.Result = star.result[id]
result.Columns = star.Columns
result.ColumnsType = star.ColumnsType
result.columnref = star.columnref
return result
}
// Close 关闭打开的结果集
func (star *StarRows) Close() error {
return star.Rows.Close()
}
func (star *StarRows) parserows() {
star.result = [][]interface{}{}
star.columnref = make(map[string]int)
star.StringResult = []map[string]string{}
star.Columns, _ = star.Rows.Columns()
types, _ := star.Rows.ColumnTypes()
for _, v := range types {
star.ColumnsType = append(star.ColumnsType, v.ScanType())
}
scanArgs := make([]interface{}, len(star.Columns))
values := make([]interface{}, len(star.Columns))
for i := range values {
star.columnref[star.Columns[i]] = i
scanArgs[i] = &values[i]
}
for star.Rows.Next() {
if err := star.Rows.Scan(scanArgs...); err != nil {
return
}
record := make(map[string]string)
var rescopy []interface{}
for i, col := range values {
rescopy = append(rescopy, col)
switch vtype := col.(type) {
case float32:
record[star.Columns[i]] = strconv.FormatFloat(float64(vtype), 'f', -1, 64)
case float64:
record[star.Columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64)
case int64:
record[star.Columns[i]] = strconv.FormatInt(vtype, 10)
case int32:
record[star.Columns[i]] = strconv.FormatInt(int64(vtype), 10)
case int:
record[star.Columns[i]] = strconv.Itoa(vtype)
case string:
record[star.Columns[i]] = vtype
case bool:
record[star.Columns[i]] = strconv.FormatBool(vtype)
case time.Time:
record[star.Columns[i]] = vtype.String()
case nil:
record[star.Columns[i]] = ""
default:
record[star.Columns[i]] = string(vtype.([]byte))
}
}
star.result = append(star.result, rescopy)
star.StringResult = append(star.StringResult, record)
}
star.Length = len(star.StringResult)
}
// Query 进行Query操作
func (star *StarDB) Query(args ...interface{}) (*StarRows, error) {
var err error
effect := new(StarRows)
if err = star.DB.Ping(); err != nil {
return effect, err
}
if len(args) == 0 {
return effect, errors.New("no args")
}
if len(args) == 1 {
sql := args[0]
if star.Rows, err = star.DB.Query(sql.(string)); err != nil {
return effect, err
}
effect.Rows = star.Rows
effect.parserows()
return effect, nil
}
sql := args[0]
stmt, err := star.DB.Prepare(sql.(string))
if err != nil {
return effect, err
}
defer stmt.Close()
var para []interface{}
for k, v := range args {
if k != 0 {
switch vtype := v.(type) {
default:
para = append(para, vtype)
}
}
}
if star.Rows, err = stmt.Query(para...); err != nil {
return effect, err
}
effect.Rows = star.Rows
effect.parserows()
return effect, nil
}
// Open 打开一个新的数据库
func (star *StarDB) Open(Method, ConnStr string) error {
var err error
star.DB, err = sql.Open(Method, ConnStr)
if err != nil {
return err
}
err = star.DB.Ping()
return err
}
// Close 关闭打开的数据库
func (star *StarDB) Close() error {
if err := star.DB.Close(); err != nil {
return err
}
return star.DB.Close()
}
// Exec 执行Exec操作
func (star *StarDB) Exec(args ...interface{}) (sql.Result, error) {
var err error
var effect sql.Result
if err = star.DB.Ping(); err != nil {
return effect, err
}
if len(args) == 0 {
return effect, errors.New("no args")
}
if len(args) == 1 {
sql := args[0]
if _, err = star.DB.Exec(sql.(string)); err != nil {
return effect, err
}
return effect, nil
}
sql := args[0]
stmt, err := star.DB.Prepare(sql.(string))
if err != nil {
return effect, err
}
defer stmt.Close()
var para []interface{}
for k, v := range args {
if k != 0 {
switch vtype := v.(type) {
default:
para = append(para, vtype)
}
}
}
if effect, err = stmt.Exec(para...); err != nil {
return effect, err
}
return effect, nil
}
// FetchAll 把结果集全部转为key-value型<string>数据
func FetchAll(rows *sql.Rows) (error, map[int]map[string]string) {
var ii int = 0
records := make(map[int]map[string]string)
columns, err := rows.Columns()
if err != nil {
return err, records
}
scanArgs := make([]interface{}, len(columns))
values := make([]interface{}, len(columns))
for i := range values {
scanArgs[i] = &values[i]
}
for rows.Next() {
if err := rows.Scan(scanArgs...); err != nil {
return err, records
}
record := make(map[string]string)
for i, col := range values {
switch vtype := col.(type) {
case float64:
record[columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64)
case int64:
record[columns[i]] = strconv.FormatInt(vtype, 10)
case string:
record[columns[i]] = vtype
case nil:
record[columns[i]] = ""
default:
record[columns[i]] = string(vtype.([]byte))
}
}
records[ii] = record
ii++
}
return nil, records
}

View File

@ -1,888 +0,0 @@
package testing
import (
"context"
"errors"
"testing"
"time"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
type TestUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
CreatedAt time.Time `db:"created_at"`
}
func setupBatchTestDB(t *testing.T) *stardb.StarDB {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
_, err = db.Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER,
created_at DATETIME
)
`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
return db
}
func TestStarDB_BatchInsert_Basic(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", "charlie@example.com", 35},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 3 {
t.Errorf("Expected 3 rows affected, got %d", affected)
}
// Verify insertion
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 3 {
t.Errorf("Expected 3 rows in database, got %d", count)
}
}
func TestStarDB_BatchInsert_Single(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_BatchInsert_Empty(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{}
_, err := db.BatchInsert("users", columns, values)
if !errors.Is(err, stardb.ErrNoInsertValues) {
t.Errorf("Expected ErrNoInsertValues, got %v", err)
}
}
func TestStarDB_BatchInsert_EmptyColumns(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
values := [][]interface{}{
{"Alice"},
}
_, err := db.BatchInsert("users", nil, values)
if !errors.Is(err, stardb.ErrNoInsertColumns) {
t.Errorf("Expected ErrNoInsertColumns, got %v", err)
}
}
func TestStarDB_BatchInsert_EmptyTableName(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
}
_, err := db.BatchInsert("", columns, values)
if !errors.Is(err, stardb.ErrTableNameEmpty) {
t.Errorf("Expected ErrTableNameEmpty, got %v", err)
}
}
func TestStarDB_BatchInsert_RowLengthMismatch(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com"},
}
_, err := db.BatchInsert("users", columns, values)
if !errors.Is(err, stardb.ErrBatchRowValueCountMismatch) {
t.Errorf("Expected ErrBatchRowValueCountMismatch, got %v", err)
}
}
func TestStarDB_BatchInsert_Large(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
var values [][]interface{}
// Insert 100 rows
for i := 0; i < 100; i++ {
values = append(values, []interface{}{
"User" + string(rune(i)),
"user" + string(rune(i)) + "@example.com",
20 + i%50,
})
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 100 {
t.Errorf("Expected 100 rows affected, got %d", affected)
}
// Verify
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 100 {
t.Errorf("Expected 100 rows in database, got %d", count)
}
}
func TestStarDB_BatchInsertContext(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
ctx := context.Background()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
}
result, err := db.BatchInsertContext(ctx, "users", columns, values)
if err != nil {
t.Fatalf("BatchInsertContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 2 {
t.Errorf("Expected 2 rows affected, got %d", affected)
}
}
func TestStarDB_BatchInsertContext_Timeout(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
time.Sleep(10 * time.Millisecond) // Ensure timeout
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
}
_, err := db.BatchInsertContext(ctx, "users", columns, values)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}
func TestStarDB_BatchInsertMaxRows_Config(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
if got := db.BatchInsertMaxRows(); got != 0 {
t.Fatalf("Expected default chunk size 0, got %d", got)
}
db.SetBatchInsertMaxRows(3)
if got := db.BatchInsertMaxRows(); got != 3 {
t.Fatalf("Expected chunk size 3, got %d", got)
}
db.SetBatchInsertMaxRows(-10)
if got := db.BatchInsertMaxRows(); got != 0 {
t.Fatalf("Expected chunk size reset to 0, got %d", got)
}
}
func TestStarDB_BatchInsertMaxParams_Config(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
if got := db.BatchInsertMaxParams(); got != 0 {
t.Fatalf("Expected default max params 0, got %d", got)
}
db.SetBatchInsertMaxParams(100)
if got := db.BatchInsertMaxParams(); got != 100 {
t.Fatalf("Expected max params 100, got %d", got)
}
db.SetBatchInsertMaxParams(-1)
if got := db.BatchInsertMaxParams(); got != 0 {
t.Fatalf("Expected max params reset to 0, got %d", got)
}
}
func TestStarDB_BatchInsert_Chunked(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(2)
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", "charlie@example.com", 35},
{"David", "david@example.com", 40},
{"Eva", "eva@example.com", 28},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("Chunked BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != int64(len(values)) {
t.Fatalf("Expected %d affected rows, got %d", len(values), affected)
}
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if count := rows.Row(0).MustInt("count"); count != len(values) {
t.Fatalf("Expected %d rows in db, got %d", len(values), count)
}
}
func TestStarDB_BatchInsert_ChunkedRollbackOnError(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(2)
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", nil, 35}, // email NOT NULL, forces second chunk failure
}
if _, err := db.BatchInsert("users", columns, values); err == nil {
t.Fatal("Expected chunked BatchInsert to fail, got nil")
}
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if count := rows.Row(0).MustInt("count"); count != 0 {
t.Fatalf("Expected rollback to keep table empty, got %d rows", count)
}
}
func TestStarDB_BatchInsert_ChunkedByMaxParams(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(0) // disabled
db.SetBatchInsertMaxParams(4) // 3 columns -> 1 row per chunk
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", "charlie@example.com", 35},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert by max params failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != int64(len(values)) {
t.Fatalf("Expected %d affected rows, got %d", len(values), affected)
}
}
func TestStarDB_BatchInsert_MaxParamsTooLow(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(0)
db.SetBatchInsertMaxParams(2) // columns=3 -> invalid
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
}
_, err := db.BatchInsert("users", columns, values)
if !errors.Is(err, stardb.ErrBatchInsertMaxParamsTooLow) {
t.Fatalf("Expected ErrBatchInsertMaxParamsTooLow, got %v", err)
}
}
func TestStarDB_BatchInsert_ChunkedHookMeta(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(2)
db.SetBatchInsertMaxParams(0)
var metas []stardb.BatchExecMeta
db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok {
metas = append(metas, meta)
}
})
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", "charlie@example.com", 35},
{"David", "david@example.com", 40},
{"Eva", "eva@example.com", 28},
}
_, err := db.BatchInsertContext(context.Background(), "users", columns, values)
if err != nil {
t.Fatalf("BatchInsertContext failed: %v", err)
}
if len(metas) != 3 {
t.Fatalf("Expected 3 chunk metas, got %d", len(metas))
}
wantRows := []int{2, 2, 1}
for i, meta := range metas {
if meta.ChunkIndex != i+1 {
t.Fatalf("Chunk %d: expected index %d, got %d", i, i+1, meta.ChunkIndex)
}
if meta.ChunkCount != 3 {
t.Fatalf("Chunk %d: expected count 3, got %d", i, meta.ChunkCount)
}
if meta.ChunkRows != wantRows[i] {
t.Fatalf("Chunk %d: expected rows %d, got %d", i, wantRows[i], meta.ChunkRows)
}
if meta.TotalRows != len(values) {
t.Fatalf("Chunk %d: expected total rows %d, got %d", i, len(values), meta.TotalRows)
}
if meta.ColumnCount != len(columns) {
t.Fatalf("Chunk %d: expected column count %d, got %d", i, len(columns), meta.ColumnCount)
}
}
}
func TestStarDB_BatchInsert_HookMetaAbsentWithoutChunking(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(0)
db.SetBatchInsertMaxParams(0)
metaFound := false
db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
if _, ok := stardb.BatchExecMetaFromContext(ctx); ok {
metaFound = true
}
})
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
}
_, err := db.BatchInsertContext(context.Background(), "users", columns, values)
if err != nil {
t.Fatalf("BatchInsertContext failed: %v", err)
}
if metaFound {
t.Fatal("Expected no batch meta for non-chunked execution")
}
}
func TestStarDB_BatchInsertStructs_Basic(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 3 {
t.Errorf("Expected 3 rows affected, got %d", affected)
}
// Verify insertion
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 3 {
t.Errorf("Expected 3 rows, got %d", rows.Length())
}
// Verify first user
name := rows.Row(0).MustString("name")
if name != "Alice" {
t.Errorf("Expected first user 'Alice', got '%s'", name)
}
}
func TestStarDB_BatchInsertStructs_Chunked(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
db.SetBatchInsertMaxRows(2)
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
{Name: "David", Email: "david@example.com", Age: 40, CreatedAt: time.Now()},
{Name: "Eva", Email: "eva@example.com", Age: 28, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("Chunked BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != int64(len(users)) {
t.Fatalf("Expected %d affected rows, got %d", len(users), affected)
}
}
func TestStarDB_BatchInsertStructs_Single(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_BatchInsertStructs_Empty(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{}
_, err := db.BatchInsertStructs("users", users, "id")
if !errors.Is(err, stardb.ErrNoStructsToInsert) {
t.Errorf("Expected ErrNoStructsToInsert, got %v", err)
}
}
func TestStarDB_BatchInsertStructs_NotSlice(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25}
_, err := db.BatchInsertStructs("users", user, "id")
if !errors.Is(err, stardb.ErrStructsNotSlice) {
t.Errorf("Expected ErrStructsNotSlice, got %v", err)
}
}
func TestStarDB_BatchInsertStructs_Nil(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
_, err := db.BatchInsertStructs("users", nil, "id")
if !errors.Is(err, stardb.ErrStructsNil) {
t.Errorf("Expected ErrStructsNil, got %v", err)
}
}
func TestStarDB_BatchInsertStructs_NilPointer(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
var users *[]TestUser
_, err := db.BatchInsertStructs("users", users, "id")
if !errors.Is(err, stardb.ErrStructsPointerNil) {
t.Errorf("Expected ErrStructsPointerNil, got %v", err)
}
}
func TestStarDB_BatchInsertStructs_Pointer(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", &users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs with pointer failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 2 {
t.Errorf("Expected 2 rows affected, got %d", affected)
}
}
func TestStarDB_BatchInsertStructsContext(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
ctx := context.Background()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructsContext(ctx, "users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructsContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 2 {
t.Errorf("Expected 2 rows affected, got %d", affected)
}
}
func TestStarDB_BatchInsertStructs_Large(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
var users []TestUser
for i := 0; i < 50; i++ {
users = append(users, TestUser{
Name: "User" + string(rune(i)),
Email: "user" + string(rune(i)) + "@example.com",
Age: 20 + i%30,
CreatedAt: time.Now(),
})
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 50 {
t.Errorf("Expected 50 rows affected, got %d", affected)
}
// Verify
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 50 {
t.Errorf("Expected 50 rows in database, got %d", count)
}
}
func TestStarDB_BatchInsert_VerifyData(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
}
_, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
// Verify data integrity
rows, err := db.Query("SELECT name, email, age FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Check Alice
row0 := rows.Row(0)
if row0.MustString("name") != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", row0.MustString("name"))
}
if row0.MustString("email") != "alice@example.com" {
t.Errorf("Expected email 'alice@example.com', got '%s'", row0.MustString("email"))
}
if row0.MustInt("age") != 25 {
t.Errorf("Expected age 25, got %d", row0.MustInt("age"))
}
// Check Bob
row1 := rows.Row(1)
if row1.MustString("name") != "Bob" {
t.Errorf("Expected name 'Bob', got '%s'", row1.MustString("name"))
}
if row1.MustString("email") != "bob@example.com" {
t.Errorf("Expected email 'bob@example.com', got '%s'", row1.MustString("email"))
}
if row1.MustInt("age") != 30 {
t.Errorf("Expected age 30, got %d", row1.MustInt("age"))
}
}
func TestStarDB_BatchInsertStructs_VerifyData(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
now := time.Now()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: now},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: now},
}
_, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
// Query and verify with ORM
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var resultUsers []TestUser
err = rows.Orm(&resultUsers)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if len(resultUsers) != 2 {
t.Fatalf("Expected 2 users, got %d", len(resultUsers))
}
// Verify Alice
if resultUsers[0].Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", resultUsers[0].Name)
}
if resultUsers[0].Email != "alice@example.com" {
t.Errorf("Expected email 'alice@example.com', got '%s'", resultUsers[0].Email)
}
if resultUsers[0].Age != 25 {
t.Errorf("Expected age 25, got %d", resultUsers[0].Age)
}
// Verify Bob
if resultUsers[1].Name != "Bob" {
t.Errorf("Expected name 'Bob', got '%s'", resultUsers[1].Name)
}
if resultUsers[1].Email != "bob@example.com" {
t.Errorf("Expected email 'bob@example.com', got '%s'", resultUsers[1].Email)
}
if resultUsers[1].Age != 30 {
t.Errorf("Expected age 30, got %d", resultUsers[1].Age)
}
}
// Benchmark tests
func BenchmarkBatchInsert_10(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
columns := []string{"name", "email", "age"}
var values [][]interface{}
for i := 0; i < 10; i++ {
values = append(values, []interface{}{"User", "user@example.com", 25})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsert("users", columns, values)
db.Exec("DELETE FROM users")
}
}
func BenchmarkBatchInsert_100(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
columns := []string{"name", "email", "age"}
var values [][]interface{}
for i := 0; i < 100; i++ {
values = append(values, []interface{}{"User", "user@example.com", 25})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsert("users", columns, values)
db.Exec("DELETE FROM users")
}
}
func BenchmarkBatchInsertStructs_10(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
var users []TestUser
for i := 0; i < 10; i++ {
users = append(users, TestUser{
Name: "User",
Email: "user@example.com",
Age: 25,
CreatedAt: time.Now(),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsertStructs("users", users, "id")
db.Exec("DELETE FROM users")
}
}
func BenchmarkBatchInsertStructs_100(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
var users []TestUser
for i := 0; i < 100; i++ {
users = append(users, TestUser{
Name: "User",
Email: "user@example.com",
Age: 25,
CreatedAt: time.Now(),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsertStructs("users", users, "id")
db.Exec("DELETE FROM users")
}
}

View File

@ -1,23 +0,0 @@
module b612.me/stardb/testing
go 1.25.6
require (
b612.me/stardb v0.0.0
modernc.org/sqlite v1.46.1
)
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.37.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)
replace b612.me/stardb => ../

View File

@ -1,53 +0,0 @@
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

View File

@ -1,934 +0,0 @@
package testing
import (
"context"
"errors"
"testing"
"time"
"b612.me/stardb"
)
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
Balance float64 `db:"balance"`
Active bool `db:"active"`
CreatedAt time.Time `db:"created_at"`
}
type Profile struct {
UserID int `db:"user_id"`
Bio string `db:"bio"`
Avatar string `db:"avatar"`
}
type NestedUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
Profile `db:"---"`
}
type UserWithPrivateField struct {
ID int64 `db:"id"`
Name string `db:"name"`
age int `db:"age"`
}
func TestStarRows_Orm_Single(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if user.Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
}
if user.Email != "alice@example.com" {
t.Errorf("Expected email 'alice@example.com', got '%s'", user.Email)
}
if user.Age != 25 {
t.Errorf("Expected age 25, got %d", user.Age)
}
if user.Balance != 100.50 {
t.Errorf("Expected balance 100.50, got %f", user.Balance)
}
if !user.Active {
t.Errorf("Expected user to be active")
}
}
func TestStarRows_Orm_Multiple(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var users []User
err = rows.Orm(&users)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if len(users) != 3 {
t.Fatalf("Expected 3 users, got %d", len(users))
}
expectedNames := []string{"Alice", "Bob", "Charlie"}
for i, user := range users {
if user.Name != expectedNames[i] {
t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name)
}
}
}
func TestStarRows_Orm_Array(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var users [3]User
err = rows.Orm(&users)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
expectedNames := []string{"Alice", "Bob", "Charlie"}
for i, user := range users {
if user.Name != expectedNames[i] {
t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name)
}
}
}
func TestStarRows_Orm_ArrayTooSmall(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var users [2]User
err = rows.Orm(&users)
if err == nil {
t.Error("Expected error when target array is smaller than row count, got nil")
}
}
func TestStarRows_Orm_Empty(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NonExistent")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var users []User
err = rows.Orm(&users)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if len(users) != 0 {
t.Errorf("Expected 0 users, got %d", len(users))
}
}
func TestStarRows_Orm_NotPointer(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(user) // Not a pointer
if !errors.Is(err, stardb.ErrTargetNotPointer) {
t.Errorf("Expected ErrTargetNotPointer, got %v", err)
}
}
func TestStarRows_Orm_NilTarget(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
err = rows.Orm(nil)
if !errors.Is(err, stardb.ErrTargetNil) {
t.Errorf("Expected ErrTargetNil, got %v", err)
}
}
func TestStarRows_Orm_NilPointerTarget(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user *User
err = rows.Orm(user)
if !errors.Is(err, stardb.ErrTargetPointerNil) {
t.Errorf("Expected ErrTargetPointerNil, got %v", err)
}
}
func TestStarRows_Orm_MissingColumns_NonStrict(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Expected non-strict ORM to ignore missing columns, got error: %v", err)
}
if user.Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
}
}
func TestStarRows_Orm_MissingColumns_Strict(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
db.SetStrictORM(true)
rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err == nil {
t.Fatalf("Expected strict ORM to fail on missing columns, got nil")
}
}
func TestStarRows_Orm_UnexportedTaggedField(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user UserWithPrivateField
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if user.ID <= 0 {
t.Errorf("Expected positive ID, got %d", user.ID)
}
if user.Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
}
}
func TestStarDB_Insert(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "David",
Email: "david@example.com",
Age: 40,
Balance: 400.00,
Active: true,
CreatedAt: time.Now(),
}
result, err := db.Insert(&user, "users", "id")
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
lastID, err := result.LastInsertId()
if err != nil {
t.Fatalf("LastInsertId failed: %v", err)
}
if lastID <= 0 {
t.Errorf("Expected positive last insert ID, got %d", lastID)
}
// Verify insertion
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var insertedUser User
err = rows.Orm(&insertedUser)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if insertedUser.Name != "David" {
t.Errorf("Expected name 'David', got '%s'", insertedUser.Name)
}
if insertedUser.Email != "david@example.com" {
t.Errorf("Expected email 'david@example.com', got '%s'", insertedUser.Email)
}
if insertedUser.Age != 40 {
t.Errorf("Expected age 40, got %d", insertedUser.Age)
}
}
func TestStarDB_InsertContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
user := User{
Name: "Eve",
Email: "eve@example.com",
Age: 28,
Balance: 250.00,
Active: true,
CreatedAt: time.Now(),
}
result, err := db.InsertContext(ctx, &user, "users", "id")
if err != nil {
t.Fatalf("InsertContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_Update(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
// First, get the user
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
// Update the user
user.Age = 26
user.Balance = 150.75
result, err := db.Update(&user, "users", "id")
if err != nil {
t.Fatalf("Update failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
// Verify the update
rows2, err := db.Query("SELECT * FROM users WHERE id = ?", user.ID)
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows2.Close()
var updatedUser User
err = rows2.Orm(&updatedUser)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if updatedUser.Age != 26 {
t.Errorf("Expected age 26, got %d", updatedUser.Age)
}
if updatedUser.Balance != 150.75 {
t.Errorf("Expected balance 150.75, got %f", updatedUser.Balance)
}
}
func TestStarDB_UpdateContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
// Get user
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
// Update
user.Active = false
result, err := db.UpdateContext(ctx, &user, "users", "id")
if err != nil {
t.Fatalf("UpdateContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_QueryX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "Alice",
}
rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
t.Fatalf("QueryX failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
var result User
err = rows.Orm(&result)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if result.Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", result.Name)
}
}
func TestStarDB_QueryXContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
user := User{
Age: 30,
}
rows, err := db.QueryXContext(ctx, &user, "SELECT * FROM users WHERE age = ?", ":age")
if err != nil {
t.Fatalf("QueryXContext failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarDB_QueryX_MissingField(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{Name: "Alice"}
_, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":unknown")
if !errors.Is(err, stardb.ErrFieldNotFound) {
t.Errorf("Expected ErrFieldNotFound, got %v", err)
}
}
func TestStarDB_QueryX_NilTarget(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
_, err := db.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id")
if !errors.Is(err, stardb.ErrTargetNil) {
t.Errorf("Expected ErrTargetNil, got %v", err)
}
}
func TestStarDB_QueryXS(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
users := []User{
{Name: "Alice"},
{Name: "Bob"},
}
results, err := db.QueryXS(&users, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
t.Fatalf("QueryXS failed: %v", err)
}
if len(results) != 2 {
t.Errorf("Expected 2 results, got %d", len(results))
}
for i, rows := range results {
if rows.Length() != 1 {
t.Errorf("Expected 1 row in result %d, got %d", i, rows.Length())
}
}
}
func TestStarDB_QueryXS_NilTargets(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
_, err := db.QueryXS(nil, "SELECT * FROM users")
if !errors.Is(err, stardb.ErrTargetsNil) {
t.Errorf("Expected ErrTargetsNil, got %v", err)
}
}
func TestStarDB_ExecX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "Alice",
Age: 99,
}
result, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
if err != nil {
t.Fatalf("ExecX failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
// Verify
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
age := rows.Row(0).MustInt("age")
if age != 99 {
t.Errorf("Expected age 99, got %d", age)
}
}
func TestStarDB_ExecXContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
user := User{
Name: "Bob",
Active: false,
}
result, err := db.ExecXContext(ctx, &user, "UPDATE users SET active = ? WHERE name = ?", ":active", ":name")
if err != nil {
t.Fatalf("ExecXContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_ExecX_MissingField(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{Name: "Alice", Age: 99}
_, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":unknown")
if !errors.Is(err, stardb.ErrFieldNotFound) {
t.Errorf("Expected ErrFieldNotFound, got %v", err)
}
}
func TestStarDB_ExecX_NilTarget(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
_, err := db.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id")
if !errors.Is(err, stardb.ErrTargetNil) {
t.Errorf("Expected ErrTargetNil, got %v", err)
}
}
func TestStarDB_ExecXS(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
users := []User{
{Name: "Alice", Age: 26},
{Name: "Bob", Age: 31},
}
results, err := db.ExecXS(&users, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
if err != nil {
t.Fatalf("ExecXS failed: %v", err)
}
if len(results) != 2 {
t.Errorf("Expected 2 results, got %d", len(results))
}
for i, result := range results {
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed for result %d: %v", i, err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected in result %d, got %d", i, affected)
}
}
}
func TestStarDB_ExecXS_NilTargets(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
_, err := db.ExecXS(nil, "UPDATE users SET age = age")
if !errors.Is(err, stardb.ErrTargetsNil) {
t.Errorf("Expected ErrTargetsNil, got %v", err)
}
}
func TestStarTx_Insert(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
user := User{
Name: "Frank",
Email: "frank@example.com",
Age: 45,
Balance: 500.00,
Active: true,
CreatedAt: time.Now(),
}
result, err := tx.Insert(&user, "users", "id")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Insert failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
lastID, err := result.LastInsertId()
if err != nil {
t.Fatalf("LastInsertId failed: %v", err)
}
if lastID <= 0 {
t.Errorf("Expected positive last insert ID, got %d", lastID)
}
// Verify
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Frank")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarTx_Update(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
// Get user
rows, err := tx.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
tx.Rollback()
t.Fatalf("Query failed: %v", err)
}
var user User
err = rows.Orm(&user)
rows.Close()
if err != nil {
tx.Rollback()
t.Fatalf("Orm failed: %v", err)
}
// Update
user.Age = 27
result, err := tx.Update(&user, "users", "id")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Update failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarTx_QueryX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
defer tx.Rollback()
user := User{
Name: "Charlie",
}
rows, err := tx.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
t.Fatalf("Tx.QueryX failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarTx_QueryX_NilTarget(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
defer tx.Rollback()
_, err = tx.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id")
if !errors.Is(err, stardb.ErrTargetNil) {
t.Errorf("Expected ErrTargetNil, got %v", err)
}
}
func TestStarTx_ExecX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
user := User{
Name: "Charlie",
Age: 36,
}
result, err := tx.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.ExecX failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarTx_ExecX_NilTarget(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
defer tx.Rollback()
_, err = tx.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id")
if !errors.Is(err, stardb.ErrTargetNil) {
t.Errorf("Expected ErrTargetNil, got %v", err)
}
}
func TestStarTx_Rollback(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
user := User{
Name: "Rollback User",
Email: "rollback@example.com",
Age: 50,
Balance: 600.00,
Active: true,
CreatedAt: time.Now(),
}
_, err = tx.Insert(&user, "users", "id")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Insert failed: %v", err)
}
// Rollback instead of commit
err = tx.Rollback()
if err != nil {
t.Fatalf("Rollback failed: %v", err)
}
// Verify the insert was rolled back
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Rollback User")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 0 {
t.Errorf("Expected 0 rows after rollback, got %d", rows.Length())
}
}
func TestNamedParameterEscape(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "Alice",
}
// Test escaped colon
rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", `\:name`)
if err != nil {
t.Fatalf("QueryX with escaped parameter failed: %v", err)
}
defer rows.Close()
// Should use literal ":name" string, not the field value
if rows.Length() != 0 {
t.Errorf("Expected 0 rows with literal ':name', got %d", rows.Length())
}
}

View File

@ -1,128 +0,0 @@
package testing
import (
"fmt"
"testing"
"time"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
func setupBenchmarkDB(b *testing.B) *stardb.StarDB {
b.Helper()
db := &stardb.StarDB{}
if err := db.Open("sqlite", ":memory:"); err != nil {
b.Fatalf("Failed to open database: %v", err)
}
_, err := db.Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER,
balance REAL,
active BOOLEAN,
created_at DATETIME
)
`)
if err != nil {
b.Fatalf("Failed to create table: %v", err)
}
_, err = db.Exec(`
INSERT INTO users (name, email, age, balance, active, created_at) VALUES
('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'),
('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'),
('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00')
`)
if err != nil {
b.Fatalf("Failed to insert seed data: %v", err)
}
return db
}
func BenchmarkQueryX(b *testing.B) {
db := setupBenchmarkDB(b)
defer db.Close()
target := User{Name: "Alice"}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := db.QueryX(&target, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
b.Fatalf("QueryX failed: %v", err)
}
_ = rows.Close()
}
}
func BenchmarkOrm(b *testing.B) {
db := setupBenchmarkDB(b)
defer db.Close()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
b.Fatalf("Query failed: %v", err)
}
var users []User
if err := rows.Orm(&users); err != nil {
_ = rows.Close()
b.Fatalf("Orm failed: %v", err)
}
_ = rows.Close()
}
}
func BenchmarkScanEach(b *testing.B) {
db := setupBenchmarkDB(b)
defer db.Close()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
count := 0
err := db.ScanEach("SELECT * FROM users ORDER BY name", func(row *stardb.StarResult) error {
_ = row.MustString("name")
count++
return nil
})
if err != nil {
b.Fatalf("ScanEach failed: %v", err)
}
if count != 3 {
b.Fatalf("Unexpected row count: %d", count)
}
}
}
func BenchmarkBatchInsert(b *testing.B) {
db := setupBenchmarkDB(b)
defer db.Close()
columns := []string{"name", "email", "age", "balance", "active", "created_at"}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
base := i * 2
values := [][]interface{}{
{fmt.Sprintf("bench_user_%d", base), fmt.Sprintf("bench_%d@example.com", base), 20 + (base % 20), 99.5, true, time.Now()},
{fmt.Sprintf("bench_user_%d", base+1), fmt.Sprintf("bench_%d@example.com", base+1), 20 + ((base + 1) % 20), 199.5, false, time.Now()},
}
if _, err := db.BatchInsert("users", columns, values); err != nil {
b.Fatalf("BatchInsert failed: %v", err)
}
}
}

View File

@ -1,298 +0,0 @@
package testing
import (
"testing"
"time"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
func TestDefaultPoolConfig(t *testing.T) {
config := stardb.DefaultPoolConfig()
if config.MaxOpenConns != 25 {
t.Errorf("Expected MaxOpenConns 25, got %d", config.MaxOpenConns)
}
if config.MaxIdleConns != 5 {
t.Errorf("Expected MaxIdleConns 5, got %d", config.MaxIdleConns)
}
if config.ConnMaxLifetime != time.Hour {
t.Errorf("Expected ConnMaxLifetime 1 hour, got %v", config.ConnMaxLifetime)
}
if config.ConnMaxIdleTime != 10*time.Minute {
t.Errorf("Expected ConnMaxIdleTime 10 minutes, got %v", config.ConnMaxIdleTime)
}
}
func TestStarDB_SetPoolConfig(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
config := &stardb.PoolConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetime: 30 * time.Minute,
ConnMaxIdleTime: 5 * time.Minute,
}
db.SetPoolConfig(config)
stats := db.Stats()
if stats.MaxOpenConnections != 50 {
t.Errorf("Expected MaxOpenConnections 50, got %d", stats.MaxOpenConnections)
}
}
func TestStarDB_SetPoolConfig_Partial(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Only set MaxOpenConns
config := &stardb.PoolConfig{
MaxOpenConns: 100,
}
db.SetPoolConfig(config)
stats := db.Stats()
if stats.MaxOpenConnections != 100 {
t.Errorf("Expected MaxOpenConnections 100, got %d", stats.MaxOpenConnections)
}
}
func TestStarDB_SetPoolConfig_Zero(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Zero values should be ignored
config := &stardb.PoolConfig{
MaxOpenConns: 0,
MaxIdleConns: 0,
}
db.SetPoolConfig(config)
// Should not panic or error
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig with zero values: %v", err)
}
}
func TestStarDB_SetPoolConfig_NilConfig(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
db.SetPoolConfig(nil)
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig(nil): %v", err)
}
}
func TestStarDB_SetPoolConfig_BeforeOpen(t *testing.T) {
db := stardb.NewStarDB()
db.SetPoolConfig(&stardb.PoolConfig{
MaxOpenConns: 10,
})
// should not panic when called before Open
}
func TestOpenWithPool_Default(t *testing.T) {
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
err = db.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
stats := db.Stats()
if stats.MaxOpenConnections != 25 {
t.Errorf("Expected default MaxOpenConnections 25, got %d", stats.MaxOpenConnections)
}
}
func TestOpenWithPool_Custom(t *testing.T) {
config := &stardb.PoolConfig{
MaxOpenConns: 15,
MaxIdleConns: 3,
ConnMaxLifetime: 20 * time.Minute,
ConnMaxIdleTime: 3 * time.Minute,
}
db, err := stardb.OpenWithPool("sqlite", ":memory:", config)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
err = db.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
stats := db.Stats()
if stats.MaxOpenConnections != 15 {
t.Errorf("Expected MaxOpenConnections 15, got %d", stats.MaxOpenConnections)
}
}
func TestOpenWithPool_InvalidDriver(t *testing.T) {
config := stardb.DefaultPoolConfig()
db, err := stardb.OpenWithPool("invalid_driver", "invalid_conn", config)
if err == nil {
db.Close()
t.Error("Expected error with invalid driver, got nil")
}
}
func TestOpenWithPool_Query(t *testing.T) {
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
// Create table
_, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Insert data
_, err = db.Exec(`INSERT INTO test (name) VALUES (?)`, "Alice")
if err != nil {
t.Fatalf("Failed to insert data: %v", err)
}
// Query data
rows, err := db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestPoolConfig_AllFields(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
config := &stardb.PoolConfig{
MaxOpenConns: 30,
MaxIdleConns: 8,
ConnMaxLifetime: 45 * time.Minute,
ConnMaxIdleTime: 7 * time.Minute,
}
db.SetPoolConfig(config)
// Verify by checking stats
stats := db.Stats()
if stats.MaxOpenConnections != 30 {
t.Errorf("Expected MaxOpenConnections 30, got %d", stats.MaxOpenConnections)
}
// Test that connections work
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig: %v", err)
}
}
func TestPoolConfig_NegativeValues(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Negative values should be ignored
config := &stardb.PoolConfig{
MaxOpenConns: -1,
MaxIdleConns: -1,
}
db.SetPoolConfig(config)
// Should not panic or error
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig with negative values: %v", err)
}
}
func TestOpenWithPool_MultipleConnections(t *testing.T) {
config := &stardb.PoolConfig{
MaxOpenConns: 5,
MaxIdleConns: 2,
}
db, err := stardb.OpenWithPool("sqlite", ":memory:", config)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
// Create table
_, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Perform multiple operations
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO test (value) VALUES (?)`, "test")
if err != nil {
t.Errorf("Insert %d failed: %v", i, err)
}
}
// Verify
rows, err := db.Query("SELECT COUNT(*) as count FROM test")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 10 {
t.Errorf("Expected 10 rows, got %d", count)
}
}

View File

@ -1,331 +0,0 @@
package testing
import (
"errors"
"testing"
"time"
"b612.me/stardb"
)
func TestStarResult_MustString(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
name := row.MustString("name")
if name != "Alice" {
t.Errorf("Expected 'Alice', got '%s'", name)
}
email := row.MustString("email")
if email != "alice@example.com" {
t.Errorf("Expected 'alice@example.com', got '%s'", email)
}
// Test non-existent column
nonexistent := row.MustString("nonexistent")
if nonexistent != "" {
t.Errorf("Expected empty string for non-existent column, got '%s'", nonexistent)
}
}
func TestStarResult_MustInt(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
age := row.MustInt("age")
if age != 30 {
t.Errorf("Expected age 30, got %d", age)
}
id := row.MustInt("id")
if id <= 0 {
t.Errorf("Expected positive id, got %d", id)
}
}
func TestStarResult_MustInt64(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Charlie")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
age := row.MustInt64("age")
if age != 35 {
t.Errorf("Expected age 35, got %d", age)
}
}
func TestStarResult_MustFloat64(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
balance := row.MustFloat64("balance")
if balance != 100.50 {
t.Errorf("Expected balance 100.50, got %f", balance)
}
}
func TestStarResult_MustBool(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Alice - active
row := rows.Row(0)
active := row.MustBool("active")
if !active {
t.Errorf("Expected Alice to be active")
}
// Charlie - inactive
row = rows.Row(2)
active = row.MustBool("active")
if active {
t.Errorf("Expected Charlie to be inactive")
}
}
func TestStarResult_IsNil(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
// Insert a row with NULL value
_, err := db.Exec("INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, NULL, ?, ?, ?)",
"David", "david@example.com", 150.0, true, time.Now())
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
if !row.IsNil("age") {
t.Errorf("Expected age to be NULL")
}
if row.IsNil("name") {
t.Errorf("Expected name to not be NULL")
}
}
func TestStarResultCol_MustString(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("name")
names := col.MustString()
expected := []string{"Alice", "Bob", "Charlie"}
for i, name := range names {
if name != expected[i] {
t.Errorf("Expected name '%s', got '%s'", expected[i], name)
}
}
}
func TestStarResultCol_MustInt(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("age")
ages := col.MustInt()
expected := []int{25, 30, 35}
for i, age := range ages {
if age != expected[i] {
t.Errorf("Expected age %d, got %d", expected[i], age)
}
}
}
func TestStarResultCol_MustFloat64(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("balance")
balances := col.MustFloat64()
expected := []float64{100.50, 200.75, 300.25}
for i, balance := range balances {
if balance != expected[i] {
t.Errorf("Expected balance %f, got %f", expected[i], balance)
}
}
}
func TestStarResultCol_MustBool(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("active")
actives := col.MustBool()
expected := []bool{true, true, false}
for i, active := range actives {
if active != expected[i] {
t.Errorf("Expected active %v, got %v", expected[i], active)
}
}
}
func TestStarResult_GetColumnNotFoundError(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
_, err = row.GetString("does_not_exist")
if !errors.Is(err, stardb.ErrColumnNotFound) {
t.Fatalf("Expected ErrColumnNotFound, got %v", err)
}
}
func TestStarResult_GetNullValues(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
_, err := db.Exec(
"INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)",
"NullUser", "null@example.com", nil, nil, nil, nil,
)
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NullUser")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
name, err := row.GetNullString("name")
if err != nil {
t.Fatalf("GetNullString failed: %v", err)
}
if !name.Valid || name.String != "NullUser" {
t.Fatalf("Expected valid name NullUser, got %+v", name)
}
age, err := row.GetNullInt64("age")
if err != nil {
t.Fatalf("GetNullInt64 failed: %v", err)
}
if age.Valid {
t.Fatalf("Expected NULL age, got %+v", age)
}
balance, err := row.GetNullFloat64("balance")
if err != nil {
t.Fatalf("GetNullFloat64 failed: %v", err)
}
if balance.Valid {
t.Fatalf("Expected NULL balance, got %+v", balance)
}
active, err := row.GetNullBool("active")
if err != nil {
t.Fatalf("GetNullBool failed: %v", err)
}
if active.Valid {
t.Fatalf("Expected NULL active, got %+v", active)
}
createdAt, err := row.GetNullTime("created_at")
if err != nil {
t.Fatalf("GetNullTime failed: %v", err)
}
if createdAt.Valid {
t.Fatalf("Expected NULL created_at, got %+v", createdAt)
}
}
func TestStarResult_GetNullTime_Valid(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT created_at FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
value, err := rows.Row(0).GetNullTime("created_at")
if err != nil {
t.Fatalf("GetNullTime failed: %v", err)
}
if !value.Valid {
t.Fatal("Expected valid created_at")
}
}

View File

@ -1,109 +0,0 @@
package testing
import (
"testing"
)
func TestStarRows_Row(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Test first row
row := rows.Row(0)
name := row.MustString("name")
if name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", name)
}
// Test out of bounds
row = rows.Row(999)
if len(row.Result()) != 0 {
t.Errorf("Expected empty result for out of bounds index")
}
// Test negative index
row = rows.Row(-1)
if len(row.Result()) != 0 {
t.Errorf("Expected empty result for negative index")
}
}
func TestStarRows_Col(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Test column extraction
col := rows.Col("name")
names := col.MustString()
if len(names) != 3 {
t.Errorf("Expected 3 names, got %d", len(names))
}
if names[0] != "Alice" {
t.Errorf("Expected first name 'Alice', got '%s'", names[0])
}
// Test non-existent column
col = rows.Col("nonexistent")
if len(col.Result()) != 0 {
t.Errorf("Expected empty result for non-existent column")
}
}
func TestStarRows_Rescan(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
db.ManualScan = true
rows, err := db.Query("SELECT * FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
err = rows.Rescan()
if err != nil {
t.Fatalf("Rescan failed: %v", err)
}
if rows.Length() != 3 {
t.Errorf("Expected 3 rows, got %d", rows.Length())
}
}
func TestStarRows_StringResult(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT name, age FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if len(rows.StringResult()) != 1 {
t.Fatalf("Expected 1 string result, got %d", len(rows.StringResult()))
}
record := rows.StringResult()[0]
if record["name"] != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", record["name"])
}
if record["age"] != "25" {
t.Errorf("Expected age '25', got '%s'", record["age"])
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,47 +0,0 @@
package testing
import (
"testing"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
// setupTestDB creates a test database with sample data
// This function is only available when building with -tags=testing
func setupTestDB(t *testing.T) *stardb.StarDB {
db := &stardb.StarDB{}
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
// Create test table
_, err = db.Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER,
balance REAL,
active BOOLEAN,
created_at DATETIME
)
`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Insert test data
_, err = db.Exec(`
INSERT INTO users (name, email, age, balance, active, created_at) VALUES
('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'),
('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'),
('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00')
`)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}
return db
}

264
tx.go
View File

@ -1,264 +0,0 @@
package stardb
import (
"context"
"database/sql"
"strings"
"time"
)
// StarTx represents a database transaction
type StarTx struct {
tx *sql.Tx
db *StarDB
}
func (t *StarTx) ensureTx() error {
if t == nil || t.tx == nil || t.db == nil {
return ErrTxNotInitialized
}
return nil
}
// Query executes a query within the transaction
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
return t.query(nil, query, args...)
}
// QueryContext executes a query with context within the transaction
func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
return t.query(ctx, query, args...)
}
// QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing.
func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
return t.queryRaw(nil, query, args...)
}
// QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing.
func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
return t.queryRaw(ctx, query, args...)
}
func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
if err := t.ensureTx(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, hookArgs)
}
start := time.Now()
var (
rows *sql.Rows
err error
)
if ctx == nil {
rows, err = t.tx.Query(query, args...)
} else {
rows, err = t.tx.QueryContext(ctx, query, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return rows, nil
}
// query is the internal query implementation
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
rows, err := t.queryRaw(ctx, query, args...)
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: t.db,
}
if !t.db.ManualScan {
if err := starRows.parse(); err != nil {
_ = rows.Close()
return nil, err
}
}
return starRows, nil
}
// Exec executes a query within the transaction
func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) {
return t.exec(nil, query, args...)
}
// ExecContext executes a query with context within the transaction
func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return t.exec(ctx, query, args...)
}
// exec is the internal exec implementation
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if err := t.ensureTx(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, hookArgs)
}
start := time.Now()
var (
result sql.Result
err error
)
if ctx == nil {
result, err = t.tx.Exec(query, args...)
} else {
result, err = t.tx.ExecContext(ctx, query, args...)
}
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, hookArgs, duration, err)
}
if err != nil {
return nil, err
}
return result, nil
}
// Prepare creates a prepared statement within the transaction
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
if err := t.ensureTx(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, nil)
}
start := time.Now()
stmt, err := t.tx.Prepare(query)
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, nil, duration, err)
}
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
}
// PrepareContext creates a prepared statement with context
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
if err := t.ensureTx(); err != nil {
return nil, err
}
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
if beforeHook != nil {
beforeHook(hookCtx, query, nil)
}
start := time.Now()
stmt, err := t.tx.PrepareContext(ctx, query)
duration := time.Since(start)
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
afterHook(hookCtx, query, nil, duration, err)
}
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
}
// QueryStmt executes a prepared statement query within the transaction
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := t.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(args...)
}
// QueryStmtContext executes a prepared statement query with context
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := t.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.QueryContext(ctx, args...)
}
// ExecStmt executes a prepared statement within the transaction
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := t.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Exec(args...)
}
// ExecStmtContext executes a prepared statement with context
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if strings.TrimSpace(query) == "" {
return nil, ErrQueryEmpty
}
stmt, err := t.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.ExecContext(ctx, args...)
}
// Commit commits the transaction
func (t *StarTx) Commit() error {
if err := t.ensureTx(); err != nil {
return err
}
return t.tx.Commit()
}
// Rollback rolls back the transaction
func (t *StarTx) Rollback() error {
if err := t.ensureTx(); err != nil {
return err
}
return t.tx.Rollback()
}

View File

@ -1,45 +0,0 @@
package stardb
import (
"context"
"database/sql"
)
// WithTx runs fn in a transaction and handles commit/rollback automatically.
func (s *StarDB) WithTx(fn func(tx *StarTx) error) error {
return s.WithTxContext(context.Background(), nil, fn)
}
// WithTxContext runs fn in a transaction with context/options and handles commit/rollback automatically.
func (s *StarDB) WithTxContext(ctx context.Context, opts *sql.TxOptions, fn func(tx *StarTx) error) (err error) {
if fn == nil {
return ErrTxFuncNil
}
if ctx == nil {
ctx = context.Background()
}
tx, err := s.BeginTx(ctx, opts)
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
_ = tx.Rollback()
panic(p)
}
}()
if err := fn(tx); err != nil {
_ = tx.Rollback()
return err
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return err
}
return nil
}