重构代码

This commit is contained in:
兔子 2026-03-07 19:27:44 +08:00
parent fb9808a139
commit 88569eb176
Signed by: b612
GPG Key ID: 99DD2222B612B612
33 changed files with 5910 additions and 2003 deletions

19
.gitignore vendored Normal file
View File

@ -0,0 +1,19 @@
# Binaries
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
# Go build cache
*.out
# Dependency directories
vendor/
# IDE
.idea/
# OS
.DS_Store

10
.idea/stardb.iml generated
View File

@ -1,6 +1,14 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="WEB_MODULE" version="4">
<component name="Go" enabled="true" />
<component name="Go" enabled="true">
<buildTags>
<option name="customFlags">
<array>
<option value="testing" />
</array>
</option>
</buildTags>
</component>
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />

201
LICENSE.txt Normal file
View File

@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

535
README.MD Normal file
View File

@ -0,0 +1,535 @@
# StarDB
一个轻量级的 Go 数据库封装库,个人学习用,提供简洁的 API 和 ORM 功能。
[![Go Version](https://img.shields.io/badge/Go-%3E%3D%201.16-blue)](https://golang.org/)
[![License](https://img.shields.io/badge/license-Apache%20License%202.0-blue)](LICENSE)
## ✨ 特性
- ✅ **零第三方依赖** - 仅使用 Go 标准库
- ✅ **类型安全** - 提供类型安全的结果转换方法
- ✅ **简单 ORM** - 支持结构体与数据库的自动映射
- ✅ **Context 支持** - 所有操作都支持 context
- ✅ **事务支持** - 完整的事务操作支持
- ✅ **预编译语句** - 支持预编译语句以提升性能
- ✅ **批量操作** - 高效的批量插入功能
- ✅ **连接池管理** - 便捷的连接池配置
- ✅ **查询构建器** - 链式调用构建 SQL 查询
## 📦 安装
```bash
go get b612.me/stardb
```
## 🚀 快速开始
### 基本使用
```go
package main
import (
"b612.me/stardb"
_ "github.com/mattn/go-sqlite3"
)
func main() {
// 创建数据库实例
db := stardb.NewStarDB()
err := db.Open("sqlite3", "test.db")
if err != nil {
panic(err)
}
defer db.Close()
// 执行查询
rows, err := db.Query("SELECT * FROM users WHERE age > ?", 18)
if err != nil {
panic(err)
}
defer rows.Close()
// 遍历结果
for i := 0; i < rows.Length(); i++ {
row := rows.Row(i)
name := row.MustString("name")
age := row.MustInt("age")
println(name, age)
}
}
```
### ORM 使用
```go
package main
import (
"b612.me/stardb"
"time"
_ "github.com/mattn/go-sqlite3"
)
// 定义结构体
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
Active bool `db:"active"`
CreatedAt time.Time `db:"created_at"`
}
func main() {
db := stardb.NewStarDB()
db.Open("sqlite3", "test.db")
defer db.Close()
// 查询单个对象
rows, _ := db.Query("SELECT * FROM users WHERE id = ?", 1)
defer rows.Close()
var user User
rows.Orm(&user)
println(user.Name)
// 查询多个对象
rows2, _ := db.Query("SELECT * FROM users WHERE age > ?", 18)
defer rows2.Close()
var users []User
rows2.Orm(&users)
for _, u := range users {
println(u.Name, u.Age)
}
}
```
### 插入和更新
```go
// 插入数据
user := User{
Name: "Alice",
Email: "alice@example.com",
Age: 25,
Active: true,
CreatedAt: time.Now(),
}
result, err := db.Insert(&user, "users", "id") // "id" 是自增字段
if err != nil {
panic(err)
}
lastID, _ := result.LastInsertId()
println("插入的 ID:", lastID)
// 更新数据
user.Age = 26
result, err = db.Update(&user, "users", "id") // "id" 是主键
if err != nil {
panic(err)
}
affected, _ := result.RowsAffected()
println("更新的行数:", affected)
```
### 批量插入
```go
// 方式 1使用原始数据
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", "charlie@example.com", 35},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
panic(err)
}
// 方式 2使用结构体
users := []User{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
}
result, err = db.BatchInsertStructs("users", users, "id")
if err != nil {
panic(err)
}
affected, _ := result.RowsAffected()
println("批量插入行数:", affected)
```
### 事务操作
```go
// 开始事务
tx, err := db.Begin()
if err != nil {
panic(err)
}
// 执行操作
_, err = tx.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "Alice", "alice@example.com")
if err != nil {
tx.Rollback()
panic(err)
}
_, err = tx.Exec("UPDATE accounts SET balance = balance - ? WHERE user_id = ?", 100, 1)
if err != nil {
tx.Rollback()
panic(err)
}
// 提交事务
err = tx.Commit()
if err != nil {
panic(err)
}
```
### 预编译语句
```go
// 创建预编译语句
stmt, err := db.Prepare("SELECT * FROM users WHERE age > ?")
if err != nil {
panic(err)
}
defer stmt.Close()
// 多次执行
rows1, _ := stmt.Query(18)
defer rows1.Close()
rows2, _ := stmt.Query(25)
defer rows2.Close()
rows3, _ := stmt.Query(30)
defer rows3.Close()
```
### 查询构建器
```go
// 使用查询构建器
rows, err := stardb.NewQueryBuilder("users").
Select("id", "name", "email").
Where("age > ?", 18).
Where("active = ?", true).
OrderBy("name ASC").
Limit(10).
Offset(0).
Query(db)
if err != nil {
panic(err)
}
defer rows.Close()
var users []User
rows.Orm(&users)
```
### 连接池配置
```go
// 方式 1使用默认配置
db, err := stardb.OpenWithPool("sqlite3", "test.db", nil)
if err != nil {
panic(err)
}
defer db.Close()
// 方式 2自定义配置
config := &stardb.PoolConfig{
MaxOpenConns: 50, // 最大打开连接数
MaxIdleConns: 10, // 最大空闲连接数
ConnMaxLifetime: 1 * time.Hour, // 连接最大生命周期
ConnMaxIdleTime: 10 * time.Minute, // 连接最大空闲时间
}
db, err = stardb.OpenWithPool("mysql", "user:pass@tcp(localhost:3306)/dbname", config)
if err != nil {
panic(err)
}
defer db.Close()
// 方式 3手动设置
db := stardb.NewStarDB()
db.Open("postgres", "postgres://user:pass@localhost/dbname")
db.SetPoolConfig(config)
```
### 命名参数绑定
```go
user := User{
Name: "Alice",
Age: 25,
}
// 使用 :fieldname 语法绑定结构体字段
rows, err := db.QueryX(&user,
"SELECT * FROM users WHERE name = ? AND age > ?",
":name", ":age")
if err != nil {
panic(err)
}
defer rows.Close()
```
## 📖 详细文档
### 结果转换方法
StarDB 提供了丰富的类型转换方法:
```go
row := rows.Row(0)
// 字符串
name := row.MustString("name")
// 整数
age := row.MustInt("age")
id := row.MustInt64("id")
count := row.MustInt32("count")
uid := row.MustUint64("uid")
// 浮点数
price := row.MustFloat64("price")
rate := row.MustFloat32("rate")
// 布尔值
active := row.MustBool("active")
// 字节数组
data := row.MustBytes("data")
// 时间
createdAt := row.MustDate("created_at", "2006-01-02 15:04:05")
// 检查 NULL
isNull := row.IsNil("optional_field")
```
### 列操作
```go
// 获取某一列的所有值
col := rows.Col("name")
names := col.MustString() // []string
ages := col.MustInt() // []int
prices := col.MustFloat64() // []float64
actives := col.MustBool() // []bool
```
### Context 支持
所有操作都有对应的 Context 版本:
```go
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
// 查询
rows, err := db.QueryContext(ctx, "SELECT * FROM users")
// 执行
result, err := db.ExecContext(ctx, "UPDATE users SET age = ?", 26)
// 事务
tx, err := db.BeginTx(ctx, nil)
// 批量插入
result, err := db.BatchInsertContext(ctx, "users", columns, values)
```
### 错误处理
```go
// 使用 Must* 方法(忽略错误,返回零值)
name := row.MustString("name")
// 使用 Get* 方法(返回错误)
name, err := row.GetString("name")
if err != nil {
// 处理错误
}
```
## 🔧 高级用法
### 嵌套结构体
```go
type Profile struct {
Bio string `db:"bio"`
Avatar string `db:"avatar"`
}
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Profile Profile `db:"---"` // 使用 "---" 标记嵌套结构体
}
rows, _ := db.Query("SELECT id, name, bio, avatar FROM users")
defer rows.Close()
var user User
rows.Orm(&user)
println(user.Name)
println(user.Profile.Bio)
```
### 手动扫描模式
```go
db := stardb.NewStarDB()
db.ManualScan = true // 启用手动扫描模式
db.Open("sqlite3", "test.db")
rows, _ := db.Query("SELECT * FROM users")
defer rows.Close()
// 手动触发解析
rows.Rescan()
// 现在可以访问数据
println(rows.Length())
```
### 直接访问底层 *sql.DB
```go
db := stardb.NewStarDB()
db.Open("sqlite3", "test.db")
// 获取底层 *sql.DB
rawDB := db.DB()
rawDB.SetMaxOpenConns(100)
// 或者使用已有的 *sql.DB
sqlDB, _ := sql.Open("sqlite3", "test.db")
db := stardb.NewStarDBWithDB(sqlDB)
```
## 🎯 性能优化建议
### 1. 使用预编译语句
对于重复执行的查询,使用预编译语句可以提升 20-50% 的性能:
```go
stmt, _ := db.Prepare("SELECT * FROM users WHERE id = ?")
defer stmt.Close()
for _, id := range userIDs {
rows, _ := stmt.Query(id)
// 处理结果
rows.Close()
}
```
### 2. 批量操作
批量插入比单条插入快 2-3 倍:
```go
// ❌ 慢
for _, user := range users {
db.Exec("INSERT INTO users (name) VALUES (?)", user.Name)
}
// ✅ 快
db.BatchInsertStructs("users", users, "id")
```
### 3. 合理配置连接池
```go
config := &stardb.PoolConfig{
MaxOpenConns: 25, // 根据数据库服务器调整
MaxIdleConns: 5, // 保持少量空闲连接
ConnMaxLifetime: 1 * time.Hour,
ConnMaxIdleTime: 10 * time.Minute,
}
db.SetPoolConfig(config)
```
### 4. 使用事务
将多个操作放在一个事务中可以显著提升性能:
```go
tx, _ := db.Begin()
for _, user := range users {
tx.Exec("INSERT INTO users (name) VALUES (?)", user.Name)
}
tx.Commit()
```
## 🧪 测试
项目包含完整的单元测试:
```bash
# 运行所有测试
cd testing
go test -v
# 运行特定测试
go test -v -run TestStarDB_Query
# 查看覆盖率
go test -cover
# 生成覆盖率报告
go test -coverprofile=coverage.out
go tool cover -html=coverage.out
# 运行基准测试
go test -bench=. -benchmem
```
## 📊 支持的数据库
StarDB 支持所有实现了 `database/sql` 接口的数据库驱动:
| 数据库 | 驱动 | 导入 |
|--------|------|------|
| SQLite | modernc.org/sqlite | `_ "modernc.org/sqlite"` |
| MySQL | github.com/go-sql-driver/mysql | `_ "github.com/go-sql-driver/mysql"` |
| PostgreSQL | github.com/lib/pq | `_ "github.com/lib/pq"` |
| SQL Server | github.com/denisenkom/go-mssqldb | `_ "github.com/denisenkom/go-mssqldb"` |
| Oracle | github.com/sijms/go-ora/v2 | `_ "github.com/sijms/go-ora/v2"` |
## 📄 许可证
本项目采用 Apache 2.0 许可证 - 详见 [LICENSE](LICENSE) 文件
## 🙏 致谢
- 感谢 Go 标准库提供的 `database/sql` 包
- 灵感来源于 xorm、gorm 等优秀的 ORM 框架
## 📮 联系方式
- 项目主页: https://git.b612.me/b612/stardb.git

112
batch.go Normal file
View File

@ -0,0 +1,112 @@
package stardb
import (
"context"
"database/sql"
"fmt"
"reflect"
"strings"
)
// BatchInsert performs batch insert operation
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
return s.batchInsert(nil, tableName, columns, values)
}
// BatchInsertContext performs batch insert with context
func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
return s.batchInsert(ctx, tableName, columns, values)
}
// batchInsert is the internal implementation
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
if len(values) == 0 {
return nil, fmt.Errorf("no values to insert")
}
// Build placeholders: (?, ?), (?, ?), ...
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
// Build SQL
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
tableName,
strings.Join(columns, ", "),
placeholders)
// Flatten values
var args []interface{}
for _, row := range values {
args = append(args, row...)
}
return s.exec(ctx, query, args...)
}
// BatchInsertStructs performs batch insert using structs
func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...)
}
// BatchInsertStructsContext performs batch insert using structs with context
func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...)
}
// batchInsertStructs is the internal implementation
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
// Get slice of structs
targetValue := reflect.ValueOf(structs)
if targetValue.Kind() == reflect.Ptr {
targetValue = targetValue.Elem()
}
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
return nil, fmt.Errorf("structs must be a slice or array")
}
if targetValue.Len() == 0 {
return nil, fmt.Errorf("no structs to insert")
}
// Get field names from first struct
firstStruct := targetValue.Index(0).Interface()
fieldNames, err := getStructFieldNames(firstStruct, "db")
if err != nil {
return nil, err
}
// Filter out auto-increment fields
var columns []string
for _, fieldName := range fieldNames {
isAutoIncrement := false
for _, autoField := range autoIncrementFields {
if fieldName == autoField {
isAutoIncrement = true
break
}
}
if !isAutoIncrement {
columns = append(columns, fieldName)
}
}
// Extract values from all structs
var values [][]interface{}
for i := 0; i < targetValue.Len(); i++ {
structVal := targetValue.Index(i).Interface()
fieldValues, err := getStructFieldValues(structVal, "db")
if err != nil {
return nil, err
}
var row []interface{}
for _, col := range columns {
row = append(row, fieldValues[col])
}
values = append(values, row)
}
return s.batchInsert(ctx, tableName, columns, values)
}

93
builder.go Normal file
View File

@ -0,0 +1,93 @@
package stardb
import (
"fmt"
"strings"
)
// QueryBuilder helps build SQL queries
type QueryBuilder struct {
table string
columns []string
where []string
whereArgs []interface{}
orderBy string
limit int
offset int
}
// NewQueryBuilder creates a new query builder
func NewQueryBuilder(table string) *QueryBuilder {
return &QueryBuilder{
table: table,
columns: []string{"*"},
}
}
// Select sets the columns to select
func (qb *QueryBuilder) Select(columns ...string) *QueryBuilder {
qb.columns = columns
return qb
}
// Where adds a WHERE condition
func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuilder {
qb.where = append(qb.where, condition)
qb.whereArgs = append(qb.whereArgs, args...)
return qb
}
// OrderBy sets the ORDER BY clause
func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder {
qb.orderBy = orderBy
return qb
}
// Limit sets the LIMIT
func (qb *QueryBuilder) Limit(limit int) *QueryBuilder {
qb.limit = limit
return qb
}
// Offset sets the OFFSET
func (qb *QueryBuilder) Offset(offset int) *QueryBuilder {
qb.offset = offset
return qb
}
// Build builds the SQL query and returns query string and args
func (qb *QueryBuilder) Build() (string, []interface{}) {
var parts []string
// SELECT
parts = append(parts, fmt.Sprintf("SELECT %s FROM %s",
strings.Join(qb.columns, ", "), qb.table))
// WHERE
if len(qb.where) > 0 {
parts = append(parts, "WHERE "+strings.Join(qb.where, " AND "))
}
// ORDER BY
if qb.orderBy != "" {
parts = append(parts, "ORDER BY "+qb.orderBy)
}
// LIMIT
if qb.limit > 0 {
parts = append(parts, fmt.Sprintf("LIMIT %d", qb.limit))
}
// OFFSET
if qb.offset > 0 {
parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset))
}
return strings.Join(parts, " "), qb.whereArgs
}
// Query executes the query
func (qb *QueryBuilder) Query(db *StarDB) (*StarRows, error) {
query, args := qb.Build()
return db.Query(query, args...)
}

462
builder_test.go Normal file
View File

@ -0,0 +1,462 @@
package stardb
import (
"reflect"
"testing"
)
func TestNewQueryBuilder(t *testing.T) {
qb := NewQueryBuilder("users")
if qb.table != "users" {
t.Errorf("Expected table 'users', got '%s'", qb.table)
}
if len(qb.columns) != 1 || qb.columns[0] != "*" {
t.Errorf("Expected default columns ['*'], got %v", qb.columns)
}
}
func TestQueryBuilder_Select(t *testing.T) {
qb := NewQueryBuilder("users").Select("id", "name", "email")
expected := []string{"id", "name", "email"}
if !reflect.DeepEqual(qb.columns, expected) {
t.Errorf("Expected columns %v, got %v", expected, qb.columns)
}
}
func TestQueryBuilder_Where(t *testing.T) {
qb := NewQueryBuilder("users").
Where("age > ?", 18).
Where("active = ?", true)
if len(qb.where) != 2 {
t.Errorf("Expected 2 where conditions, got %d", len(qb.where))
}
if qb.where[0] != "age > ?" {
t.Errorf("Expected first where 'age > ?', got '%s'", qb.where[0])
}
if qb.where[1] != "active = ?" {
t.Errorf("Expected second where 'active = ?', got '%s'", qb.where[1])
}
if len(qb.whereArgs) != 2 {
t.Errorf("Expected 2 where args, got %d", len(qb.whereArgs))
}
if qb.whereArgs[0] != 18 {
t.Errorf("Expected first arg 18, got %v", qb.whereArgs[0])
}
if qb.whereArgs[1] != true {
t.Errorf("Expected second arg true, got %v", qb.whereArgs[1])
}
}
func TestQueryBuilder_OrderBy(t *testing.T) {
qb := NewQueryBuilder("users").OrderBy("name ASC")
if qb.orderBy != "name ASC" {
t.Errorf("Expected orderBy 'name ASC', got '%s'", qb.orderBy)
}
}
func TestQueryBuilder_Limit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(10)
if qb.limit != 10 {
t.Errorf("Expected limit 10, got %d", qb.limit)
}
}
func TestQueryBuilder_Offset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(20)
if qb.offset != 20 {
t.Errorf("Expected offset 20, got %d", qb.offset)
}
}
func TestQueryBuilder_Build_Simple(t *testing.T) {
qb := NewQueryBuilder("users")
query, args := qb.Build()
expectedQuery := "SELECT * FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithSelect(t *testing.T) {
qb := NewQueryBuilder("users").Select("id", "name", "email")
query, args := qb.Build()
expectedQuery := "SELECT id, name, email FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithWhere(t *testing.T) {
qb := NewQueryBuilder("users").
Where("age > ?", 18).
Where("active = ?", true)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users WHERE age > ? AND active = ?"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 2 {
t.Errorf("Expected 2 args, got %d", len(args))
}
if args[0] != 18 {
t.Errorf("Expected first arg 18, got %v", args[0])
}
if args[1] != true {
t.Errorf("Expected second arg true, got %v", args[1])
}
}
func TestQueryBuilder_Build_WithOrderBy(t *testing.T) {
qb := NewQueryBuilder("users").OrderBy("created_at DESC")
query, args := qb.Build()
expectedQuery := "SELECT * FROM users ORDER BY created_at DESC"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithLimit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(10)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users LIMIT 10"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithOffset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(20)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users OFFSET 20"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_WithLimitAndOffset(t *testing.T) {
qb := NewQueryBuilder("users").Limit(10).Offset(20)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users LIMIT 10 OFFSET 20"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_Build_Complex(t *testing.T) {
qb := NewQueryBuilder("users").
Select("id", "name", "email", "age").
Where("age > ?", 18).
Where("active = ?", true).
Where("country = ?", "US").
OrderBy("name ASC").
Limit(10).
Offset(20)
query, args := qb.Build()
expectedQuery := "SELECT id, name, email, age FROM users WHERE age > ? AND active = ? AND country = ? ORDER BY name ASC LIMIT 10 OFFSET 20"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedArgs := []interface{}{18, true, "US"}
if len(args) != len(expectedArgs) {
t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args))
}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_Build_MultipleWhere(t *testing.T) {
qb := NewQueryBuilder("orders").
Where("user_id = ?", 123).
Where("status IN (?, ?, ?)", "pending", "processing", "shipped").
Where("created_at > ?", "2024-01-01")
query, args := qb.Build()
expectedQuery := "SELECT * FROM orders WHERE user_id = ? AND status IN (?, ?, ?) AND created_at > ?"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
if len(args) != 5 {
t.Errorf("Expected 5 args, got %d", len(args))
}
expectedArgs := []interface{}{123, "pending", "processing", "shipped", "2024-01-01"}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_Chaining(t *testing.T) {
// Test method chaining returns the same builder
qb := NewQueryBuilder("users")
qb2 := qb.Select("id", "name")
if qb != qb2 {
t.Error("Select should return the same builder instance")
}
qb3 := qb.Where("age > ?", 18)
if qb != qb3 {
t.Error("Where should return the same builder instance")
}
qb4 := qb.OrderBy("name ASC")
if qb != qb4 {
t.Error("OrderBy should return the same builder instance")
}
qb5 := qb.Limit(10)
if qb != qb5 {
t.Error("Limit should return the same builder instance")
}
qb6 := qb.Offset(20)
if qb != qb6 {
t.Error("Offset should return the same builder instance")
}
}
func TestQueryBuilder_EmptyWhere(t *testing.T) {
qb := NewQueryBuilder("users").Select("id", "name")
query, args := qb.Build()
expectedQuery := "SELECT id, name FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_OnlyLimit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(5)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users LIMIT 5"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_OnlyOffset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(10)
query, args := qb.Build()
expectedQuery := "SELECT * FROM users OFFSET 10"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_ZeroLimit(t *testing.T) {
qb := NewQueryBuilder("users").Limit(0)
query, args := qb.Build()
// Limit 0 should not be included
expectedQuery := "SELECT * FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_ZeroOffset(t *testing.T) {
qb := NewQueryBuilder("users").Offset(0)
query, args := qb.Build()
// Offset 0 should not be included
expectedQuery := "SELECT * FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_MultipleSelect(t *testing.T) {
// Last Select should override previous ones
qb := NewQueryBuilder("users").
Select("id", "name").
Select("email", "age")
query, _ := qb.Build()
expectedQuery := "SELECT email, age FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
}
func TestQueryBuilder_MultipleOrderBy(t *testing.T) {
// Last OrderBy should override previous ones
qb := NewQueryBuilder("users").
OrderBy("name ASC").
OrderBy("created_at DESC")
query, _ := qb.Build()
expectedQuery := "SELECT * FROM users ORDER BY created_at DESC"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
}
func TestQueryBuilder_ComplexWhereConditions(t *testing.T) {
qb := NewQueryBuilder("products").
Select("id", "name", "price").
Where("category_id = ?", 5).
Where("price BETWEEN ? AND ?", 10.0, 100.0).
Where("stock > ?", 0).
Where("name LIKE ?", "%phone%").
OrderBy("price DESC").
Limit(20)
query, args := qb.Build()
expectedQuery := "SELECT id, name, price FROM products WHERE category_id = ? AND price BETWEEN ? AND ? AND stock > ? AND name LIKE ? ORDER BY price DESC LIMIT 20"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedArgs := []interface{}{5, 10.0, 100.0, 0, "%phone%"}
if len(args) != len(expectedArgs) {
t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args))
}
for i, expected := range expectedArgs {
if args[i] != expected {
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
}
}
}
func TestQueryBuilder_SingleColumn(t *testing.T) {
qb := NewQueryBuilder("users").Select("COUNT(*)")
query, args := qb.Build()
expectedQuery := "SELECT COUNT(*) FROM users"
if query != expectedQuery {
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
}
if len(args) != 0 {
t.Errorf("Expected 0 args, got %d", len(args))
}
}
func TestQueryBuilder_JoinLikeWhere(t *testing.T) {
// Test that WHERE can handle JOIN-like conditions
qb := NewQueryBuilder("users").
Select("users.id", "users.name", "orders.total").
Where("users.id = orders.user_id").
Where("orders.status = ?", "completed")
query, args := qb.Build()
expectedQuery := "SELECT users.id, users.name, orders.total FROM users WHERE users.id = orders.user_id AND orders.status = ?"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
if len(args) != 1 {
t.Errorf("Expected 1 arg, got %d", len(args))
}
if args[0] != "completed" {
t.Errorf("Expected arg 'completed', got %v", args[0])
}
}
// Benchmark tests
func BenchmarkQueryBuilder_Simple(b *testing.B) {
for i := 0; i < b.N; i++ {
qb := NewQueryBuilder("users")
qb.Build()
}
}
func BenchmarkQueryBuilder_Complex(b *testing.B) {
for i := 0; i < b.N; i++ {
qb := NewQueryBuilder("users").
Select("id", "name", "email", "age").
Where("age > ?", 18).
Where("active = ?", true).
Where("country = ?", "US").
OrderBy("name ASC").
Limit(10).
Offset(20)
qb.Build()
}
}

176
converter.go Normal file
View File

@ -0,0 +1,176 @@
package stardb
import (
"strconv"
"time"
)
// convertToInt64 converts any value to int64
func convertToInt64(val interface{}) int64 {
switch v := val.(type) {
case nil:
return 0
case int:
return int64(v)
case int32:
return int64(v)
case int64:
return v
case uint64:
return int64(v)
case float32:
return int64(v)
case float64:
return int64(v)
case string:
result, _ := strconv.ParseInt(v, 10, 64)
return result
case bool:
if v {
return 1
}
return 0
case time.Time:
return v.Unix()
case []byte:
result, _ := strconv.ParseInt(string(v), 10, 64)
return result
default:
return 0
}
}
// convertToUint64 converts any value to uint64
func convertToUint64(val interface{}) uint64 {
switch v := val.(type) {
case nil:
return 0
case int:
return uint64(v)
case int32:
return uint64(v)
case int64:
return uint64(v)
case uint64:
return v
case float32:
return uint64(v)
case float64:
return uint64(v)
case string:
result, _ := strconv.ParseUint(v, 10, 64)
return result
case bool:
if v {
return 1
}
return 0
case time.Time:
return uint64(v.Unix())
case []byte:
result, _ := strconv.ParseUint(string(v), 10, 64)
return result
default:
return 0
}
}
// convertToFloat64 converts any value to float64
func convertToFloat64(val interface{}) float64 {
switch v := val.(type) {
case nil:
return 0
case int:
return float64(v)
case int32:
return float64(v)
case int64:
return float64(v)
case uint64:
return float64(v)
case float32:
return float64(v)
case float64:
return v
case string:
result, _ := strconv.ParseFloat(v, 64)
return result
case bool:
if v {
return 1
}
return 0
case time.Time:
return float64(v.Unix())
case []byte:
result, _ := strconv.ParseFloat(string(v), 64)
return result
default:
return 0
}
}
// convertToBool converts any value to bool
// Non-zero numbers are considered true
func convertToBool(val interface{}) bool {
switch v := val.(type) {
case nil:
return false
case bool:
return v
case int:
return v != 0
case int32:
return v != 0
case int64:
return v != 0
case uint64:
return v != 0
case float32:
return v != 0
case float64:
return v != 0
case string:
result, _ := strconv.ParseBool(v)
return result
case []byte:
result, _ := strconv.ParseBool(string(v))
return result
default:
return false
}
}
// convertToTime converts any value to time.Time
func convertToTime(val interface{}, layout string) time.Time {
switch v := val.(type) {
case nil:
return time.Time{}
case time.Time:
return v
case int:
return time.Unix(int64(v), 0)
case int32:
return time.Unix(int64(v), 0)
case int64:
return time.Unix(v, 0)
case uint64:
return time.Unix(int64(v), 0)
case float32:
sec := int64(v)
nsec := int64((v - float32(sec)) * 1e9)
return time.Unix(sec, nsec)
case float64:
sec := int64(v)
nsec := int64((v - float64(sec)) * 1e9)
return time.Unix(sec, nsec)
case string:
result, _ := time.Parse(layout, v)
return result
case []byte:
result, _ := time.Parse(layout, string(v))
return result
default:
return time.Time{}
}
}

68
converter_safe.go Normal file
View File

@ -0,0 +1,68 @@
package stardb
import (
"fmt"
"strconv"
"time"
)
// ConvertToInt64Safe converts any value to int64 with error handling
func ConvertToInt64Safe(val interface{}) (int64, error) {
switch v := val.(type) {
case nil:
return 0, nil
case int:
return int64(v), nil
case int32:
return int64(v), nil
case int64:
return v, nil
case uint64:
return int64(v), nil
case float32:
return int64(v), nil
case float64:
return int64(v), nil
case string:
return strconv.ParseInt(v, 10, 64)
case bool:
if v {
return 1, nil
}
return 0, nil
case time.Time:
return v.Unix(), nil
case []byte:
return strconv.ParseInt(string(v), 10, 64)
default:
return 0, fmt.Errorf("cannot convert %T to int64", val)
}
}
// ConvertToStringSafe converts any value to string with error handling
func ConvertToStringSafe(val interface{}) (string, error) {
switch v := val.(type) {
case nil:
return "", nil
case string:
return v, nil
case int:
return strconv.Itoa(v), nil
case int32:
return strconv.FormatInt(int64(v), 10), nil
case int64:
return strconv.FormatInt(v, 10), nil
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
case float64:
return strconv.FormatFloat(v, 'f', -1, 64), nil
case bool:
return strconv.FormatBool(v), nil
case time.Time:
return v.String(), nil
case []byte:
return string(v), nil
default:
return "", fmt.Errorf("cannot convert %T to string", val)
}
}

183
converter_test.go Normal file
View File

@ -0,0 +1,183 @@
package stardb
import (
"testing"
"time"
)
func TestConvertToInt64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected int64
}{
{"nil", nil, 0},
{"int", 42, 42},
{"int32", int32(42), 42},
{"int64", int64(42), 42},
{"uint64", uint64(42), 42},
{"float32", float32(42.7), 42},
{"float64", float64(42.7), 42},
{"string", "42", 42},
{"bool true", true, 1},
{"bool false", false, 0},
{"bytes", []byte("42"), 42},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToInt64(tt.input)
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestConvertToUint64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected uint64
}{
{"nil", nil, 0},
{"int", 42, 42},
{"int32", int32(42), 42},
{"int64", int64(42), 42},
{"uint64", uint64(42), 42},
{"float32", float32(42.7), 42},
{"float64", float64(42.7), 42},
{"string", "42", 42},
{"bool true", true, 1},
{"bool false", false, 0},
{"bytes", []byte("42"), 42},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToUint64(tt.input)
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestConvertToFloat64(t *testing.T) {
tests := []struct {
name string
input interface{}
expected float64
}{
{"nil", nil, 0.0},
{"int", 42, 42.0},
{"int32", int32(42), 42.0},
{"int64", int64(42), 42.0},
{"uint64", uint64(42), 42.0},
{"float32", float32(42.5), 42.5},
{"float64", float64(42.5), 42.5},
{"string", "42.5", 42.5},
{"bool true", true, 1.0},
{"bool false", false, 0.0},
{"bytes", []byte("42.5"), 42.5},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToFloat64(tt.input)
if result != tt.expected {
t.Errorf("Expected %f, got %f", tt.expected, result)
}
})
}
}
func TestConvertToBool(t *testing.T) {
tests := []struct {
name string
input interface{}
expected bool
}{
{"nil", nil, false},
{"bool true", true, true},
{"bool false", false, false},
{"int positive", 1, true},
{"int zero", 0, false},
{"int negative", -1, true},
{"float positive", 1.5, true},
{"float zero", 0.0, false},
{"string true", "true", true},
{"string false", "false", false},
{"string 1", "1", true},
{"string 0", "0", false},
{"bytes true", []byte("true"), true},
{"bytes false", []byte("false"), false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToBool(tt.input)
if result != tt.expected {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}
func TestConvertToString(t *testing.T) {
now := time.Now()
tests := []struct {
name string
input interface{}
expected string
}{
{"nil", nil, ""},
{"string", "hello", "hello"},
{"int", 42, "42"},
{"int32", int32(42), "42"},
{"int64", int64(42), "42"},
{"float32", float32(42.5), "42.5"},
{"float64", float64(42.5), "42.5"},
{"bool true", true, "true"},
{"bool false", false, "false"},
{"time", now, now.String()},
{"bytes", []byte("hello"), "hello"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToString(tt.input)
if result != tt.expected {
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
}
})
}
}
func TestConvertToTime(t *testing.T) {
layout := "2006-01-02 15:04:05"
timeStr := "2024-01-01 10:00:00"
expectedTime, _ := time.Parse(layout, timeStr)
tests := []struct {
name string
input interface{}
layout string
expected time.Time
}{
{"nil", nil, layout, time.Time{}},
{"time.Time", expectedTime, layout, expectedTime},
{"int64 unix", int64(1704103200), layout, time.Unix(1704103200, 0)},
{"string", timeStr, layout, expectedTime},
{"bytes", []byte(timeStr), layout, expectedTime},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertToTime(tt.input, tt.layout)
if !result.Equal(tt.expected) {
t.Errorf("Expected %v, got %v", tt.expected, result)
}
})
}
}

2
go.mod
View File

@ -1,3 +1,3 @@
module b612.me/stardb
go 1.16
go 1.16

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=

563
orm.go Normal file
View File

@ -0,0 +1,563 @@
package stardb
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
)
// Orm maps query results to a struct or slice of structs
// Usage:
//
// var user User
// rows.Orm(&user) // single row
// var users []User
// rows.Orm(&users) // multiple rows
func (r *StarRows) Orm(target interface{}) error {
if !r.parsed {
if err := r.parse(); err != nil {
return err
}
}
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() != reflect.Ptr {
return errors.New("target must be a pointer")
}
targetType = targetType.Elem()
targetValue = targetValue.Elem()
// Handle slice/array
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
elementType := targetType.Elem()
result := reflect.New(targetType).Elem()
if r.Length() == 0 {
targetValue.Set(result)
return nil
}
for i := 0; i < r.Length(); i++ {
element := reflect.New(elementType)
if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil {
return err
}
result = reflect.Append(result, element.Elem())
}
targetValue.Set(result)
return nil
}
// Handle single struct
if r.Length() == 0 {
return nil
}
return r.setStructFieldsFromRow(target, "db", 0)
}
// QueryX executes a query with named parameter binding
// Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id")
func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
return s.queryX(nil, target, query, args...)
}
// QueryXContext executes a query with context and named parameter binding
func (s *StarDB) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
return s.queryX(ctx, target, query, args...)
}
// queryX is the internal implementation
func (s *StarDB) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
// Replace named parameters with actual values
processedArgs := make([]interface{}, len(args))
for i, arg := range args {
if str, ok := arg.(string); ok {
if strings.HasPrefix(str, ":") {
fieldName := str[1:]
if val, exists := fieldValues[fieldName]; exists {
processedArgs[i] = val
} else {
processedArgs[i] = ""
}
} else if strings.HasPrefix(str, `\:`) {
processedArgs[i] = str[1:]
} else {
processedArgs[i] = arg
}
} else {
processedArgs[i] = arg
}
}
return s.query(ctx, query, processedArgs...)
}
// QueryXS executes queries for multiple structs
func (s *StarDB) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return s.queryXS(nil, targets, query, args...)
}
// QueryXSContext executes queries with context for multiple structs
func (s *StarDB) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return s.queryXS(ctx, targets, query, args...)
}
// queryXS is the internal implementation
func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
var results []*StarRows
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
for i := 0; i < targetValue.Len(); i++ {
result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := s.queryX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// ExecX executes a statement with named parameter binding
func (s *StarDB) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) {
return s.execX(nil, target, query, args...)
}
// ExecXContext executes a statement with context and named parameter binding
func (s *StarDB) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
return s.execX(ctx, target, query, args...)
}
// execX is the internal implementation
func (s *StarDB) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
// Replace named parameters with actual values
processedArgs := make([]interface{}, len(args))
for i, arg := range args {
if str, ok := arg.(string); ok {
if strings.HasPrefix(str, ":") {
fieldName := str[1:]
if val, exists := fieldValues[fieldName]; exists {
processedArgs[i] = val
} else {
processedArgs[i] = ""
}
} else if strings.HasPrefix(str, `\:`) {
processedArgs[i] = str[1:]
} else {
processedArgs[i] = arg
}
} else {
processedArgs[i] = arg
}
}
return s.exec(ctx, query, processedArgs...)
}
// ExecXS executes statements for multiple structs
func (s *StarDB) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return s.execXS(nil, targets, query, args...)
}
// ExecXSContext executes statements with context for multiple structs
func (s *StarDB) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return s.execXS(ctx, targets, query, args...)
}
// execXS is the internal implementation
func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
var results []sql.Result
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
for i := 0; i < targetValue.Len(); i++ {
result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := s.execX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// Insert inserts a struct into the database
// Usage: Insert(&user, "users", "id") // id is auto-increment
func (s *StarDB) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return s.insert(nil, target, tableName, autoIncrementFields...)
}
// InsertContext inserts a struct with context
func (s *StarDB) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return s.insert(ctx, target, tableName, autoIncrementFields...)
}
// insert is the internal implementation
func (s *StarDB) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...)
if err != nil {
return nil, err
}
args := []interface{}{}
for _, param := range params {
args = append(args, param)
}
return s.execX(ctx, target, query, args...)
}
// Update updates a struct in the database
// Usage: Update(&user, "users", "id") // id is primary key
func (s *StarDB) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return s.update(nil, target, tableName, primaryKeys...)
}
// UpdateContext updates a struct with context
func (s *StarDB) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return s.update(ctx, target, tableName, primaryKeys...)
}
// update is the internal implementation
func (s *StarDB) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
query, params, err := buildUpdateSQL(target, tableName, primaryKeys...)
if err != nil {
return nil, err
}
args := []interface{}{}
for _, param := range params {
args = append(args, param)
}
return s.execX(ctx, target, query, args...)
}
// buildInsertSQL builds an INSERT SQL statement
func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) {
fieldNames, err := getStructFieldNames(target, "db")
if err != nil {
return "", []string{}, err
}
var columns []string
var placeholders []string
var params []string
for _, fieldName := range fieldNames {
// Skip auto-increment fields
isAutoIncrement := false
for _, autoField := range autoIncrementFields {
if fieldName == autoField {
isAutoIncrement = true
break
}
}
if isAutoIncrement {
continue
}
columns = append(columns, fieldName)
placeholders = append(placeholders, "?")
params = append(params, ":"+fieldName)
}
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
tableName,
strings.Join(columns, ", "),
strings.Join(placeholders, ", "))
return query, params, nil
}
// buildUpdateSQL builds an UPDATE SQL statement
func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) {
fieldNames, err := getStructFieldNames(target, "db")
if err != nil {
return "", []string{}, err
}
var setClauses []string
var params []string
// Build SET clause
for _, fieldName := range fieldNames {
setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName))
params = append(params, ":"+fieldName)
}
// Build WHERE clause
var whereClauses []string
for _, pk := range primaryKeys {
whereClauses = append(whereClauses, fmt.Sprintf("%s = ?", pk))
params = append(params, ":"+pk)
}
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s",
tableName,
strings.Join(setClauses, ", "),
strings.Join(whereClauses, " AND "))
return query, params, nil
}
// Transaction ORM methods
// QueryX executes a query with named parameter binding in a transaction
func (t *StarTx) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
return t.queryX(nil, target, query, args...)
}
// QueryXContext executes a query with context in a transaction
func (t *StarTx) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
return t.queryX(ctx, target, query, args...)
}
// queryX is the internal implementation
func (t *StarTx) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
processedArgs := make([]interface{}, len(args))
for i, arg := range args {
if str, ok := arg.(string); ok {
if strings.HasPrefix(str, ":") {
fieldName := str[1:]
if val, exists := fieldValues[fieldName]; exists {
processedArgs[i] = val
} else {
processedArgs[i] = ""
}
} else if strings.HasPrefix(str, `\:`) {
processedArgs[i] = str[1:]
} else {
processedArgs[i] = arg
}
} else {
processedArgs[i] = arg
}
}
return t.query(ctx, query, processedArgs...)
}
// QueryXS executes queries for multiple structs in a transaction
func (t *StarTx) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return t.queryXS(nil, targets, query, args...)
}
// QueryXSContext executes queries with context for multiple structs in a transaction
func (t *StarTx) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
return t.queryXS(ctx, targets, query, args...)
}
// queryXS is the internal implementation
func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
var results []*StarRows
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
for i := 0; i < targetValue.Len(); i++ {
result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := t.queryX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// ExecX executes a statement with named parameter binding in a transaction
func (t *StarTx) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) {
return t.execX(nil, target, query, args...)
}
// ExecXContext executes a statement with context in a transaction
func (t *StarTx) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
return t.execX(ctx, target, query, args...)
}
// execX is the internal implementation
func (t *StarTx) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
fieldValues, err := getStructFieldValues(target, "db")
if err != nil {
return nil, err
}
processedArgs := make([]interface{}, len(args))
for i, arg := range args {
if str, ok := arg.(string); ok {
if strings.HasPrefix(str, ":") {
fieldName := str[1:]
if val, exists := fieldValues[fieldName]; exists {
processedArgs[i] = val
} else {
processedArgs[i] = ""
}
} else if strings.HasPrefix(str, `\:`) {
processedArgs[i] = str[1:]
} else {
processedArgs[i] = arg
}
} else {
processedArgs[i] = arg
}
}
return t.exec(ctx, query, processedArgs...)
}
// ExecXS executes statements for multiple structs in a transaction
func (t *StarTx) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return t.execXS(nil, targets, query, args...)
}
// ExecXSContext executes statements with context for multiple structs in a transaction
func (t *StarTx) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
return t.execXS(ctx, targets, query, args...)
}
// execXS is the internal implementation
func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
var results []sql.Result
targetType := reflect.TypeOf(targets)
targetValue := reflect.ValueOf(targets)
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
for i := 0; i < targetValue.Len(); i++ {
result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
} else {
result, err := t.execX(ctx, targets, query, args...)
if err != nil {
return results, err
}
results = append(results, result)
}
return results, nil
}
// Insert inserts a struct in a transaction
func (t *StarTx) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return t.insert(nil, target, tableName, autoIncrementFields...)
}
// InsertContext inserts a struct with context in a transaction
func (t *StarTx) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
return t.insert(ctx, target, tableName, autoIncrementFields...)
}
// insert is the internal implementation
func (t *StarTx) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...)
if err != nil {
return nil, err
}
args := []interface{}{}
for _, param := range params {
args = append(args, param)
}
return t.execX(ctx, target, query, args...)
}
// Update updates a struct in a transaction
func (t *StarTx) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return t.update(nil, target, tableName, primaryKeys...)
}
// UpdateContext updates a struct with context in a transaction
func (t *StarTx) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
return t.update(ctx, target, tableName, primaryKeys...)
}
// update is the internal implementation
func (t *StarTx) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
query, params, err := buildUpdateSQL(target, tableName, primaryKeys...)
if err != nil {
return nil, err
}
args := []interface{}{}
for _, param := range params {
args = append(args, param)
}
return t.execX(ctx, target, query, args...)
}

88
orm_test.go Normal file
View File

@ -0,0 +1,88 @@
package stardb
import (
"testing"
"time"
)
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
Balance float64 `db:"balance"`
Active bool `db:"active"`
CreatedAt time.Time `db:"created_at"`
}
type Profile struct {
UserID int `db:"user_id"`
Bio string `db:"bio"`
Avatar string `db:"avatar"`
}
type NestedUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
Profile `db:"---"`
}
func TestBuildInsertSQL(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
Email: "test@example.com",
Age: 30,
Balance: 100.0,
Active: true,
CreatedAt: time.Now(),
}
query, params, err := buildInsertSQL(&user, "users", "id")
if err != nil {
t.Fatalf("buildInsertSQL failed: %v", err)
}
expectedQuery := "INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedParams := []string{":name", ":email", ":age", ":balance", ":active", ":created_at"}
if len(params) != len(expectedParams) {
t.Errorf("Expected %d params, got %d", len(expectedParams), len(params))
}
for i, param := range params {
if param != expectedParams[i] {
t.Errorf("Expected param %s, got %s", expectedParams[i], param)
}
}
}
func TestBuildUpdateSQL(t *testing.T) {
user := User{
ID: 1,
Name: "Test",
Email: "test@example.com",
Age: 30,
Balance: 100.0,
Active: true,
CreatedAt: time.Now(),
}
query, params, err := buildUpdateSQL(&user, "users", "id")
if err != nil {
t.Fatalf("buildUpdateSQL failed: %v", err)
}
expectedQuery := "UPDATE users SET id = ?, name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?"
if query != expectedQuery {
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
}
expectedParamCount := 8 // 7 fields + 1 primary key
if len(params) != expectedParamCount {
t.Errorf("Expected %d params, got %d", expectedParamCount, len(params))
}
}

511
orm_v1.go
View File

@ -1,511 +0,0 @@
package stardb
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
)
func (star *StarRows) Orm(ins interface{}) error {
//check if is slice
if !star.parsed {
if err := star.parserows(); err != nil {
return err
}
}
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() != reflect.Ptr {
return errors.New("interface not writable")
}
//now convert to slice
t = t.Elem()
v = v.Elem()
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
//get type of slice
sigType := t.Elem()
var result reflect.Value
result = reflect.New(t).Elem()
if star.Length == 0 {
v.Set(result)
return nil
}
for i := 0; i < star.Length; i++ {
val := reflect.New(sigType)
star.setAllRefValue(val.Interface(), "db", i)
result = reflect.Append(result, val.Elem())
}
v.Set(result)
return nil
}
if star.Length == 0 {
return nil
}
return star.setAllRefValue(ins, "db", 0)
}
func (star *StarDB) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) {
kvMap, err := getAllRefValue(ins, "db")
if err != nil {
return nil, err
}
for k, v := range args {
if k == 0 {
continue
}
switch v.(type) {
case string:
str := v.(string)
if strings.Index(str, ":") == 0 {
if _, ok := kvMap[str[1:]]; ok {
args[k] = kvMap[str[1:]]
} else {
args[k] = ""
}
continue
}
if strings.Index(str, `\:`) == 0 {
args[k] = kvMap[str[1:]]
}
}
}
return star.query(ctx, args...)
}
func (star *StarDB) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) {
return star.queryX(nil, ins, args)
}
func (star *StarDB) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) {
var starRes []*StarRows
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.queryX(nil, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.queryX(nil, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarDB) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) {
var starRes []sql.Result
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.execX(nil, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.execX(nil, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarDB) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) {
return star.execX(nil, ins, args...)
}
func getUpdateSentence(ins interface{}, sheetName string, primaryKey ...string) (string, []string, error) {
Keys, err := getAllRefKey(ins, "db")
if err != nil {
return "", []string{}, err
}
var mystr string
for k, v := range Keys {
mystr += fmt.Sprintf("%s=? ", v)
Keys[k] = ":" + v
}
mystr = fmt.Sprintf("update %s set %s where ", sheetName, mystr)
var whereSlice []string
for _, v := range primaryKey {
whereSlice = append(whereSlice, v+"=?")
Keys = append(Keys, ":"+v)
}
mystr += strings.Join(whereSlice, " and ")
return mystr, Keys, nil
}
func getInsertSentence(ins interface{}, sheetName string, autoIncrease ...string) (string, []string, error) {
Keys, err := getAllRefKey(ins, "db")
if err != nil {
return "", []string{}, err
}
var mystr, rps string
var rtnKeys []string
cns:
for _, v := range Keys {
for _, vs := range autoIncrease {
if v == vs {
rps += "null,"
continue cns
}
}
rtnKeys = append(rtnKeys, ":"+v)
rps += "?,"
}
mystr = fmt.Sprintf("insert into %s (%s) values (%s) ", sheetName, strings.Join(Keys, ","), rps[:len(rps)-1])
return mystr, rtnKeys, nil
}
func (star *StarDB) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) {
kvMap, err := getAllRefValue(ins, "db")
if err != nil {
return nil, err
}
for k, v := range args {
if k == 0 {
continue
}
switch v.(type) {
case string:
str := v.(string)
if strings.Index(str, ":") == 0 {
if _, ok := kvMap[str[1:]]; ok {
args[k] = kvMap[str[1:]]
} else {
args[k] = ""
}
continue
}
if strings.Index(str, `\:`) == 0 {
args[k] = kvMap[str[1:]]
}
}
}
return star.exec(ctx, args...)
}
func (star *StarDB) Update(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) {
return star.updateinsert(nil, true, ins, sheetName, primaryKey...)
}
func (star *StarDB) UpdateContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) {
return star.updateinsert(ctx, true, ins, sheetName, primaryKey...)
}
func (star *StarDB) Insert(ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) {
return star.updateinsert(nil, false, ins, sheetName, autoCreaseKey...)
}
func (star *StarDB) InsertContext(ctx context.Context, ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) {
return star.updateinsert(ctx, false, ins, sheetName, autoCreaseKey...)
}
func (star *StarDB) updateinsert(ctx context.Context, isUpdate bool, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) {
var sqlStr string
var para []string
var err error
if isUpdate {
sqlStr, para, err = getUpdateSentence(ins, sheetName, primaryKey...)
} else {
sqlStr, para, err = getInsertSentence(ins, sheetName, primaryKey...)
}
if err != nil {
return nil, err
}
tmpStr := append([]interface{}{}, sqlStr)
for _, v := range para {
tmpStr = append(tmpStr, v)
}
return star.execX(ctx, ins, tmpStr...)
}
func (star *StarDB) QueryXContext(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) {
return star.queryX(ctx, ins, args)
}
func (star *StarDB) QueryXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]*StarRows, error) {
var starRes []*StarRows
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.queryX(ctx, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.queryX(ctx, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarDB) ExecXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]sql.Result, error) {
var starRes []sql.Result
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.execX(ctx, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.execX(ctx, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarDB) ExecXContext(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) {
return star.execX(ctx, ins, args...)
}
func (star *StarTx) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) {
kvMap, err := getAllRefValue(ins, "db")
if err != nil {
return nil, err
}
for k, v := range args {
if k == 0 {
continue
}
switch v.(type) {
case string:
str := v.(string)
if strings.Index(str, ":") == 0 {
if _, ok := kvMap[str[1:]]; ok {
args[k] = kvMap[str[1:]]
} else {
args[k] = ""
}
continue
}
if strings.Index(str, `\:`) == 0 {
args[k] = kvMap[str[1:]]
}
}
}
return star.query(ctx, args...)
}
func (star *StarTx) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) {
return star.queryX(nil, ins, args)
}
func (star *StarTx) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) {
var starRes []*StarRows
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.queryX(nil, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.queryX(nil, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarTx) Update(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) {
return star.updateinsert(nil, true, ins, sheetName, primaryKey...)
}
func (star *StarTx) UpdateContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) {
return star.updateinsert(ctx, true, ins, sheetName, primaryKey...)
}
func (star *StarTx) Insert(ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) {
return star.updateinsert(nil, false, ins, sheetName, autoCreaseKey...)
}
func (star *StarTx) InsertContext(ctx context.Context, ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) {
return star.updateinsert(ctx, false, ins, sheetName, autoCreaseKey...)
}
func (star *StarTx) updateinsert(ctx context.Context, isUpdate bool, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) {
var sqlStr string
var para []string
var err error
if isUpdate {
sqlStr, para, err = getUpdateSentence(ins, sheetName, primaryKey...)
} else {
sqlStr, para, err = getInsertSentence(ins, sheetName, primaryKey...)
}
if err != nil {
return nil, err
}
tmpStr := append([]interface{}{}, sqlStr)
for _, v := range para {
tmpStr = append(tmpStr, v)
}
return star.execX(ctx, ins, tmpStr...)
}
func (star *StarTx) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) {
var starRes []sql.Result
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.execX(nil, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.execX(nil, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarTx) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) {
return star.execX(nil, ins, args...)
}
func (star *StarTx) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) {
kvMap, err := getAllRefValue(ins, "db")
if err != nil {
return nil, err
}
for k, v := range args {
if k == 0 {
continue
}
switch v.(type) {
case string:
str := v.(string)
if strings.Index(str, ":") == 0 {
if _, ok := kvMap[str[1:]]; ok {
args[k] = kvMap[str[1:]]
} else {
args[k] = ""
}
continue
}
if strings.Index(str, `\:`) == 0 {
args[k] = kvMap[str[1:]]
}
}
}
return star.exec(ctx, args...)
}
func (star *StarTx) QueryXContext(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) {
return star.queryX(ctx, ins, args)
}
func (star *StarTx) QueryXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]*StarRows, error) {
var starRes []*StarRows
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.queryX(ctx, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.queryX(ctx, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarTx) ExecXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]sql.Result, error) {
var starRes []sql.Result
t := reflect.TypeOf(ins)
v := reflect.ValueOf(ins)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
//now convert to slice
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
for i := 0; i < v.Len(); i++ {
result, err := star.execX(ctx, v.Index(i).Interface(), args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
} else {
result, err := star.execX(ctx, ins, args...)
if err != nil {
return starRes, err
}
starRes = append(starRes, result)
}
return starRes, nil
}
func (star *StarTx) ExecXContext(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) {
return star.execX(ctx, ins, args...)
}

54
pool.go Normal file
View File

@ -0,0 +1,54 @@
package stardb
import (
"time"
)
// PoolConfig represents database connection pool configuration
type PoolConfig struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
// DefaultPoolConfig returns default pool configuration
func DefaultPoolConfig() *PoolConfig {
return &PoolConfig{
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: time.Hour,
ConnMaxIdleTime: 10 * time.Minute,
}
}
// SetPoolConfig applies pool configuration to the database
func (s *StarDB) SetPoolConfig(config *PoolConfig) {
if config.MaxOpenConns > 0 {
s.db.SetMaxOpenConns(config.MaxOpenConns)
}
if config.MaxIdleConns > 0 {
s.db.SetMaxIdleConns(config.MaxIdleConns)
}
if config.ConnMaxLifetime > 0 {
s.db.SetConnMaxLifetime(config.ConnMaxLifetime)
}
if config.ConnMaxIdleTime > 0 {
s.db.SetConnMaxIdleTime(config.ConnMaxIdleTime)
}
}
// OpenWithPool opens a database connection with pool configuration
func OpenWithPool(driver, connStr string, config *PoolConfig) (*StarDB, error) {
db := NewStarDB()
if err := db.Open(driver, connStr); err != nil {
return nil, err
}
if config == nil {
config = DefaultPoolConfig()
}
db.SetPoolConfig(config)
return db, nil
}

View File

@ -6,231 +6,252 @@ import (
"time"
)
func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) error {
t := reflect.TypeOf(stc)
v := reflect.ValueOf(stc)
if t.Kind() == reflect.Ptr {
v = v.Elem()
}
if t.Kind() != reflect.Ptr && !v.CanSet() {
return errors.New("interface{} is not writable")
}
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if v.Kind() != reflect.Struct {
return errors.New("interface{} is not a struct")
// setStructFieldsFromRow sets struct fields from a row result using reflection
func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error {
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() == reflect.Ptr {
targetValue = targetValue.Elem()
}
for i := 0; i < t.NumField(); i++ {
tp := t.Field(i)
srFrd := v.Field(i)
seg := tp.Tag.Get(skey)
if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() {
return errors.New("target is not writable")
}
if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct {
if seg == "" {
if targetType.Kind() == reflect.Ptr {
targetType = targetType.Elem()
}
if targetValue.Kind() != reflect.Struct {
return errors.New("target is not a struct")
}
for i := 0; i < targetType.NumField(); i++ {
field := targetType.Field(i)
fieldValue := targetValue.Field(i)
tagValue := field.Tag.Get(tagKey)
// Handle nested structs
if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct {
if tagValue == "" {
continue
}
if seg == "---" {
sp := reflect.New(reflect.TypeOf(srFrd.Interface()).Elem()).Interface()
star.setAllRefValue(sp, skey, rows)
v.Field(i).Set(reflect.ValueOf(sp))
if tagValue == "---" {
nestedPtr := reflect.New(reflect.TypeOf(fieldValue.Interface()).Elem()).Interface()
r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex)
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr))
continue
}
}
if srFrd.Kind() == reflect.Struct {
if seg == "" {
if fieldValue.Kind() == reflect.Struct {
if tagValue == "" {
continue
}
if seg == "---" {
sp := reflect.New(reflect.TypeOf(v.Field(i).Interface())).Interface()
star.setAllRefValue(sp, skey, rows)
v.Field(i).Set(reflect.ValueOf(sp).Elem())
if tagValue == "---" {
nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface()
r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex)
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem())
continue
}
}
if seg == "" {
if tagValue == "" {
continue
}
if _, ok := star.Row(rows).columnref[seg]; !ok {
// Check if column exists
if _, ok := r.Row(rowIndex).columnIndex[tagValue]; !ok {
continue
}
myInt64 := star.Row(rows).MustInt64(seg)
myUint64 := star.Row(rows).MustUint64(seg)
switch v.Field(i).Kind() {
case reflect.String:
v.Field(i).SetString(star.Row(rows).MustString(seg))
case reflect.Int64:
v.Field(i).SetInt(myInt64)
case reflect.Int32:
v.Field(i).SetInt(int64(star.Row(rows).MustInt32(seg)))
case reflect.Int16:
v.Field(i).SetInt(int64(int16(myInt64)))
case reflect.Int8:
v.Field(i).SetInt(int64(int8(myInt64)))
case reflect.Uint64:
v.Field(i).SetUint(myUint64)
case reflect.Uint32:
v.Field(i).SetUint(uint64(uint32(myUint64)))
case reflect.Uint16:
v.Field(i).SetUint(uint64(uint16(myUint64)))
case reflect.Uint8:
v.Field(i).SetUint(uint64(uint8(myUint64)))
case reflect.Bool:
v.Field(i).SetBool(star.Row(rows).MustBool(seg))
case reflect.Float64:
v.Field(i).SetFloat(star.Row(rows).MustFloat64(seg))
case reflect.Float32:
v.Field(i).SetFloat(float64(star.Row(rows).MustFloat32(seg)))
case reflect.Interface, reflect.Struct, reflect.Ptr:
inf := star.Row(rows).Result[star.columnref[seg]]
switch vtype := inf.(type) {
case time.Time:
v.Field(i).Set(reflect.ValueOf(vtype))
}
default:
}
// Set field value based on type
r.setFieldValue(fieldValue, tagValue, rowIndex)
}
return nil
}
func setRefValue(stc interface{}, skey, key string, value interface{}) error {
t := reflect.TypeOf(stc)
v := reflect.ValueOf(stc).Elem()
if t.Kind() != reflect.Ptr || !v.CanSet() {
return errors.New("interface{} is not writable")
}
if v.Kind() != reflect.Struct {
return errors.New("interface{} is not a struct")
}
t = t.Elem()
for i := 0; i < t.NumField(); i++ {
tp := t.Field(i)
seg := tp.Tag.Get(skey)
if seg == "" || key != seg {
continue
// setFieldValue sets a single field value
func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, rowIndex int) {
row := r.Row(rowIndex)
switch fieldValue.Kind() {
case reflect.String:
fieldValue.SetString(row.MustString(columnName))
case reflect.Int:
fieldValue.SetInt(int64(row.MustInt(columnName)))
case reflect.Int8:
fieldValue.SetInt(int64(int8(row.MustInt64(columnName))))
case reflect.Int16:
fieldValue.SetInt(int64(int16(row.MustInt64(columnName))))
case reflect.Int32:
fieldValue.SetInt(int64(row.MustInt32(columnName)))
case reflect.Int64:
fieldValue.SetInt(row.MustInt64(columnName))
case reflect.Uint:
fieldValue.SetUint(uint64(row.MustUint64(columnName)))
case reflect.Uint8:
fieldValue.SetUint(uint64(uint8(row.MustUint64(columnName))))
case reflect.Uint16:
fieldValue.SetUint(uint64(uint16(row.MustUint64(columnName))))
case reflect.Uint32:
fieldValue.SetUint(uint64(uint32(row.MustUint64(columnName))))
case reflect.Uint64:
fieldValue.SetUint(row.MustUint64(columnName))
case reflect.Bool:
fieldValue.SetBool(row.MustBool(columnName))
case reflect.Float32:
fieldValue.SetFloat(float64(row.MustFloat32(columnName)))
case reflect.Float64:
fieldValue.SetFloat(row.MustFloat64(columnName))
case reflect.Interface, reflect.Struct, reflect.Ptr:
// Handle special types like time.Time
colIndex := r.columnIndex[columnName]
val := row.Result()[colIndex]
if t, ok := val.(time.Time); ok {
fieldValue.Set(reflect.ValueOf(t))
}
v.Field(i).Set(reflect.ValueOf(value))
}
return nil
}
func getAllRefValue(stc interface{}, skey string) (map[string]interface{}, error) {
// getStructFieldValues extracts all field values from a struct
func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) {
result := make(map[string]interface{})
t := reflect.TypeOf(stc)
v := reflect.ValueOf(stc)
if t.Kind() == reflect.Ptr {
if v.IsNil() {
return nil, errors.New("ptr interface{} is nil")
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return nil, errors.New("pointer target is nil")
}
t = t.Elem()
v = v.Elem()
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
if v.Kind() != reflect.Struct {
return nil, errors.New("interface{} is not a struct")
if targetValue.Kind() != reflect.Struct {
return nil, errors.New("target is not a struct")
}
for i := 0; i < t.NumField(); i++ {
tp := t.Field(i)
srFrd := v.Field(i)
seg := tp.Tag.Get(skey)
if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct {
if srFrd.IsNil() {
for i := 0; i < targetType.NumField(); i++ {
field := targetType.Field(i)
fieldValue := targetValue.Field(i)
tagValue := field.Tag.Get(tagKey)
// Handle nested pointer structs
if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct {
if fieldValue.IsNil() {
continue
}
if seg == "---" {
res, err := getAllRefValue(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey)
if tagValue == "---" {
nestedValues, err := getStructFieldValues(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey)
if err != nil {
return result, err
}
for k, v := range res {
for k, v := range nestedValues {
result[k] = v
}
continue
}
}
if v.Field(i).Kind() == reflect.Struct {
res, err := getAllRefValue(v.Field(i).Interface(), skey)
if seg == "---" {
// Handle nested structs
if targetValue.Field(i).Kind() == reflect.Struct {
if tagValue == "---" {
nestedValues, err := getStructFieldValues(targetValue.Field(i).Interface(), tagKey)
if err != nil {
return result, err
}
for k, v := range res {
for k, v := range nestedValues {
result[k] = v
}
continue
}
}
if seg == "" {
if tagValue == "" {
continue
}
value := v.Field(i)
if !value.CanInterface() {
if !fieldValue.CanInterface() {
continue
}
result[seg] = value.Interface()
result[tagValue] = fieldValue.Interface()
}
return result, nil
}
func getAllRefKey(stc interface{}, skey string) ([]string, error) {
// getStructFieldNames extracts all field names (tag values) from a struct
func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
var result []string
_, isStruct := isWritableStruct(stc)
if !isStruct {
return []string{}, errors.New("interface{} is not a struct")
if !isStruct(target) {
return []string{}, errors.New("target is not a struct")
}
t := reflect.TypeOf(stc)
v := reflect.ValueOf(stc)
if t.Kind() == reflect.Ptr {
if v.IsNil() {
return []string{}, errors.New("ptr interface{} is nil")
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
if targetType.Kind() == reflect.Ptr {
if targetValue.IsNil() {
return []string{}, errors.New("pointer target is nil")
}
t = t.Elem()
v = v.Elem()
targetType = targetType.Elem()
targetValue = targetValue.Elem()
}
for i := 0; i < t.NumField(); i++ {
srFrd := v.Field(i)
profile := t.Field(i)
seg := profile.Tag.Get(skey)
if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct {
if srFrd.IsNil() {
for i := 0; i < targetType.NumField(); i++ {
fieldValue := targetValue.Field(i)
field := targetType.Field(i)
tagValue := field.Tag.Get(tagKey)
// Handle nested pointer structs
if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct {
if fieldValue.IsNil() {
continue
}
if seg == "---" {
res, err := getAllRefKey(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey)
if tagValue == "---" {
nestedNames, err := getStructFieldNames(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey)
if err != nil {
return result, err
}
for _, v := range res {
result = append(result, v)
}
result = append(result, nestedNames...)
continue
}
}
if v.Field(i).Kind() == reflect.Struct && seg == "---" {
res, err := getAllRefKey(v.Field(i).Interface(), skey)
// Handle nested structs
if targetValue.Field(i).Kind() == reflect.Struct && tagValue == "---" {
nestedNames, err := getStructFieldNames(targetValue.Field(i).Interface(), tagKey)
if err != nil {
return result, err
}
for _, v := range res {
result = append(result, v)
}
result = append(result, nestedNames...)
continue
}
if seg != "" {
result = append(result, seg)
if tagValue != "" {
result = append(result, tagValue)
}
}
return result, nil
}
func isWritableStruct(stc interface{}) (isWritable bool, isStruct bool) {
t := reflect.TypeOf(stc)
v := reflect.ValueOf(stc)
if t.Kind() == reflect.Ptr || v.CanSet() {
isWritable = true
}
if v.Kind() == reflect.Struct {
isStruct = true
}
return
// isWritable checks if a value is writable
func isWritable(target interface{}) bool {
targetType := reflect.TypeOf(target)
targetValue := reflect.ValueOf(target)
return targetType.Kind() == reflect.Ptr || targetValue.CanSet()
}
// isStruct checks if a value is a struct
func isStruct(target interface{}) bool {
targetValue := reflect.ValueOf(target)
if targetValue.Kind() == reflect.Ptr {
targetValue = targetValue.Elem()
}
return targetValue.Kind() == reflect.Struct
}

View File

@ -1,49 +0,0 @@
package stardb
import (
"fmt"
"testing"
)
type Useless struct {
Leader string `db:"leader"`
Usable bool `db:"use"`
O *Whoami `db:"---"`
}
type Whoami struct {
Hehe string `db:"hehe"`
}
func TestUpInOrm(t *testing.T) {
var hehe = Useless{
Leader: "no",
Usable: false,
}
sqlstr, param, err := getUpdateSentence(hehe, "ryz", "leader")
fmt.Println(sqlstr, param, err)
sqlstr, param, err = getInsertSentence(hehe, "ryz", "use")
fmt.Println(sqlstr, param, err)
}
func Test_SetRefVal(t *testing.T) {
var hehe = Useless{
Leader: "no",
}
fmt.Printf("%+v\n", hehe)
fmt.Println(setRefValue(&hehe, "db", "leader", "sb"))
fmt.Printf("%+v\n", hehe)
fmt.Printf("%+v\n", hehe)
fmt.Println(getAllRefKey(hehe, "db"))
fmt.Println(getAllRefValue(hehe, "db"))
}
func Test_Ref(t *testing.T) {
oooooo := Useless{
Leader: "Heheeee",
}
oooooo.O = &Whoami{"fuck"}
fmt.Println(getAllRefKey(oooooo, "db"))
fmt.Println(getAllRefValue(oooooo, "db"))
fmt.Println(getAllRefValue(&oooooo, "db"))
}

286
result.go Normal file
View File

@ -0,0 +1,286 @@
package stardb
import (
"fmt"
"reflect"
"time"
)
// StarResult represents a single row result
type StarResult struct {
result []interface{}
columns []string
columnIndex map[string]int
columnsType []reflect.Type
}
// Result returns the raw result slice
func (r *StarResult) Result() []interface{} {
return r.result
}
// Columns returns column names
func (r *StarResult) Columns() []string {
return r.columns
}
// ColumnsType returns column types
func (r *StarResult) ColumnsType() []reflect.Type {
return r.columnsType
}
// IsNil checks if a column value is nil
func (r *StarResult) IsNil(name string) bool {
index, ok := r.columnIndex[name]
if !ok {
return true
}
return r.result[index] == nil
}
// MustString returns column value as string
func (r *StarResult) MustString(name string) string {
index, ok := r.columnIndex[name]
if !ok {
return ""
}
return convertToString(r.result[index])
}
// MustInt returns column value as int
func (r *StarResult) MustInt(name string) int {
return int(r.MustInt64(name))
}
// MustInt32 returns column value as int32
func (r *StarResult) MustInt32(name string) int32 {
return int32(r.MustInt64(name))
}
// MustInt64 returns column value as int64
func (r *StarResult) MustInt64(name string) int64 {
index, ok := r.columnIndex[name]
if !ok {
return 0
}
return convertToInt64(r.result[index])
}
// MustUint64 returns column value as uint64
func (r *StarResult) MustUint64(name string) uint64 {
index, ok := r.columnIndex[name]
if !ok {
return 0
}
return convertToUint64(r.result[index])
}
// MustFloat32 returns column value as float32
func (r *StarResult) MustFloat32(name string) float32 {
return float32(r.MustFloat64(name))
}
// MustFloat64 returns column value as float64
func (r *StarResult) MustFloat64(name string) float64 {
index, ok := r.columnIndex[name]
if !ok {
return 0
}
return convertToFloat64(r.result[index])
}
// MustBool returns column value as bool
func (r *StarResult) MustBool(name string) bool {
index, ok := r.columnIndex[name]
if !ok {
return false
}
return convertToBool(r.result[index])
}
// MustBytes returns column value as []byte
func (r *StarResult) MustBytes(name string) []byte {
index, ok := r.columnIndex[name]
if !ok {
return []byte{}
}
if b, ok := r.result[index].([]byte); ok {
return b
}
return []byte(r.MustString(name))
}
// MustDate returns column value as time.Time
func (r *StarResult) MustDate(name, layout string) time.Time {
index, ok := r.columnIndex[name]
if !ok {
return time.Time{}
}
return convertToTime(r.result[index], layout)
}
// Scan scans the result into provided pointers
// Usage: row.Scan(&id, &name, &email)
func (r *StarResult) Scan(dest ...interface{}) error {
if len(dest) != len(r.result) {
return fmt.Errorf("expected %d destination arguments, got %d", len(r.result), len(dest))
}
for i, d := range dest {
if err := convertAssign(d, r.result[i]); err != nil {
return fmt.Errorf("error scanning column %d: %v", i, err)
}
}
return nil
}
// convertAssign assigns src to dest
func convertAssign(dest, src interface{}) error {
switch d := dest.(type) {
case *string:
*d = convertToString(src)
case *int:
*d = int(convertToInt64(src))
case *int32:
*d = int32(convertToInt64(src))
case *int64:
*d = convertToInt64(src)
case *uint64:
*d = convertToUint64(src)
case *float32:
*d = float32(convertToFloat64(src))
case *float64:
*d = convertToFloat64(src)
case *bool:
*d = convertToBool(src)
case *time.Time:
if t, ok := src.(time.Time); ok {
*d = t
}
case *[]byte:
if b, ok := src.([]byte); ok {
*d = b
}
default:
return fmt.Errorf("unsupported Scan type: %T", dest)
}
return nil
}
// StarResultCol represents a single column result
type StarResultCol struct {
result []interface{}
}
// Result returns the raw result slice
func (c *StarResultCol) Result() []interface{} {
return c.result
}
// Len returns the number of values
func (c *StarResultCol) Len() int {
return len(c.result)
}
// IsNil returns slice of nil checks for each row
func (c *StarResultCol) IsNil() []bool {
result := make([]bool, len(c.result))
for i, v := range c.result {
result[i] = (v == nil)
}
return result
}
// MustString returns all values as strings
func (c *StarResultCol) MustString() []string {
result := make([]string, len(c.result))
for i, v := range c.result {
result[i] = convertToString(v)
}
return result
}
// MustInt returns all values as ints
func (c *StarResultCol) MustInt() []int {
result := make([]int, len(c.result))
for i, v := range c.result {
result[i] = int(convertToInt64(v))
}
return result
}
// MustInt32 returns all values as int32
func (c *StarResultCol) MustInt32() []int32 {
result := make([]int32, len(c.result))
for i, v := range c.result {
result[i] = int32(convertToInt64(v))
}
return result
}
// MustInt64 returns all values as int64
func (c *StarResultCol) MustInt64() []int64 {
result := make([]int64, len(c.result))
for i, v := range c.result {
result[i] = convertToInt64(v)
}
return result
}
// MustUint64 returns all values as uint64
func (c *StarResultCol) MustUint64() []uint64 {
result := make([]uint64, len(c.result))
for i, v := range c.result {
result[i] = convertToUint64(v)
}
return result
}
// MustFloat32 returns all values as float32
func (c *StarResultCol) MustFloat32() []float32 {
result := make([]float32, len(c.result))
for i, v := range c.result {
result[i] = float32(convertToFloat64(v))
}
return result
}
// MustFloat64 returns all values as float64
func (c *StarResultCol) MustFloat64() []float64 {
result := make([]float64, len(c.result))
for i, v := range c.result {
result[i] = convertToFloat64(v)
}
return result
}
// MustBool returns all values as bool
func (c *StarResultCol) MustBool() []bool {
result := make([]bool, len(c.result))
for i, v := range c.result {
result[i] = convertToBool(v)
}
return result
}
// MustBytes returns all values as []byte
func (c *StarResultCol) MustBytes() [][]byte {
result := make([][]byte, len(c.result))
for i, v := range c.result {
if b, ok := v.([]byte); ok {
result[i] = b
} else {
result[i] = []byte(convertToString(v))
}
}
return result
}
// MustDate returns all values as time.Time
func (c *StarResultCol) MustDate(layout string) []time.Time {
result := make([]time.Time, len(c.result))
for i, v := range c.result {
result[i] = convertToTime(v, layout)
}
return result
}

51
result_safe.go Normal file
View File

@ -0,0 +1,51 @@
package stardb
import (
"errors"
"fmt"
"strconv"
)
// GetString returns column value as string with error
func (r *StarResult) GetString(name string) (string, error) {
index, ok := r.columnIndex[name]
if !ok {
return "", errors.New("column not found: " + name)
}
return ConvertToStringSafe(r.Result()[index])
}
// GetInt64 returns column value as int64 with error
func (r *StarResult) GetInt64(name string) (int64, error) {
index, ok := r.columnIndex[name]
if !ok {
return 0, errors.New("column not found: " + name)
}
return ConvertToInt64Safe(r.Result()[index])
}
// GetFloat64 returns column value as float64 with error
func (r *StarResult) GetFloat64(name string) (float64, error) {
index, ok := r.columnIndex[name]
if !ok {
return 0, errors.New("column not found: " + name)
}
switch v := r.Result()[index].(type) {
case nil:
return 0, nil
case float64:
return v, nil
case float32:
return float64(v), nil
case int, int32, int64, uint64:
val, err := ConvertToInt64Safe(v)
return float64(val), err
case string:
return strconv.ParseFloat(v, 64)
case []byte:
return strconv.ParseFloat(string(v), 64)
default:
return 0, fmt.Errorf("cannot convert %T to float64", v)
}
}

168
rows.go Normal file
View File

@ -0,0 +1,168 @@
package stardb
import (
"database/sql"
"reflect"
"strconv"
"time"
)
// StarRows represents a result set from a query
type StarRows struct {
rows *sql.Rows
db *StarDB
length int
stringResult []map[string]string
columns []string
columnsType []reflect.Type
columnIndex map[string]int
data [][]interface{}
parsed bool
}
// Length returns the number of rows
func (r *StarRows) Length() int {
return r.length
}
// StringResult returns all rows as string maps
func (r *StarRows) StringResult() []map[string]string {
return r.stringResult
}
// Columns returns column names
func (r *StarRows) Columns() []string {
return r.columns
}
// ColumnsType returns column types
func (r *StarRows) ColumnsType() []reflect.Type {
return r.columnsType
}
// Close closes the result set
func (r *StarRows) Close() error {
return r.rows.Close()
}
// Rescan re-parses the result set
func (r *StarRows) Rescan() error {
return r.parse()
}
// Row returns a specific row by index
func (r *StarRows) Row(index int) *StarResult {
result := &StarResult{}
if index >= len(r.data) {
return result
}
result.result = r.data[index]
result.columns = r.columns
result.columnsType = r.columnsType
result.columnIndex = r.columnIndex
return result
}
// Col returns all values for a specific column
func (r *StarRows) Col(name string) *StarResultCol {
result := &StarResultCol{}
if _, ok := r.columnIndex[name]; !ok {
return result
}
colIndex := r.columnIndex[name]
for _, row := range r.data {
result.result = append(result.result, row[colIndex])
}
return result
}
// parse parses the sql.Rows into internal data structures
func (r *StarRows) parse() error {
if r.parsed {
return nil
}
defer func() {
r.parsed = true
}()
r.data = [][]interface{}{}
r.columnIndex = make(map[string]int)
r.stringResult = []map[string]string{}
var err error
r.columns, err = r.rows.Columns()
if err != nil {
return err
}
columnTypes, err := r.rows.ColumnTypes()
if err != nil {
return err
}
for _, colType := range columnTypes {
r.columnsType = append(r.columnsType, colType.ScanType())
}
// Build column index map
for i, colName := range r.columns {
r.columnIndex[colName] = i
}
// Prepare scan arguments
scanArgs := make([]interface{}, len(r.columns))
values := make([]interface{}, len(r.columns))
for i := range values {
scanArgs[i] = &values[i]
}
// Scan all rows
for r.rows.Next() {
if err := r.rows.Scan(scanArgs...); err != nil {
return err
}
record := make(map[string]string)
rowCopy := make([]interface{}, len(values))
for i, val := range values {
rowCopy[i] = val
record[r.columns[i]] = convertToString(val)
}
r.data = append(r.data, rowCopy)
r.stringResult = append(r.stringResult, record)
}
r.length = len(r.stringResult)
return nil
}
// convertToString converts any value to string
func convertToString(val interface{}) string {
switch v := val.(type) {
case nil:
return ""
case string:
return v
case int:
return strconv.Itoa(v)
case int32:
return strconv.FormatInt(int64(v), 10)
case int64:
return strconv.FormatInt(v, 10)
case float32:
return strconv.FormatFloat(float64(v), 'f', -1, 32)
case float64:
return strconv.FormatFloat(v, 'f', -1, 64)
case bool:
return strconv.FormatBool(v)
case time.Time:
return v.String()
case []byte:
return string(v)
default:
return ""
}
}

296
stardb.go Normal file
View File

@ -0,0 +1,296 @@
package stardb
import (
"context"
"database/sql"
"errors"
)
// StarDB is a simple wrapper around sql.DB providing enhanced functionality
type StarDB struct {
db *sql.DB
ManualScan bool // If true, rows won't be automatically parsed
}
// NewStarDB creates a new StarDB instance
func NewStarDB() *StarDB {
return &StarDB{}
}
// NewStarDBWithDB creates a new StarDB instance with an existing *sql.DB
func NewStarDBWithDB(db *sql.DB) *StarDB {
return &StarDB{db: db}
}
// DB returns the underlying *sql.DB
func (s *StarDB) DB() *sql.DB {
return s.db
}
// SetDB sets the underlying *sql.DB
func (s *StarDB) SetDB(db *sql.DB) {
s.db = db
}
// Open opens a new database connection
func (s *StarDB) Open(driver, connStr string) error {
var err error
s.db, err = sql.Open(driver, connStr)
return err
}
// Close closes the database connection
func (s *StarDB) Close() error {
return s.db.Close()
}
// Ping verifies the database connection is alive
func (s *StarDB) Ping() error {
return s.db.Ping()
}
// PingContext verifies the database connection with context
func (s *StarDB) PingContext(ctx context.Context) error {
return s.db.PingContext(ctx)
}
// Stats returns database statistics
func (s *StarDB) Stats() sql.DBStats {
return s.db.Stats()
}
// SetMaxOpenConns sets the maximum number of open connections
func (s *StarDB) SetMaxOpenConns(n int) {
s.db.SetMaxOpenConns(n)
}
// SetMaxIdleConns sets the maximum number of idle connections
func (s *StarDB) SetMaxIdleConns(n int) {
s.db.SetMaxIdleConns(n)
}
// Conn returns a single connection from the pool
func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) {
return s.db.Conn(ctx)
}
// Query executes a query that returns rows
// Usage: Query("SELECT * FROM users WHERE id = ?", 1)
func (s *StarDB) Query(query string, args ...interface{}) (*StarRows, error) {
return s.query(nil, query, args...)
}
// QueryContext executes a query with context
func (s *StarDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
return s.query(ctx, query, args...)
}
// query is the internal query implementation
func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if err := s.db.Ping(); err != nil {
return nil, err
}
var rows *sql.Rows
var err error
if ctx == nil {
rows, err = s.db.Query(query, args...)
} else {
rows, err = s.db.QueryContext(ctx, query, args...)
}
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: s,
}
if !s.ManualScan {
err = starRows.parse()
}
return starRows, err
}
// Exec executes a query that doesn't return rows
// Usage: Exec("INSERT INTO users (name) VALUES (?)", "John")
func (s *StarDB) Exec(query string, args ...interface{}) (sql.Result, error) {
return s.exec(nil, query, args...)
}
// ExecContext executes a query with context
func (s *StarDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return s.exec(ctx, query, args...)
}
// exec is the internal exec implementation
func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if err := s.db.Ping(); err != nil {
return nil, err
}
if ctx == nil {
return s.db.Exec(query, args...)
}
return s.db.ExecContext(ctx, query, args...)
}
// Prepare creates a prepared statement
func (s *StarDB) Prepare(query string) (*StarStmt, error) {
stmt, err := s.db.Prepare(query)
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: s}, nil
}
// PrepareContext creates a prepared statement with context
func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
stmt, err := s.db.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: s}, nil
}
// QueryStmt executes a prepared statement query
// Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1)
func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := s.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(args...)
}
// QueryStmtContext executes a prepared statement query with context
func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := s.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.QueryContext(ctx, args...)
}
// ExecStmt executes a prepared statement
// Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John")
func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := s.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Exec(args...)
}
// ExecStmtContext executes a prepared statement with context
func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := s.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.ExecContext(ctx, args...)
}
// Begin starts a transaction
func (s *StarDB) Begin() (*StarTx, error) {
tx, err := s.db.Begin()
if err != nil {
return nil, err
}
return &StarTx{tx: tx, db: s}, nil
}
// BeginTx starts a transaction with options
func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) {
tx, err := s.db.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &StarTx{tx: tx, db: s}, nil
}
// StarStmt represents a prepared statement
type StarStmt struct {
stmt *sql.Stmt
db *StarDB
}
// Query executes a prepared statement query
func (s *StarStmt) Query(args ...interface{}) (*StarRows, error) {
return s.query(nil, args...)
}
// QueryContext executes a prepared statement query with context
func (s *StarStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) {
return s.query(ctx, args...)
}
// query is the internal query implementation
func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) {
var rows *sql.Rows
var err error
if ctx == nil {
rows, err = s.stmt.Query(args...)
} else {
rows, err = s.stmt.QueryContext(ctx, args...)
}
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: s.db,
}
if !s.db.ManualScan {
err = starRows.parse()
}
return starRows, err
}
// Exec executes a prepared statement
func (s *StarStmt) Exec(args ...interface{}) (sql.Result, error) {
return s.exec(nil, args...)
}
// ExecContext executes a prepared statement with context
func (s *StarStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
return s.exec(ctx, args...)
}
// exec is the internal exec implementation
func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) {
if ctx == nil {
return s.stmt.Exec(args...)
}
return s.stmt.ExecContext(ctx, args...)
}
// Close closes the prepared statement
func (s *StarStmt) Close() error {
return s.stmt.Close()
}

File diff suppressed because it is too large Load Diff

563
testing/batch_test.go Normal file
View File

@ -0,0 +1,563 @@
package testing
import (
"context"
"testing"
"time"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
type TestUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
CreatedAt time.Time `db:"created_at"`
}
func setupBatchTestDB(t *testing.T) *stardb.StarDB {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
_, err = db.Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER,
created_at DATETIME
)
`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
return db
}
func TestStarDB_BatchInsert_Basic(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
{"Charlie", "charlie@example.com", 35},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 3 {
t.Errorf("Expected 3 rows affected, got %d", affected)
}
// Verify insertion
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 3 {
t.Errorf("Expected 3 rows in database, got %d", count)
}
}
func TestStarDB_BatchInsert_Single(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_BatchInsert_Empty(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{}
_, err := db.BatchInsert("users", columns, values)
if err == nil {
t.Error("Expected error with empty values, got nil")
}
}
func TestStarDB_BatchInsert_Large(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
var values [][]interface{}
// Insert 100 rows
for i := 0; i < 100; i++ {
values = append(values, []interface{}{
"User" + string(rune(i)),
"user" + string(rune(i)) + "@example.com",
20 + i%50,
})
}
result, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 100 {
t.Errorf("Expected 100 rows affected, got %d", affected)
}
// Verify
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 100 {
t.Errorf("Expected 100 rows in database, got %d", count)
}
}
func TestStarDB_BatchInsertContext(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
ctx := context.Background()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
}
result, err := db.BatchInsertContext(ctx, "users", columns, values)
if err != nil {
t.Fatalf("BatchInsertContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 2 {
t.Errorf("Expected 2 rows affected, got %d", affected)
}
}
func TestStarDB_BatchInsertContext_Timeout(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
time.Sleep(10 * time.Millisecond) // Ensure timeout
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
}
_, err := db.BatchInsertContext(ctx, "users", columns, values)
if err == nil {
t.Error("Expected timeout error, got nil")
}
}
func TestStarDB_BatchInsertStructs_Basic(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 3 {
t.Errorf("Expected 3 rows affected, got %d", affected)
}
// Verify insertion
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 3 {
t.Errorf("Expected 3 rows, got %d", rows.Length())
}
// Verify first user
name := rows.Row(0).MustString("name")
if name != "Alice" {
t.Errorf("Expected first user 'Alice', got '%s'", name)
}
}
func TestStarDB_BatchInsertStructs_Single(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_BatchInsertStructs_Empty(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{}
_, err := db.BatchInsertStructs("users", users, "id")
if err == nil {
t.Error("Expected error with empty slice, got nil")
}
}
func TestStarDB_BatchInsertStructs_NotSlice(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25}
_, err := db.BatchInsertStructs("users", user, "id")
if err == nil {
t.Error("Expected error with non-slice, got nil")
}
}
func TestStarDB_BatchInsertStructs_Pointer(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructs("users", &users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs with pointer failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 2 {
t.Errorf("Expected 2 rows affected, got %d", affected)
}
}
func TestStarDB_BatchInsertStructsContext(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
ctx := context.Background()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
}
result, err := db.BatchInsertStructsContext(ctx, "users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructsContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 2 {
t.Errorf("Expected 2 rows affected, got %d", affected)
}
}
func TestStarDB_BatchInsertStructs_Large(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
var users []TestUser
for i := 0; i < 50; i++ {
users = append(users, TestUser{
Name: "User" + string(rune(i)),
Email: "user" + string(rune(i)) + "@example.com",
Age: 20 + i%30,
CreatedAt: time.Now(),
})
}
result, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 50 {
t.Errorf("Expected 50 rows affected, got %d", affected)
}
// Verify
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 50 {
t.Errorf("Expected 50 rows in database, got %d", count)
}
}
func TestStarDB_BatchInsert_VerifyData(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
columns := []string{"name", "email", "age"}
values := [][]interface{}{
{"Alice", "alice@example.com", 25},
{"Bob", "bob@example.com", 30},
}
_, err := db.BatchInsert("users", columns, values)
if err != nil {
t.Fatalf("BatchInsert failed: %v", err)
}
// Verify data integrity
rows, err := db.Query("SELECT name, email, age FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Check Alice
row0 := rows.Row(0)
if row0.MustString("name") != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", row0.MustString("name"))
}
if row0.MustString("email") != "alice@example.com" {
t.Errorf("Expected email 'alice@example.com', got '%s'", row0.MustString("email"))
}
if row0.MustInt("age") != 25 {
t.Errorf("Expected age 25, got %d", row0.MustInt("age"))
}
// Check Bob
row1 := rows.Row(1)
if row1.MustString("name") != "Bob" {
t.Errorf("Expected name 'Bob', got '%s'", row1.MustString("name"))
}
if row1.MustString("email") != "bob@example.com" {
t.Errorf("Expected email 'bob@example.com', got '%s'", row1.MustString("email"))
}
if row1.MustInt("age") != 30 {
t.Errorf("Expected age 30, got %d", row1.MustInt("age"))
}
}
func TestStarDB_BatchInsertStructs_VerifyData(t *testing.T) {
db := setupBatchTestDB(t)
defer db.Close()
now := time.Now()
users := []TestUser{
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: now},
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: now},
}
_, err := db.BatchInsertStructs("users", users, "id")
if err != nil {
t.Fatalf("BatchInsertStructs failed: %v", err)
}
// Query and verify with ORM
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var resultUsers []TestUser
err = rows.Orm(&resultUsers)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if len(resultUsers) != 2 {
t.Fatalf("Expected 2 users, got %d", len(resultUsers))
}
// Verify Alice
if resultUsers[0].Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", resultUsers[0].Name)
}
if resultUsers[0].Email != "alice@example.com" {
t.Errorf("Expected email 'alice@example.com', got '%s'", resultUsers[0].Email)
}
if resultUsers[0].Age != 25 {
t.Errorf("Expected age 25, got %d", resultUsers[0].Age)
}
// Verify Bob
if resultUsers[1].Name != "Bob" {
t.Errorf("Expected name 'Bob', got '%s'", resultUsers[1].Name)
}
if resultUsers[1].Email != "bob@example.com" {
t.Errorf("Expected email 'bob@example.com', got '%s'", resultUsers[1].Email)
}
if resultUsers[1].Age != 30 {
t.Errorf("Expected age 30, got %d", resultUsers[1].Age)
}
}
// Benchmark tests
func BenchmarkBatchInsert_10(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
columns := []string{"name", "email", "age"}
var values [][]interface{}
for i := 0; i < 10; i++ {
values = append(values, []interface{}{"User", "user@example.com", 25})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsert("users", columns, values)
db.Exec("DELETE FROM users")
}
}
func BenchmarkBatchInsert_100(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
columns := []string{"name", "email", "age"}
var values [][]interface{}
for i := 0; i < 100; i++ {
values = append(values, []interface{}{"User", "user@example.com", 25})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsert("users", columns, values)
db.Exec("DELETE FROM users")
}
}
func BenchmarkBatchInsertStructs_10(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
var users []TestUser
for i := 0; i < 10; i++ {
users = append(users, TestUser{
Name: "User",
Email: "user@example.com",
Age: 25,
CreatedAt: time.Now(),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsertStructs("users", users, "id")
db.Exec("DELETE FROM users")
}
}
func BenchmarkBatchInsertStructs_100(b *testing.B) {
db := setupBatchTestDB(&testing.T{})
defer db.Close()
var users []TestUser
for i := 0; i < 100; i++ {
users = append(users, TestUser{
Name: "User",
Email: "user@example.com",
Age: 25,
CreatedAt: time.Now(),
})
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
db.BatchInsertStructs("users", users, "id")
db.Exec("DELETE FROM users")
}
}

23
testing/go.mod Normal file
View File

@ -0,0 +1,23 @@
module b612.me/stardb/testing
go 1.25.6
require (
b612.me/stardb v0.0.0
modernc.org/sqlite v1.46.1
)
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/ncruces/go-strftime v1.0.0 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
golang.org/x/sys v0.37.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
modernc.org/memory v1.11.0 // indirect
)
replace b612.me/stardb => ../

53
testing/go.sum Normal file
View File

@ -0,0 +1,53 @@
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=

691
testing/orm_test.go Normal file
View File

@ -0,0 +1,691 @@
package testing
import (
"context"
"testing"
"time"
)
type User struct {
ID int64 `db:"id"`
Name string `db:"name"`
Email string `db:"email"`
Age int `db:"age"`
Balance float64 `db:"balance"`
Active bool `db:"active"`
CreatedAt time.Time `db:"created_at"`
}
type Profile struct {
UserID int `db:"user_id"`
Bio string `db:"bio"`
Avatar string `db:"avatar"`
}
type NestedUser struct {
ID int64 `db:"id"`
Name string `db:"name"`
Profile `db:"---"`
}
func TestStarRows_Orm_Single(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if user.Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
}
if user.Email != "alice@example.com" {
t.Errorf("Expected email 'alice@example.com', got '%s'", user.Email)
}
if user.Age != 25 {
t.Errorf("Expected age 25, got %d", user.Age)
}
if user.Balance != 100.50 {
t.Errorf("Expected balance 100.50, got %f", user.Balance)
}
if !user.Active {
t.Errorf("Expected user to be active")
}
}
func TestStarRows_Orm_Multiple(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var users []User
err = rows.Orm(&users)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if len(users) != 3 {
t.Fatalf("Expected 3 users, got %d", len(users))
}
expectedNames := []string{"Alice", "Bob", "Charlie"}
for i, user := range users {
if user.Name != expectedNames[i] {
t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name)
}
}
}
func TestStarRows_Orm_Empty(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NonExistent")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var users []User
err = rows.Orm(&users)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if len(users) != 0 {
t.Errorf("Expected 0 users, got %d", len(users))
}
}
func TestStarRows_Orm_NotPointer(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(user) // Not a pointer
if err == nil {
t.Errorf("Expected error when passing non-pointer, got nil")
}
}
func TestStarDB_Insert(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "David",
Email: "david@example.com",
Age: 40,
Balance: 400.00,
Active: true,
CreatedAt: time.Now(),
}
result, err := db.Insert(&user, "users", "id")
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
lastID, err := result.LastInsertId()
if err != nil {
t.Fatalf("LastInsertId failed: %v", err)
}
if lastID <= 0 {
t.Errorf("Expected positive last insert ID, got %d", lastID)
}
// Verify insertion
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var insertedUser User
err = rows.Orm(&insertedUser)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if insertedUser.Name != "David" {
t.Errorf("Expected name 'David', got '%s'", insertedUser.Name)
}
if insertedUser.Email != "david@example.com" {
t.Errorf("Expected email 'david@example.com', got '%s'", insertedUser.Email)
}
if insertedUser.Age != 40 {
t.Errorf("Expected age 40, got %d", insertedUser.Age)
}
}
func TestStarDB_InsertContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
user := User{
Name: "Eve",
Email: "eve@example.com",
Age: 28,
Balance: 250.00,
Active: true,
CreatedAt: time.Now(),
}
result, err := db.InsertContext(ctx, &user, "users", "id")
if err != nil {
t.Fatalf("InsertContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_Update(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
// First, get the user
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
// Update the user
user.Age = 26
user.Balance = 150.75
result, err := db.Update(&user, "users", "id")
if err != nil {
t.Fatalf("Update failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
// Verify the update
rows2, err := db.Query("SELECT * FROM users WHERE id = ?", user.ID)
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows2.Close()
var updatedUser User
err = rows2.Orm(&updatedUser)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if updatedUser.Age != 26 {
t.Errorf("Expected age 26, got %d", updatedUser.Age)
}
if updatedUser.Balance != 150.75 {
t.Errorf("Expected balance 150.75, got %f", updatedUser.Balance)
}
}
func TestStarDB_UpdateContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
// Get user
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
var user User
err = rows.Orm(&user)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
// Update
user.Active = false
result, err := db.UpdateContext(ctx, &user, "users", "id")
if err != nil {
t.Fatalf("UpdateContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_QueryX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "Alice",
}
rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
t.Fatalf("QueryX failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
var result User
err = rows.Orm(&result)
if err != nil {
t.Fatalf("Orm failed: %v", err)
}
if result.Name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", result.Name)
}
}
func TestStarDB_QueryXContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
user := User{
Age: 30,
}
rows, err := db.QueryXContext(ctx, &user, "SELECT * FROM users WHERE age = ?", ":age")
if err != nil {
t.Fatalf("QueryXContext failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarDB_QueryXS(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
users := []User{
{Name: "Alice"},
{Name: "Bob"},
}
results, err := db.QueryXS(&users, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
t.Fatalf("QueryXS failed: %v", err)
}
if len(results) != 2 {
t.Errorf("Expected 2 results, got %d", len(results))
}
for i, rows := range results {
if rows.Length() != 1 {
t.Errorf("Expected 1 row in result %d, got %d", i, rows.Length())
}
}
}
func TestStarDB_ExecX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "Alice",
Age: 99,
}
result, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
if err != nil {
t.Fatalf("ExecX failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
// Verify
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
age := rows.Row(0).MustInt("age")
if age != 99 {
t.Errorf("Expected age 99, got %d", age)
}
}
func TestStarDB_ExecXContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
user := User{
Name: "Bob",
Active: false,
}
result, err := db.ExecXContext(ctx, &user, "UPDATE users SET active = ? WHERE name = ?", ":active", ":name")
if err != nil {
t.Fatalf("ExecXContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_ExecXS(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
users := []User{
{Name: "Alice", Age: 26},
{Name: "Bob", Age: 31},
}
results, err := db.ExecXS(&users, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
if err != nil {
t.Fatalf("ExecXS failed: %v", err)
}
if len(results) != 2 {
t.Errorf("Expected 2 results, got %d", len(results))
}
for i, result := range results {
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed for result %d: %v", i, err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected in result %d, got %d", i, affected)
}
}
}
func TestStarTx_Insert(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
user := User{
Name: "Frank",
Email: "frank@example.com",
Age: 45,
Balance: 500.00,
Active: true,
CreatedAt: time.Now(),
}
result, err := tx.Insert(&user, "users", "id")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Insert failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
lastID, err := result.LastInsertId()
if err != nil {
t.Fatalf("LastInsertId failed: %v", err)
}
if lastID <= 0 {
t.Errorf("Expected positive last insert ID, got %d", lastID)
}
// Verify
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Frank")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarTx_Update(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
// Get user
rows, err := tx.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
tx.Rollback()
t.Fatalf("Query failed: %v", err)
}
var user User
err = rows.Orm(&user)
rows.Close()
if err != nil {
tx.Rollback()
t.Fatalf("Orm failed: %v", err)
}
// Update
user.Age = 27
result, err := tx.Update(&user, "users", "id")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Update failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarTx_QueryX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
defer tx.Rollback()
user := User{
Name: "Charlie",
}
rows, err := tx.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name")
if err != nil {
t.Fatalf("Tx.QueryX failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarTx_ExecX(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
user := User{
Name: "Charlie",
Age: 36,
}
result, err := tx.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.ExecX failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarTx_Rollback(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
user := User{
Name: "Rollback User",
Email: "rollback@example.com",
Age: 50,
Balance: 600.00,
Active: true,
CreatedAt: time.Now(),
}
_, err = tx.Insert(&user, "users", "id")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Insert failed: %v", err)
}
// Rollback instead of commit
err = tx.Rollback()
if err != nil {
t.Fatalf("Rollback failed: %v", err)
}
// Verify the insert was rolled back
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Rollback User")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 0 {
t.Errorf("Expected 0 rows after rollback, got %d", rows.Length())
}
}
func TestNamedParameterEscape(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
user := User{
Name: "Alice",
}
// Test escaped colon
rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", `\:name`)
if err != nil {
t.Fatalf("QueryX with escaped parameter failed: %v", err)
}
defer rows.Close()
// Should use literal ":name" string, not the field value
if rows.Length() != 0 {
t.Errorf("Expected 0 rows with literal ':name', got %d", rows.Length())
}
}

272
testing/pool_test.go Normal file
View File

@ -0,0 +1,272 @@
package testing
import (
"testing"
"time"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
func TestDefaultPoolConfig(t *testing.T) {
config := stardb.DefaultPoolConfig()
if config.MaxOpenConns != 25 {
t.Errorf("Expected MaxOpenConns 25, got %d", config.MaxOpenConns)
}
if config.MaxIdleConns != 5 {
t.Errorf("Expected MaxIdleConns 5, got %d", config.MaxIdleConns)
}
if config.ConnMaxLifetime != time.Hour {
t.Errorf("Expected ConnMaxLifetime 1 hour, got %v", config.ConnMaxLifetime)
}
if config.ConnMaxIdleTime != 10*time.Minute {
t.Errorf("Expected ConnMaxIdleTime 10 minutes, got %v", config.ConnMaxIdleTime)
}
}
func TestStarDB_SetPoolConfig(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
config := &stardb.PoolConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetime: 30 * time.Minute,
ConnMaxIdleTime: 5 * time.Minute,
}
db.SetPoolConfig(config)
stats := db.Stats()
if stats.MaxOpenConnections != 50 {
t.Errorf("Expected MaxOpenConnections 50, got %d", stats.MaxOpenConnections)
}
}
func TestStarDB_SetPoolConfig_Partial(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Only set MaxOpenConns
config := &stardb.PoolConfig{
MaxOpenConns: 100,
}
db.SetPoolConfig(config)
stats := db.Stats()
if stats.MaxOpenConnections != 100 {
t.Errorf("Expected MaxOpenConnections 100, got %d", stats.MaxOpenConnections)
}
}
func TestStarDB_SetPoolConfig_Zero(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Zero values should be ignored
config := &stardb.PoolConfig{
MaxOpenConns: 0,
MaxIdleConns: 0,
}
db.SetPoolConfig(config)
// Should not panic or error
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig with zero values: %v", err)
}
}
func TestOpenWithPool_Default(t *testing.T) {
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
err = db.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
stats := db.Stats()
if stats.MaxOpenConnections != 25 {
t.Errorf("Expected default MaxOpenConnections 25, got %d", stats.MaxOpenConnections)
}
}
func TestOpenWithPool_Custom(t *testing.T) {
config := &stardb.PoolConfig{
MaxOpenConns: 15,
MaxIdleConns: 3,
ConnMaxLifetime: 20 * time.Minute,
ConnMaxIdleTime: 3 * time.Minute,
}
db, err := stardb.OpenWithPool("sqlite", ":memory:", config)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
err = db.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
stats := db.Stats()
if stats.MaxOpenConnections != 15 {
t.Errorf("Expected MaxOpenConnections 15, got %d", stats.MaxOpenConnections)
}
}
func TestOpenWithPool_InvalidDriver(t *testing.T) {
config := stardb.DefaultPoolConfig()
db, err := stardb.OpenWithPool("invalid_driver", "invalid_conn", config)
if err == nil {
db.Close()
t.Error("Expected error with invalid driver, got nil")
}
}
func TestOpenWithPool_Query(t *testing.T) {
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
// Create table
_, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Insert data
_, err = db.Exec(`INSERT INTO test (name) VALUES (?)`, "Alice")
if err != nil {
t.Fatalf("Failed to insert data: %v", err)
}
// Query data
rows, err := db.Query("SELECT * FROM test")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestPoolConfig_AllFields(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
config := &stardb.PoolConfig{
MaxOpenConns: 30,
MaxIdleConns: 8,
ConnMaxLifetime: 45 * time.Minute,
ConnMaxIdleTime: 7 * time.Minute,
}
db.SetPoolConfig(config)
// Verify by checking stats
stats := db.Stats()
if stats.MaxOpenConnections != 30 {
t.Errorf("Expected MaxOpenConnections 30, got %d", stats.MaxOpenConnections)
}
// Test that connections work
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig: %v", err)
}
}
func TestPoolConfig_NegativeValues(t *testing.T) {
db := stardb.NewStarDB()
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
defer db.Close()
// Negative values should be ignored
config := &stardb.PoolConfig{
MaxOpenConns: -1,
MaxIdleConns: -1,
}
db.SetPoolConfig(config)
// Should not panic or error
err = db.Ping()
if err != nil {
t.Errorf("Ping failed after SetPoolConfig with negative values: %v", err)
}
}
func TestOpenWithPool_MultipleConnections(t *testing.T) {
config := &stardb.PoolConfig{
MaxOpenConns: 5,
MaxIdleConns: 2,
}
db, err := stardb.OpenWithPool("sqlite", ":memory:", config)
if err != nil {
t.Fatalf("OpenWithPool failed: %v", err)
}
defer db.Close()
// Create table
_, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Perform multiple operations
for i := 0; i < 10; i++ {
_, err = db.Exec(`INSERT INTO test (value) VALUES (?)`, "test")
if err != nil {
t.Errorf("Insert %d failed: %v", i, err)
}
}
// Verify
rows, err := db.Query("SELECT COUNT(*) as count FROM test")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
count := rows.Row(0).MustInt("count")
if count != 10 {
t.Errorf("Expected 10 rows, got %d", count)
}
}

231
testing/result_test.go Normal file
View File

@ -0,0 +1,231 @@
package testing
import (
"testing"
"time"
)
func TestStarResult_MustString(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
name := row.MustString("name")
if name != "Alice" {
t.Errorf("Expected 'Alice', got '%s'", name)
}
email := row.MustString("email")
if email != "alice@example.com" {
t.Errorf("Expected 'alice@example.com', got '%s'", email)
}
// Test non-existent column
nonexistent := row.MustString("nonexistent")
if nonexistent != "" {
t.Errorf("Expected empty string for non-existent column, got '%s'", nonexistent)
}
}
func TestStarResult_MustInt(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
age := row.MustInt("age")
if age != 30 {
t.Errorf("Expected age 30, got %d", age)
}
id := row.MustInt("id")
if id <= 0 {
t.Errorf("Expected positive id, got %d", id)
}
}
func TestStarResult_MustInt64(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Charlie")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
age := row.MustInt64("age")
if age != 35 {
t.Errorf("Expected age 35, got %d", age)
}
}
func TestStarResult_MustFloat64(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
balance := row.MustFloat64("balance")
if balance != 100.50 {
t.Errorf("Expected balance 100.50, got %f", balance)
}
}
func TestStarResult_MustBool(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Alice - active
row := rows.Row(0)
active := row.MustBool("active")
if !active {
t.Errorf("Expected Alice to be active")
}
// Charlie - inactive
row = rows.Row(2)
active = row.MustBool("active")
if active {
t.Errorf("Expected Charlie to be inactive")
}
}
func TestStarResult_IsNil(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
// Insert a row with NULL value
_, err := db.Exec("INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, NULL, ?, ?, ?)",
"David", "david@example.com", 150.0, true, time.Now())
if err != nil {
t.Fatalf("Insert failed: %v", err)
}
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
row := rows.Row(0)
if !row.IsNil("age") {
t.Errorf("Expected age to be NULL")
}
if row.IsNil("name") {
t.Errorf("Expected name to not be NULL")
}
}
func TestStarResultCol_MustString(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("name")
names := col.MustString()
expected := []string{"Alice", "Bob", "Charlie"}
for i, name := range names {
if name != expected[i] {
t.Errorf("Expected name '%s', got '%s'", expected[i], name)
}
}
}
func TestStarResultCol_MustInt(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("age")
ages := col.MustInt()
expected := []int{25, 30, 35}
for i, age := range ages {
if age != expected[i] {
t.Errorf("Expected age %d, got %d", expected[i], age)
}
}
}
func TestStarResultCol_MustFloat64(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("balance")
balances := col.MustFloat64()
expected := []float64{100.50, 200.75, 300.25}
for i, balance := range balances {
if balance != expected[i] {
t.Errorf("Expected balance %f, got %f", expected[i], balance)
}
}
}
func TestStarResultCol_MustBool(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
col := rows.Col("active")
actives := col.MustBool()
expected := []bool{true, true, false}
for i, active := range actives {
if active != expected[i] {
t.Errorf("Expected active %v, got %v", expected[i], active)
}
}
}

103
testing/rows_test.go Normal file
View File

@ -0,0 +1,103 @@
package testing
import (
"testing"
)
func TestStarRows_Row(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Test first row
row := rows.Row(0)
name := row.MustString("name")
if name != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", name)
}
// Test out of bounds
row = rows.Row(999)
if len(row.Result()) != 0 {
t.Errorf("Expected empty result for out of bounds index")
}
}
func TestStarRows_Col(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users ORDER BY name")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
// Test column extraction
col := rows.Col("name")
names := col.MustString()
if len(names) != 3 {
t.Errorf("Expected 3 names, got %d", len(names))
}
if names[0] != "Alice" {
t.Errorf("Expected first name 'Alice', got '%s'", names[0])
}
// Test non-existent column
col = rows.Col("nonexistent")
if len(col.Result()) != 0 {
t.Errorf("Expected empty result for non-existent column")
}
}
func TestStarRows_Rescan(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
db.ManualScan = true
rows, err := db.Query("SELECT * FROM users")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
err = rows.Rescan()
if err != nil {
t.Fatalf("Rescan failed: %v", err)
}
if rows.Length() != 3 {
t.Errorf("Expected 3 rows, got %d", rows.Length())
}
}
func TestStarRows_StringResult(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT name, age FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if len(rows.StringResult()) != 1 {
t.Fatalf("Expected 1 string result, got %d", len(rows.StringResult()))
}
record := rows.StringResult()[0]
if record["name"] != "Alice" {
t.Errorf("Expected name 'Alice', got '%s'", record["name"])
}
if record["age"] != "25" {
t.Errorf("Expected age '25', got '%s'", record["age"])
}
}

233
testing/stardb_test.go Normal file
View File

@ -0,0 +1,233 @@
package testing
import (
"context"
"testing"
"time"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
func TestStarDB_Open(t *testing.T) {
db := &stardb.StarDB{}
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Errorf("Open failed: %v", err)
}
defer db.Close()
err = db.Ping()
if err != nil {
t.Errorf("Ping failed: %v", err)
}
}
func TestStarDB_Query(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.Query("SELECT * FROM users WHERE age > ?", 25)
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 2 {
t.Errorf("Expected 2 rows, got %d", rows.Length())
}
if len(rows.Columns()) != 7 {
t.Errorf("Expected 7 columns, got %d", len(rows.Columns()))
}
}
func TestStarDB_QueryContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("QueryContext failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarDB_Exec(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
result, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 26, "Alice")
if err != nil {
t.Fatalf("Exec failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_ExecContext(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
ctx := context.Background()
result, err := db.ExecContext(ctx, "DELETE FROM users WHERE name = ?", "Charlie")
if err != nil {
t.Fatalf("ExecContext failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_QueryStmt(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
rows, err := db.QueryStmt("SELECT * FROM users WHERE age > ?", 25)
if err != nil {
t.Fatalf("QueryStmt failed: %v", err)
}
defer rows.Close()
if rows.Length() != 2 {
t.Errorf("Expected 2 rows, got %d", rows.Length())
}
}
func TestStarDB_ExecStmt(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
result, err := db.ExecStmt("UPDATE users SET age = ? WHERE name = ?", 27, "Bob")
if err != nil {
t.Fatalf("ExecStmt failed: %v", err)
}
affected, err := result.RowsAffected()
if err != nil {
t.Fatalf("RowsAffected failed: %v", err)
}
if affected != 1 {
t.Errorf("Expected 1 row affected, got %d", affected)
}
}
func TestStarDB_Prepare(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
stmt, err := db.Prepare("SELECT * FROM users WHERE name = ?")
if err != nil {
t.Fatalf("Prepare failed: %v", err)
}
defer stmt.Close()
rows, err := stmt.Query("Alice")
if err != nil {
t.Fatalf("Stmt.Query failed: %v", err)
}
defer rows.Close()
if rows.Length() != 1 {
t.Errorf("Expected 1 row, got %d", rows.Length())
}
}
func TestStarDB_Transaction(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
_, err = tx.Exec("UPDATE users SET age = ? WHERE name = ?", 28, "Alice")
if err != nil {
tx.Rollback()
t.Fatalf("Tx.Exec failed: %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit failed: %v", err)
}
// Verify the change
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
age := rows.Row(0).MustInt("age")
if age != 28 {
t.Errorf("Expected age 28, got %d", age)
}
}
func TestStarDB_TransactionRollback(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin failed: %v", err)
}
_, err = tx.Exec("UPDATE users SET age = ? WHERE name = ?", 99, "Alice")
if err != nil {
t.Fatalf("Tx.Exec failed: %v", err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("Rollback failed: %v", err)
}
// Verify the change was rolled back
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
if err != nil {
t.Fatalf("Query failed: %v", err)
}
defer rows.Close()
age := rows.Row(0).MustInt("age")
if age == 99 {
t.Errorf("Expected age to be rolled back, but got %d", age)
}
}
func TestStarDB_SetMaxConnections(t *testing.T) {
db := setupTestDB(t)
defer db.Close()
db.SetMaxOpenConns(10)
db.SetMaxIdleConns(5)
stats := db.Stats()
if stats.MaxOpenConnections != 10 {
t.Errorf("Expected MaxOpenConnections 10, got %d", stats.MaxOpenConnections)
}
}

47
testing/testing.go Normal file
View File

@ -0,0 +1,47 @@
package testing
import (
"testing"
"b612.me/stardb"
_ "modernc.org/sqlite"
)
// setupTestDB creates a test database with sample data
// This function is only available when building with -tags=testing
func setupTestDB(t *testing.T) *stardb.StarDB {
db := &stardb.StarDB{}
err := db.Open("sqlite", ":memory:")
if err != nil {
t.Fatalf("Failed to open database: %v", err)
}
// Create test table
_, err = db.Exec(`
CREATE TABLE users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT NOT NULL,
age INTEGER,
balance REAL,
active BOOLEAN,
created_at DATETIME
)
`)
if err != nil {
t.Fatalf("Failed to create table: %v", err)
}
// Insert test data
_, err = db.Exec(`
INSERT INTO users (name, email, age, balance, active, created_at) VALUES
('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'),
('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'),
('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00')
`)
if err != nil {
t.Fatalf("Failed to insert test data: %v", err)
}
return db
}

156
tx.go Normal file
View File

@ -0,0 +1,156 @@
package stardb
import (
"context"
"database/sql"
"errors"
)
// StarTx represents a database transaction
type StarTx struct {
tx *sql.Tx
db *StarDB
}
// Query executes a query within the transaction
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
return t.query(nil, query, args...)
}
// QueryContext executes a query with context within the transaction
func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
return t.query(ctx, query, args...)
}
// query is the internal query implementation
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if err := t.db.Ping(); err != nil {
return nil, err
}
var rows *sql.Rows
var err error
if ctx == nil {
rows, err = t.tx.Query(query, args...)
} else {
rows, err = t.tx.QueryContext(ctx, query, args...)
}
if err != nil {
return nil, err
}
starRows := &StarRows{
rows: rows,
db: t.db,
}
if !t.db.ManualScan {
err = starRows.parse()
}
return starRows, err
}
// Exec executes a query within the transaction
func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) {
return t.exec(nil, query, args...)
}
// ExecContext executes a query with context within the transaction
func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
return t.exec(ctx, query, args...)
}
// exec is the internal exec implementation
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if err := t.db.Ping(); err != nil {
return nil, err
}
if ctx == nil {
return t.tx.Exec(query, args...)
}
return t.tx.ExecContext(ctx, query, args...)
}
// Prepare creates a prepared statement within the transaction
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
stmt, err := t.tx.Prepare(query)
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: t.db}, nil
}
// PrepareContext creates a prepared statement with context
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
stmt, err := t.tx.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
return &StarStmt{stmt: stmt, db: t.db}, nil
}
// QueryStmt executes a prepared statement query within the transaction
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := t.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Query(args...)
}
// QueryStmtContext executes a prepared statement query with context
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := t.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.QueryContext(ctx, args...)
}
// ExecStmt executes a prepared statement within the transaction
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := t.Prepare(query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.Exec(args...)
}
// ExecStmtContext executes a prepared statement with context
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
if query == "" {
return nil, errors.New("query string cannot be empty")
}
stmt, err := t.PrepareContext(ctx, query)
if err != nil {
return nil, err
}
defer stmt.Close()
return stmt.ExecContext(ctx, args...)
}
// Commit commits the transaction
func (t *StarTx) Commit() error {
return t.tx.Commit()
}
// Rollback rolls back the transaction
func (t *StarTx) Rollback() error {
return t.tx.Rollback()
}