From 88569eb17616e862a358e309711edb965a327ff7 Mon Sep 17 00:00:00 2001 From: starainrt Date: Sat, 7 Mar 2026 19:27:44 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 19 + .idea/stardb.iml | 10 +- LICENSE.txt | 201 +++++++ README.MD | 535 +++++++++++++++++ batch.go | 112 ++++ builder.go | 93 +++ builder_test.go | 462 ++++++++++++++ converter.go | 176 ++++++ converter_safe.go | 68 +++ converter_test.go | 183 ++++++ go.mod | 2 +- go.sum | 2 + orm.go | 563 ++++++++++++++++++ orm_test.go | 88 +++ orm_v1.go | 511 ---------------- pool.go | 54 ++ reflect.go | 321 +++++----- reflect_test.go | 49 -- result.go | 286 +++++++++ result_safe.go | 51 ++ rows.go | 168 ++++++ stardb.go | 296 +++++++++ stardb_v1.go | 1291 ---------------------------------------- testing/batch_test.go | 563 ++++++++++++++++++ testing/go.mod | 23 + testing/go.sum | 53 ++ testing/orm_test.go | 691 +++++++++++++++++++++ testing/pool_test.go | 272 +++++++++ testing/result_test.go | 231 +++++++ testing/rows_test.go | 103 ++++ testing/stardb_test.go | 233 ++++++++ testing/testing.go | 47 ++ tx.go | 156 +++++ 33 files changed, 5910 insertions(+), 2003 deletions(-) create mode 100644 .gitignore create mode 100644 LICENSE.txt create mode 100644 README.MD create mode 100644 batch.go create mode 100644 builder.go create mode 100644 builder_test.go create mode 100644 converter.go create mode 100644 converter_safe.go create mode 100644 converter_test.go create mode 100644 go.sum create mode 100644 orm.go create mode 100644 orm_test.go delete mode 100644 orm_v1.go create mode 100644 pool.go delete mode 100644 reflect_test.go create mode 100644 result.go create mode 100644 result_safe.go create mode 100644 rows.go create mode 100644 stardb.go delete mode 100644 stardb_v1.go create mode 100644 testing/batch_test.go create mode 100644 testing/go.mod create mode 100644 testing/go.sum create mode 100644 testing/orm_test.go create mode 100644 testing/pool_test.go create mode 100644 testing/result_test.go create mode 100644 testing/rows_test.go create mode 100644 testing/stardb_test.go create mode 100644 testing/testing.go create mode 100644 tx.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cab69e8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Binaries +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test + +# Go build cache +*.out + +# Dependency directories +vendor/ + +# IDE +.idea/ + +# OS +.DS_Store \ No newline at end of file diff --git a/.idea/stardb.iml b/.idea/stardb.iml index 5e764c4..8806f39 100644 --- a/.idea/stardb.iml +++ b/.idea/stardb.iml @@ -1,6 +1,14 @@ - + + + + + diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.MD b/README.MD new file mode 100644 index 0000000..abd5046 --- /dev/null +++ b/README.MD @@ -0,0 +1,535 @@ +# StarDB + +一个轻量级的 Go 数据库封装库,个人学习用,提供简洁的 API 和 ORM 功能。 + +[![Go Version](https://img.shields.io/badge/Go-%3E%3D%201.16-blue)](https://golang.org/) +[![License](https://img.shields.io/badge/license-Apache%20License%202.0-blue)](LICENSE) + +## ✨ 特性 + +- ✅ **零第三方依赖** - 仅使用 Go 标准库 +- ✅ **类型安全** - 提供类型安全的结果转换方法 +- ✅ **简单 ORM** - 支持结构体与数据库的自动映射 +- ✅ **Context 支持** - 所有操作都支持 context +- ✅ **事务支持** - 完整的事务操作支持 +- ✅ **预编译语句** - 支持预编译语句以提升性能 +- ✅ **批量操作** - 高效的批量插入功能 +- ✅ **连接池管理** - 便捷的连接池配置 +- ✅ **查询构建器** - 链式调用构建 SQL 查询 + +## 📦 安装 + +```bash +go get b612.me/stardb +``` + +## 🚀 快速开始 + +### 基本使用 + +```go +package main + +import ( + "b612.me/stardb" + _ "github.com/mattn/go-sqlite3" +) + +func main() { + // 创建数据库实例 + db := stardb.NewStarDB() + err := db.Open("sqlite3", "test.db") + if err != nil { + panic(err) + } + defer db.Close() + + // 执行查询 + rows, err := db.Query("SELECT * FROM users WHERE age > ?", 18) + if err != nil { + panic(err) + } + defer rows.Close() + + // 遍历结果 + for i := 0; i < rows.Length(); i++ { + row := rows.Row(i) + name := row.MustString("name") + age := row.MustInt("age") + println(name, age) + } +} +``` + +### ORM 使用 + +```go +package main + +import ( + "b612.me/stardb" + "time" + _ "github.com/mattn/go-sqlite3" +) + +// 定义结构体 +type User struct { + ID int64 `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Age int `db:"age"` + Active bool `db:"active"` + CreatedAt time.Time `db:"created_at"` +} + +func main() { + db := stardb.NewStarDB() + db.Open("sqlite3", "test.db") + defer db.Close() + + // 查询单个对象 + rows, _ := db.Query("SELECT * FROM users WHERE id = ?", 1) + defer rows.Close() + + var user User + rows.Orm(&user) + println(user.Name) + + // 查询多个对象 + rows2, _ := db.Query("SELECT * FROM users WHERE age > ?", 18) + defer rows2.Close() + + var users []User + rows2.Orm(&users) + + for _, u := range users { + println(u.Name, u.Age) + } +} +``` + +### 插入和更新 + +```go +// 插入数据 +user := User{ + Name: "Alice", + Email: "alice@example.com", + Age: 25, + Active: true, + CreatedAt: time.Now(), +} + +result, err := db.Insert(&user, "users", "id") // "id" 是自增字段 +if err != nil { + panic(err) +} + +lastID, _ := result.LastInsertId() +println("插入的 ID:", lastID) + +// 更新数据 +user.Age = 26 +result, err = db.Update(&user, "users", "id") // "id" 是主键 +if err != nil { + panic(err) +} + +affected, _ := result.RowsAffected() +println("更新的行数:", affected) +``` + +### 批量插入 + +```go +// 方式 1:使用原始数据 +columns := []string{"name", "email", "age"} +values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + {"Charlie", "charlie@example.com", 35}, +} + +result, err := db.BatchInsert("users", columns, values) +if err != nil { + panic(err) +} + +// 方式 2:使用结构体 +users := []User{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, + {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()}, + {Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()}, +} + +result, err = db.BatchInsertStructs("users", users, "id") +if err != nil { + panic(err) +} + +affected, _ := result.RowsAffected() +println("批量插入行数:", affected) +``` + +### 事务操作 + +```go +// 开始事务 +tx, err := db.Begin() +if err != nil { + panic(err) +} + +// 执行操作 +_, err = tx.Exec("INSERT INTO users (name, email) VALUES (?, ?)", "Alice", "alice@example.com") +if err != nil { + tx.Rollback() + panic(err) +} + +_, err = tx.Exec("UPDATE accounts SET balance = balance - ? WHERE user_id = ?", 100, 1) +if err != nil { + tx.Rollback() + panic(err) +} + +// 提交事务 +err = tx.Commit() +if err != nil { + panic(err) +} +``` + +### 预编译语句 + +```go +// 创建预编译语句 +stmt, err := db.Prepare("SELECT * FROM users WHERE age > ?") +if err != nil { + panic(err) +} +defer stmt.Close() + +// 多次执行 +rows1, _ := stmt.Query(18) +defer rows1.Close() + +rows2, _ := stmt.Query(25) +defer rows2.Close() + +rows3, _ := stmt.Query(30) +defer rows3.Close() +``` + +### 查询构建器 + +```go +// 使用查询构建器 +rows, err := stardb.NewQueryBuilder("users"). + Select("id", "name", "email"). + Where("age > ?", 18). + Where("active = ?", true). + OrderBy("name ASC"). + Limit(10). + Offset(0). + Query(db) + +if err != nil { + panic(err) +} +defer rows.Close() + +var users []User +rows.Orm(&users) +``` + +### 连接池配置 + +```go +// 方式 1:使用默认配置 +db, err := stardb.OpenWithPool("sqlite3", "test.db", nil) +if err != nil { + panic(err) +} +defer db.Close() + +// 方式 2:自定义配置 +config := &stardb.PoolConfig{ + MaxOpenConns: 50, // 最大打开连接数 + MaxIdleConns: 10, // 最大空闲连接数 + ConnMaxLifetime: 1 * time.Hour, // 连接最大生命周期 + ConnMaxIdleTime: 10 * time.Minute, // 连接最大空闲时间 +} + +db, err = stardb.OpenWithPool("mysql", "user:pass@tcp(localhost:3306)/dbname", config) +if err != nil { + panic(err) +} +defer db.Close() + +// 方式 3:手动设置 +db := stardb.NewStarDB() +db.Open("postgres", "postgres://user:pass@localhost/dbname") +db.SetPoolConfig(config) +``` + +### 命名参数绑定 + +```go +user := User{ + Name: "Alice", + Age: 25, +} + +// 使用 :fieldname 语法绑定结构体字段 +rows, err := db.QueryX(&user, + "SELECT * FROM users WHERE name = ? AND age > ?", + ":name", ":age") +if err != nil { + panic(err) +} +defer rows.Close() +``` + +## 📖 详细文档 + +### 结果转换方法 + +StarDB 提供了丰富的类型转换方法: + +```go +row := rows.Row(0) + +// 字符串 +name := row.MustString("name") + +// 整数 +age := row.MustInt("age") +id := row.MustInt64("id") +count := row.MustInt32("count") +uid := row.MustUint64("uid") + +// 浮点数 +price := row.MustFloat64("price") +rate := row.MustFloat32("rate") + +// 布尔值 +active := row.MustBool("active") + +// 字节数组 +data := row.MustBytes("data") + +// 时间 +createdAt := row.MustDate("created_at", "2006-01-02 15:04:05") + +// 检查 NULL +isNull := row.IsNil("optional_field") +``` + +### 列操作 + +```go +// 获取某一列的所有值 +col := rows.Col("name") + +names := col.MustString() // []string +ages := col.MustInt() // []int +prices := col.MustFloat64() // []float64 +actives := col.MustBool() // []bool +``` + +### Context 支持 + +所有操作都有对应的 Context 版本: + +```go +ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) +defer cancel() + +// 查询 +rows, err := db.QueryContext(ctx, "SELECT * FROM users") + +// 执行 +result, err := db.ExecContext(ctx, "UPDATE users SET age = ?", 26) + +// 事务 +tx, err := db.BeginTx(ctx, nil) + +// 批量插入 +result, err := db.BatchInsertContext(ctx, "users", columns, values) +``` + +### 错误处理 + +```go +// 使用 Must* 方法(忽略错误,返回零值) +name := row.MustString("name") + +// 使用 Get* 方法(返回错误) +name, err := row.GetString("name") +if err != nil { + // 处理错误 +} +``` + +## 🔧 高级用法 + +### 嵌套结构体 + +```go +type Profile struct { + Bio string `db:"bio"` + Avatar string `db:"avatar"` +} + +type User struct { + ID int64 `db:"id"` + Name string `db:"name"` + Profile Profile `db:"---"` // 使用 "---" 标记嵌套结构体 +} + +rows, _ := db.Query("SELECT id, name, bio, avatar FROM users") +defer rows.Close() + +var user User +rows.Orm(&user) + +println(user.Name) +println(user.Profile.Bio) +``` + +### 手动扫描模式 + +```go +db := stardb.NewStarDB() +db.ManualScan = true // 启用手动扫描模式 +db.Open("sqlite3", "test.db") + +rows, _ := db.Query("SELECT * FROM users") +defer rows.Close() + +// 手动触发解析 +rows.Rescan() + +// 现在可以访问数据 +println(rows.Length()) +``` + +### 直接访问底层 *sql.DB + +```go +db := stardb.NewStarDB() +db.Open("sqlite3", "test.db") + +// 获取底层 *sql.DB +rawDB := db.DB() +rawDB.SetMaxOpenConns(100) + +// 或者使用已有的 *sql.DB +sqlDB, _ := sql.Open("sqlite3", "test.db") +db := stardb.NewStarDBWithDB(sqlDB) +``` + +## 🎯 性能优化建议 + +### 1. 使用预编译语句 + +对于重复执行的查询,使用预编译语句可以提升 20-50% 的性能: + +```go +stmt, _ := db.Prepare("SELECT * FROM users WHERE id = ?") +defer stmt.Close() + +for _, id := range userIDs { + rows, _ := stmt.Query(id) + // 处理结果 + rows.Close() +} +``` + +### 2. 批量操作 + +批量插入比单条插入快 2-3 倍: + +```go +// ❌ 慢 +for _, user := range users { + db.Exec("INSERT INTO users (name) VALUES (?)", user.Name) +} + +// ✅ 快 +db.BatchInsertStructs("users", users, "id") +``` + +### 3. 合理配置连接池 + +```go +config := &stardb.PoolConfig{ + MaxOpenConns: 25, // 根据数据库服务器调整 + MaxIdleConns: 5, // 保持少量空闲连接 + ConnMaxLifetime: 1 * time.Hour, + ConnMaxIdleTime: 10 * time.Minute, +} +db.SetPoolConfig(config) +``` + +### 4. 使用事务 + +将多个操作放在一个事务中可以显著提升性能: + +```go +tx, _ := db.Begin() +for _, user := range users { + tx.Exec("INSERT INTO users (name) VALUES (?)", user.Name) +} +tx.Commit() +``` + +## 🧪 测试 + +项目包含完整的单元测试: + +```bash +# 运行所有测试 +cd testing +go test -v + +# 运行特定测试 +go test -v -run TestStarDB_Query + +# 查看覆盖率 +go test -cover + +# 生成覆盖率报告 +go test -coverprofile=coverage.out +go tool cover -html=coverage.out + +# 运行基准测试 +go test -bench=. -benchmem +``` + +## 📊 支持的数据库 + +StarDB 支持所有实现了 `database/sql` 接口的数据库驱动: + +| 数据库 | 驱动 | 导入 | +|--------|------|------| +| SQLite | modernc.org/sqlite | `_ "modernc.org/sqlite"` | +| MySQL | github.com/go-sql-driver/mysql | `_ "github.com/go-sql-driver/mysql"` | +| PostgreSQL | github.com/lib/pq | `_ "github.com/lib/pq"` | +| SQL Server | github.com/denisenkom/go-mssqldb | `_ "github.com/denisenkom/go-mssqldb"` | +| Oracle | github.com/sijms/go-ora/v2 | `_ "github.com/sijms/go-ora/v2"` | + + +## 📄 许可证 + +本项目采用 Apache 2.0 许可证 - 详见 [LICENSE](LICENSE) 文件 + +## 🙏 致谢 + +- 感谢 Go 标准库提供的 `database/sql` 包 +- 灵感来源于 xorm、gorm 等优秀的 ORM 框架 + +## 📮 联系方式 + +- 项目主页: https://git.b612.me/b612/stardb.git \ No newline at end of file diff --git a/batch.go b/batch.go new file mode 100644 index 0000000..3607b03 --- /dev/null +++ b/batch.go @@ -0,0 +1,112 @@ +package stardb + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "strings" +) + +// BatchInsert performs batch insert operation +// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}}) +func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) { + return s.batchInsert(nil, tableName, columns, values) +} + +// BatchInsertContext performs batch insert with context +func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) { + return s.batchInsert(ctx, tableName, columns, values) +} + +// batchInsert is the internal implementation +func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) { + if len(values) == 0 { + return nil, fmt.Errorf("no values to insert") + } + + // Build placeholders: (?, ?), (?, ?), ... + placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)" + placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup + + // Build SQL + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", + tableName, + strings.Join(columns, ", "), + placeholders) + + // Flatten values + var args []interface{} + for _, row := range values { + args = append(args, row...) + } + + return s.exec(ctx, query, args...) +} + +// BatchInsertStructs performs batch insert using structs +func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { + return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...) +} + +// BatchInsertStructsContext performs batch insert using structs with context +func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { + return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...) +} + +// batchInsertStructs is the internal implementation +func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) { + // Get slice of structs + targetValue := reflect.ValueOf(structs) + if targetValue.Kind() == reflect.Ptr { + targetValue = targetValue.Elem() + } + + if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array { + return nil, fmt.Errorf("structs must be a slice or array") + } + + if targetValue.Len() == 0 { + return nil, fmt.Errorf("no structs to insert") + } + + // Get field names from first struct + firstStruct := targetValue.Index(0).Interface() + fieldNames, err := getStructFieldNames(firstStruct, "db") + if err != nil { + return nil, err + } + + // Filter out auto-increment fields + var columns []string + for _, fieldName := range fieldNames { + isAutoIncrement := false + for _, autoField := range autoIncrementFields { + if fieldName == autoField { + isAutoIncrement = true + break + } + } + if !isAutoIncrement { + columns = append(columns, fieldName) + } + } + + // Extract values from all structs + var values [][]interface{} + for i := 0; i < targetValue.Len(); i++ { + structVal := targetValue.Index(i).Interface() + fieldValues, err := getStructFieldValues(structVal, "db") + if err != nil { + return nil, err + } + + var row []interface{} + for _, col := range columns { + row = append(row, fieldValues[col]) + } + values = append(values, row) + } + + return s.batchInsert(ctx, tableName, columns, values) +} diff --git a/builder.go b/builder.go new file mode 100644 index 0000000..7cf5762 --- /dev/null +++ b/builder.go @@ -0,0 +1,93 @@ +package stardb + +import ( + "fmt" + "strings" +) + +// QueryBuilder helps build SQL queries +type QueryBuilder struct { + table string + columns []string + where []string + whereArgs []interface{} + orderBy string + limit int + offset int +} + +// NewQueryBuilder creates a new query builder +func NewQueryBuilder(table string) *QueryBuilder { + return &QueryBuilder{ + table: table, + columns: []string{"*"}, + } +} + +// Select sets the columns to select +func (qb *QueryBuilder) Select(columns ...string) *QueryBuilder { + qb.columns = columns + return qb +} + +// Where adds a WHERE condition +func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuilder { + qb.where = append(qb.where, condition) + qb.whereArgs = append(qb.whereArgs, args...) + return qb +} + +// OrderBy sets the ORDER BY clause +func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder { + qb.orderBy = orderBy + return qb +} + +// Limit sets the LIMIT +func (qb *QueryBuilder) Limit(limit int) *QueryBuilder { + qb.limit = limit + return qb +} + +// Offset sets the OFFSET +func (qb *QueryBuilder) Offset(offset int) *QueryBuilder { + qb.offset = offset + return qb +} + +// Build builds the SQL query and returns query string and args +func (qb *QueryBuilder) Build() (string, []interface{}) { + var parts []string + + // SELECT + parts = append(parts, fmt.Sprintf("SELECT %s FROM %s", + strings.Join(qb.columns, ", "), qb.table)) + + // WHERE + if len(qb.where) > 0 { + parts = append(parts, "WHERE "+strings.Join(qb.where, " AND ")) + } + + // ORDER BY + if qb.orderBy != "" { + parts = append(parts, "ORDER BY "+qb.orderBy) + } + + // LIMIT + if qb.limit > 0 { + parts = append(parts, fmt.Sprintf("LIMIT %d", qb.limit)) + } + + // OFFSET + if qb.offset > 0 { + parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset)) + } + + return strings.Join(parts, " "), qb.whereArgs +} + +// Query executes the query +func (qb *QueryBuilder) Query(db *StarDB) (*StarRows, error) { + query, args := qb.Build() + return db.Query(query, args...) +} diff --git a/builder_test.go b/builder_test.go new file mode 100644 index 0000000..c838553 --- /dev/null +++ b/builder_test.go @@ -0,0 +1,462 @@ +package stardb + +import ( + "reflect" + "testing" +) + +func TestNewQueryBuilder(t *testing.T) { + qb := NewQueryBuilder("users") + + if qb.table != "users" { + t.Errorf("Expected table 'users', got '%s'", qb.table) + } + + if len(qb.columns) != 1 || qb.columns[0] != "*" { + t.Errorf("Expected default columns ['*'], got %v", qb.columns) + } +} + +func TestQueryBuilder_Select(t *testing.T) { + qb := NewQueryBuilder("users").Select("id", "name", "email") + + expected := []string{"id", "name", "email"} + if !reflect.DeepEqual(qb.columns, expected) { + t.Errorf("Expected columns %v, got %v", expected, qb.columns) + } +} + +func TestQueryBuilder_Where(t *testing.T) { + qb := NewQueryBuilder("users"). + Where("age > ?", 18). + Where("active = ?", true) + + if len(qb.where) != 2 { + t.Errorf("Expected 2 where conditions, got %d", len(qb.where)) + } + + if qb.where[0] != "age > ?" { + t.Errorf("Expected first where 'age > ?', got '%s'", qb.where[0]) + } + + if qb.where[1] != "active = ?" { + t.Errorf("Expected second where 'active = ?', got '%s'", qb.where[1]) + } + + if len(qb.whereArgs) != 2 { + t.Errorf("Expected 2 where args, got %d", len(qb.whereArgs)) + } + + if qb.whereArgs[0] != 18 { + t.Errorf("Expected first arg 18, got %v", qb.whereArgs[0]) + } + + if qb.whereArgs[1] != true { + t.Errorf("Expected second arg true, got %v", qb.whereArgs[1]) + } +} + +func TestQueryBuilder_OrderBy(t *testing.T) { + qb := NewQueryBuilder("users").OrderBy("name ASC") + + if qb.orderBy != "name ASC" { + t.Errorf("Expected orderBy 'name ASC', got '%s'", qb.orderBy) + } +} + +func TestQueryBuilder_Limit(t *testing.T) { + qb := NewQueryBuilder("users").Limit(10) + + if qb.limit != 10 { + t.Errorf("Expected limit 10, got %d", qb.limit) + } +} + +func TestQueryBuilder_Offset(t *testing.T) { + qb := NewQueryBuilder("users").Offset(20) + + if qb.offset != 20 { + t.Errorf("Expected offset 20, got %d", qb.offset) + } +} + +func TestQueryBuilder_Build_Simple(t *testing.T) { + qb := NewQueryBuilder("users") + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_Build_WithSelect(t *testing.T) { + qb := NewQueryBuilder("users").Select("id", "name", "email") + query, args := qb.Build() + + expectedQuery := "SELECT id, name, email FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_Build_WithWhere(t *testing.T) { + qb := NewQueryBuilder("users"). + Where("age > ?", 18). + Where("active = ?", true) + + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users WHERE age > ? AND active = ?" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 2 { + t.Errorf("Expected 2 args, got %d", len(args)) + } + + if args[0] != 18 { + t.Errorf("Expected first arg 18, got %v", args[0]) + } + + if args[1] != true { + t.Errorf("Expected second arg true, got %v", args[1]) + } +} + +func TestQueryBuilder_Build_WithOrderBy(t *testing.T) { + qb := NewQueryBuilder("users").OrderBy("created_at DESC") + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users ORDER BY created_at DESC" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_Build_WithLimit(t *testing.T) { + qb := NewQueryBuilder("users").Limit(10) + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users LIMIT 10" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_Build_WithOffset(t *testing.T) { + qb := NewQueryBuilder("users").Offset(20) + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users OFFSET 20" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_Build_WithLimitAndOffset(t *testing.T) { + qb := NewQueryBuilder("users").Limit(10).Offset(20) + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users LIMIT 10 OFFSET 20" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_Build_Complex(t *testing.T) { + qb := NewQueryBuilder("users"). + Select("id", "name", "email", "age"). + Where("age > ?", 18). + Where("active = ?", true). + Where("country = ?", "US"). + OrderBy("name ASC"). + Limit(10). + Offset(20) + + query, args := qb.Build() + + expectedQuery := "SELECT id, name, email, age FROM users WHERE age > ? AND active = ? AND country = ? ORDER BY name ASC LIMIT 10 OFFSET 20" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + expectedArgs := []interface{}{18, true, "US"} + if len(args) != len(expectedArgs) { + t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args)) + } + + for i, expected := range expectedArgs { + if args[i] != expected { + t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i]) + } + } +} + +func TestQueryBuilder_Build_MultipleWhere(t *testing.T) { + qb := NewQueryBuilder("orders"). + Where("user_id = ?", 123). + Where("status IN (?, ?, ?)", "pending", "processing", "shipped"). + Where("created_at > ?", "2024-01-01") + + query, args := qb.Build() + + expectedQuery := "SELECT * FROM orders WHERE user_id = ? AND status IN (?, ?, ?) AND created_at > ?" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + if len(args) != 5 { + t.Errorf("Expected 5 args, got %d", len(args)) + } + + expectedArgs := []interface{}{123, "pending", "processing", "shipped", "2024-01-01"} + for i, expected := range expectedArgs { + if args[i] != expected { + t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i]) + } + } +} + +func TestQueryBuilder_Chaining(t *testing.T) { + // Test method chaining returns the same builder + qb := NewQueryBuilder("users") + + qb2 := qb.Select("id", "name") + if qb != qb2 { + t.Error("Select should return the same builder instance") + } + + qb3 := qb.Where("age > ?", 18) + if qb != qb3 { + t.Error("Where should return the same builder instance") + } + + qb4 := qb.OrderBy("name ASC") + if qb != qb4 { + t.Error("OrderBy should return the same builder instance") + } + + qb5 := qb.Limit(10) + if qb != qb5 { + t.Error("Limit should return the same builder instance") + } + + qb6 := qb.Offset(20) + if qb != qb6 { + t.Error("Offset should return the same builder instance") + } +} + +func TestQueryBuilder_EmptyWhere(t *testing.T) { + qb := NewQueryBuilder("users").Select("id", "name") + query, args := qb.Build() + + expectedQuery := "SELECT id, name FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_OnlyLimit(t *testing.T) { + qb := NewQueryBuilder("users").Limit(5) + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users LIMIT 5" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_OnlyOffset(t *testing.T) { + qb := NewQueryBuilder("users").Offset(10) + query, args := qb.Build() + + expectedQuery := "SELECT * FROM users OFFSET 10" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_ZeroLimit(t *testing.T) { + qb := NewQueryBuilder("users").Limit(0) + query, args := qb.Build() + + // Limit 0 should not be included + expectedQuery := "SELECT * FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_ZeroOffset(t *testing.T) { + qb := NewQueryBuilder("users").Offset(0) + query, args := qb.Build() + + // Offset 0 should not be included + expectedQuery := "SELECT * FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_MultipleSelect(t *testing.T) { + // Last Select should override previous ones + qb := NewQueryBuilder("users"). + Select("id", "name"). + Select("email", "age") + + query, _ := qb.Build() + + expectedQuery := "SELECT email, age FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } +} + +func TestQueryBuilder_MultipleOrderBy(t *testing.T) { + // Last OrderBy should override previous ones + qb := NewQueryBuilder("users"). + OrderBy("name ASC"). + OrderBy("created_at DESC") + + query, _ := qb.Build() + + expectedQuery := "SELECT * FROM users ORDER BY created_at DESC" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } +} + +func TestQueryBuilder_ComplexWhereConditions(t *testing.T) { + qb := NewQueryBuilder("products"). + Select("id", "name", "price"). + Where("category_id = ?", 5). + Where("price BETWEEN ? AND ?", 10.0, 100.0). + Where("stock > ?", 0). + Where("name LIKE ?", "%phone%"). + OrderBy("price DESC"). + Limit(20) + + query, args := qb.Build() + + expectedQuery := "SELECT id, name, price FROM products WHERE category_id = ? AND price BETWEEN ? AND ? AND stock > ? AND name LIKE ? ORDER BY price DESC LIMIT 20" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + expectedArgs := []interface{}{5, 10.0, 100.0, 0, "%phone%"} + if len(args) != len(expectedArgs) { + t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args)) + } + + for i, expected := range expectedArgs { + if args[i] != expected { + t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i]) + } + } +} + +func TestQueryBuilder_SingleColumn(t *testing.T) { + qb := NewQueryBuilder("users").Select("COUNT(*)") + query, args := qb.Build() + + expectedQuery := "SELECT COUNT(*) FROM users" + if query != expectedQuery { + t.Errorf("Expected query '%s', got '%s'", expectedQuery, query) + } + + if len(args) != 0 { + t.Errorf("Expected 0 args, got %d", len(args)) + } +} + +func TestQueryBuilder_JoinLikeWhere(t *testing.T) { + // Test that WHERE can handle JOIN-like conditions + qb := NewQueryBuilder("users"). + Select("users.id", "users.name", "orders.total"). + Where("users.id = orders.user_id"). + Where("orders.status = ?", "completed") + + query, args := qb.Build() + + expectedQuery := "SELECT users.id, users.name, orders.total FROM users WHERE users.id = orders.user_id AND orders.status = ?" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + if len(args) != 1 { + t.Errorf("Expected 1 arg, got %d", len(args)) + } + + if args[0] != "completed" { + t.Errorf("Expected arg 'completed', got %v", args[0]) + } +} + +// Benchmark tests +func BenchmarkQueryBuilder_Simple(b *testing.B) { + for i := 0; i < b.N; i++ { + qb := NewQueryBuilder("users") + qb.Build() + } +} + +func BenchmarkQueryBuilder_Complex(b *testing.B) { + for i := 0; i < b.N; i++ { + qb := NewQueryBuilder("users"). + Select("id", "name", "email", "age"). + Where("age > ?", 18). + Where("active = ?", true). + Where("country = ?", "US"). + OrderBy("name ASC"). + Limit(10). + Offset(20) + qb.Build() + } +} diff --git a/converter.go b/converter.go new file mode 100644 index 0000000..1e0020e --- /dev/null +++ b/converter.go @@ -0,0 +1,176 @@ +package stardb + +import ( + "strconv" + "time" +) + +// convertToInt64 converts any value to int64 +func convertToInt64(val interface{}) int64 { + switch v := val.(type) { + case nil: + return 0 + case int: + return int64(v) + case int32: + return int64(v) + case int64: + return v + case uint64: + return int64(v) + case float32: + return int64(v) + case float64: + return int64(v) + case string: + result, _ := strconv.ParseInt(v, 10, 64) + return result + case bool: + if v { + return 1 + } + return 0 + case time.Time: + return v.Unix() + case []byte: + result, _ := strconv.ParseInt(string(v), 10, 64) + return result + default: + return 0 + } +} + +// convertToUint64 converts any value to uint64 +func convertToUint64(val interface{}) uint64 { + switch v := val.(type) { + case nil: + return 0 + case int: + return uint64(v) + case int32: + return uint64(v) + case int64: + return uint64(v) + case uint64: + return v + case float32: + return uint64(v) + case float64: + return uint64(v) + case string: + result, _ := strconv.ParseUint(v, 10, 64) + return result + case bool: + if v { + return 1 + } + return 0 + case time.Time: + return uint64(v.Unix()) + case []byte: + result, _ := strconv.ParseUint(string(v), 10, 64) + return result + default: + return 0 + } +} + +// convertToFloat64 converts any value to float64 +func convertToFloat64(val interface{}) float64 { + switch v := val.(type) { + case nil: + return 0 + case int: + return float64(v) + case int32: + return float64(v) + case int64: + return float64(v) + case uint64: + return float64(v) + case float32: + return float64(v) + case float64: + return v + case string: + result, _ := strconv.ParseFloat(v, 64) + return result + case bool: + if v { + return 1 + } + return 0 + case time.Time: + return float64(v.Unix()) + case []byte: + result, _ := strconv.ParseFloat(string(v), 64) + return result + default: + return 0 + } +} + +// convertToBool converts any value to bool +// Non-zero numbers are considered true +func convertToBool(val interface{}) bool { + switch v := val.(type) { + case nil: + return false + case bool: + return v + case int: + return v != 0 + case int32: + return v != 0 + case int64: + return v != 0 + case uint64: + return v != 0 + case float32: + return v != 0 + case float64: + return v != 0 + case string: + result, _ := strconv.ParseBool(v) + return result + case []byte: + result, _ := strconv.ParseBool(string(v)) + return result + default: + return false + } +} + +// convertToTime converts any value to time.Time +func convertToTime(val interface{}, layout string) time.Time { + switch v := val.(type) { + case nil: + return time.Time{} + case time.Time: + return v + case int: + return time.Unix(int64(v), 0) + case int32: + return time.Unix(int64(v), 0) + case int64: + return time.Unix(v, 0) + case uint64: + return time.Unix(int64(v), 0) + case float32: + sec := int64(v) + nsec := int64((v - float32(sec)) * 1e9) + return time.Unix(sec, nsec) + case float64: + sec := int64(v) + nsec := int64((v - float64(sec)) * 1e9) + return time.Unix(sec, nsec) + case string: + result, _ := time.Parse(layout, v) + return result + case []byte: + result, _ := time.Parse(layout, string(v)) + return result + default: + return time.Time{} + } +} diff --git a/converter_safe.go b/converter_safe.go new file mode 100644 index 0000000..9793c82 --- /dev/null +++ b/converter_safe.go @@ -0,0 +1,68 @@ +package stardb + +import ( + "fmt" + "strconv" + "time" +) + +// ConvertToInt64Safe converts any value to int64 with error handling +func ConvertToInt64Safe(val interface{}) (int64, error) { + switch v := val.(type) { + case nil: + return 0, nil + case int: + return int64(v), nil + case int32: + return int64(v), nil + case int64: + return v, nil + case uint64: + return int64(v), nil + case float32: + return int64(v), nil + case float64: + return int64(v), nil + case string: + return strconv.ParseInt(v, 10, 64) + case bool: + if v { + return 1, nil + } + return 0, nil + case time.Time: + return v.Unix(), nil + case []byte: + return strconv.ParseInt(string(v), 10, 64) + default: + return 0, fmt.Errorf("cannot convert %T to int64", val) + } +} + +// ConvertToStringSafe converts any value to string with error handling +func ConvertToStringSafe(val interface{}) (string, error) { + switch v := val.(type) { + case nil: + return "", nil + case string: + return v, nil + case int: + return strconv.Itoa(v), nil + case int32: + return strconv.FormatInt(int64(v), 10), nil + case int64: + return strconv.FormatInt(v, 10), nil + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32), nil + case float64: + return strconv.FormatFloat(v, 'f', -1, 64), nil + case bool: + return strconv.FormatBool(v), nil + case time.Time: + return v.String(), nil + case []byte: + return string(v), nil + default: + return "", fmt.Errorf("cannot convert %T to string", val) + } +} diff --git a/converter_test.go b/converter_test.go new file mode 100644 index 0000000..ff7a772 --- /dev/null +++ b/converter_test.go @@ -0,0 +1,183 @@ +package stardb + +import ( + "testing" + "time" +) + +func TestConvertToInt64(t *testing.T) { + tests := []struct { + name string + input interface{} + expected int64 + }{ + {"nil", nil, 0}, + {"int", 42, 42}, + {"int32", int32(42), 42}, + {"int64", int64(42), 42}, + {"uint64", uint64(42), 42}, + {"float32", float32(42.7), 42}, + {"float64", float64(42.7), 42}, + {"string", "42", 42}, + {"bool true", true, 1}, + {"bool false", false, 0}, + {"bytes", []byte("42"), 42}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToInt64(tt.input) + if result != tt.expected { + t.Errorf("Expected %d, got %d", tt.expected, result) + } + }) + } +} + +func TestConvertToUint64(t *testing.T) { + tests := []struct { + name string + input interface{} + expected uint64 + }{ + {"nil", nil, 0}, + {"int", 42, 42}, + {"int32", int32(42), 42}, + {"int64", int64(42), 42}, + {"uint64", uint64(42), 42}, + {"float32", float32(42.7), 42}, + {"float64", float64(42.7), 42}, + {"string", "42", 42}, + {"bool true", true, 1}, + {"bool false", false, 0}, + {"bytes", []byte("42"), 42}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToUint64(tt.input) + if result != tt.expected { + t.Errorf("Expected %d, got %d", tt.expected, result) + } + }) + } +} + +func TestConvertToFloat64(t *testing.T) { + tests := []struct { + name string + input interface{} + expected float64 + }{ + {"nil", nil, 0.0}, + {"int", 42, 42.0}, + {"int32", int32(42), 42.0}, + {"int64", int64(42), 42.0}, + {"uint64", uint64(42), 42.0}, + {"float32", float32(42.5), 42.5}, + {"float64", float64(42.5), 42.5}, + {"string", "42.5", 42.5}, + {"bool true", true, 1.0}, + {"bool false", false, 0.0}, + {"bytes", []byte("42.5"), 42.5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToFloat64(tt.input) + if result != tt.expected { + t.Errorf("Expected %f, got %f", tt.expected, result) + } + }) + } +} + +func TestConvertToBool(t *testing.T) { + tests := []struct { + name string + input interface{} + expected bool + }{ + {"nil", nil, false}, + {"bool true", true, true}, + {"bool false", false, false}, + {"int positive", 1, true}, + {"int zero", 0, false}, + {"int negative", -1, true}, + {"float positive", 1.5, true}, + {"float zero", 0.0, false}, + {"string true", "true", true}, + {"string false", "false", false}, + {"string 1", "1", true}, + {"string 0", "0", false}, + {"bytes true", []byte("true"), true}, + {"bytes false", []byte("false"), false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToBool(tt.input) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestConvertToString(t *testing.T) { + now := time.Now() + tests := []struct { + name string + input interface{} + expected string + }{ + {"nil", nil, ""}, + {"string", "hello", "hello"}, + {"int", 42, "42"}, + {"int32", int32(42), "42"}, + {"int64", int64(42), "42"}, + {"float32", float32(42.5), "42.5"}, + {"float64", float64(42.5), "42.5"}, + {"bool true", true, "true"}, + {"bool false", false, "false"}, + {"time", now, now.String()}, + {"bytes", []byte("hello"), "hello"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToString(tt.input) + if result != tt.expected { + t.Errorf("Expected '%s', got '%s'", tt.expected, result) + } + }) + } +} + +func TestConvertToTime(t *testing.T) { + layout := "2006-01-02 15:04:05" + timeStr := "2024-01-01 10:00:00" + expectedTime, _ := time.Parse(layout, timeStr) + + tests := []struct { + name string + input interface{} + layout string + expected time.Time + }{ + {"nil", nil, layout, time.Time{}}, + {"time.Time", expectedTime, layout, expectedTime}, + {"int64 unix", int64(1704103200), layout, time.Unix(1704103200, 0)}, + {"string", timeStr, layout, expectedTime}, + {"bytes", []byte(timeStr), layout, expectedTime}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertToTime(tt.input, tt.layout) + if !result.Equal(tt.expected) { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/go.mod b/go.mod index bc39add..80cd1f0 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module b612.me/stardb -go 1.16 +go 1.16 \ No newline at end of file diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..684933a --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk= +github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/orm.go b/orm.go new file mode 100644 index 0000000..2453b88 --- /dev/null +++ b/orm.go @@ -0,0 +1,563 @@ +package stardb + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "strings" +) + +// Orm maps query results to a struct or slice of structs +// Usage: +// +// var user User +// rows.Orm(&user) // single row +// var users []User +// rows.Orm(&users) // multiple rows +func (r *StarRows) Orm(target interface{}) error { + if !r.parsed { + if err := r.parse(); err != nil { + return err + } + } + + targetType := reflect.TypeOf(target) + targetValue := reflect.ValueOf(target) + + if targetType.Kind() != reflect.Ptr { + return errors.New("target must be a pointer") + } + + targetType = targetType.Elem() + targetValue = targetValue.Elem() + + // Handle slice/array + if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + elementType := targetType.Elem() + result := reflect.New(targetType).Elem() + + if r.Length() == 0 { + targetValue.Set(result) + return nil + } + + for i := 0; i < r.Length(); i++ { + element := reflect.New(elementType) + if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil { + return err + } + result = reflect.Append(result, element.Elem()) + } + + targetValue.Set(result) + return nil + } + + // Handle single struct + if r.Length() == 0 { + return nil + } + + return r.setStructFieldsFromRow(target, "db", 0) +} + +// QueryX executes a query with named parameter binding +// Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id") +func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) { + return s.queryX(nil, target, query, args...) +} + +// QueryXContext executes a query with context and named parameter binding +func (s *StarDB) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) { + return s.queryX(ctx, target, query, args...) +} + +// queryX is the internal implementation +func (s *StarDB) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) { + fieldValues, err := getStructFieldValues(target, "db") + if err != nil { + return nil, err + } + + // Replace named parameters with actual values + processedArgs := make([]interface{}, len(args)) + for i, arg := range args { + if str, ok := arg.(string); ok { + if strings.HasPrefix(str, ":") { + fieldName := str[1:] + if val, exists := fieldValues[fieldName]; exists { + processedArgs[i] = val + } else { + processedArgs[i] = "" + } + } else if strings.HasPrefix(str, `\:`) { + processedArgs[i] = str[1:] + } else { + processedArgs[i] = arg + } + } else { + processedArgs[i] = arg + } + } + + return s.query(ctx, query, processedArgs...) +} + +// QueryXS executes queries for multiple structs +func (s *StarDB) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { + return s.queryXS(nil, targets, query, args...) +} + +// QueryXSContext executes queries with context for multiple structs +func (s *StarDB) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { + return s.queryXS(ctx, targets, query, args...) +} + +// queryXS is the internal implementation +func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { + var results []*StarRows + + targetType := reflect.TypeOf(targets) + targetValue := reflect.ValueOf(targets) + + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + targetValue = targetValue.Elem() + } + + if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + for i := 0; i < targetValue.Len(); i++ { + result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + } else { + result, err := s.queryX(ctx, targets, query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + + return results, nil +} + +// ExecX executes a statement with named parameter binding +func (s *StarDB) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) { + return s.execX(nil, target, query, args...) +} + +// ExecXContext executes a statement with context and named parameter binding +func (s *StarDB) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) { + return s.execX(ctx, target, query, args...) +} + +// execX is the internal implementation +func (s *StarDB) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) { + fieldValues, err := getStructFieldValues(target, "db") + if err != nil { + return nil, err + } + + // Replace named parameters with actual values + processedArgs := make([]interface{}, len(args)) + for i, arg := range args { + if str, ok := arg.(string); ok { + if strings.HasPrefix(str, ":") { + fieldName := str[1:] + if val, exists := fieldValues[fieldName]; exists { + processedArgs[i] = val + } else { + processedArgs[i] = "" + } + } else if strings.HasPrefix(str, `\:`) { + processedArgs[i] = str[1:] + } else { + processedArgs[i] = arg + } + } else { + processedArgs[i] = arg + } + } + + return s.exec(ctx, query, processedArgs...) +} + +// ExecXS executes statements for multiple structs +func (s *StarDB) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { + return s.execXS(nil, targets, query, args...) +} + +// ExecXSContext executes statements with context for multiple structs +func (s *StarDB) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { + return s.execXS(ctx, targets, query, args...) +} + +// execXS is the internal implementation +func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { + var results []sql.Result + + targetType := reflect.TypeOf(targets) + targetValue := reflect.ValueOf(targets) + + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + targetValue = targetValue.Elem() + } + + if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + for i := 0; i < targetValue.Len(); i++ { + result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + } else { + result, err := s.execX(ctx, targets, query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + + return results, nil +} + +// Insert inserts a struct into the database +// Usage: Insert(&user, "users", "id") // id is auto-increment +func (s *StarDB) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) { + return s.insert(nil, target, tableName, autoIncrementFields...) +} + +// InsertContext inserts a struct with context +func (s *StarDB) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) { + return s.insert(ctx, target, tableName, autoIncrementFields...) +} + +// insert is the internal implementation +func (s *StarDB) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) { + query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...) + if err != nil { + return nil, err + } + + args := []interface{}{} + for _, param := range params { + args = append(args, param) + } + + return s.execX(ctx, target, query, args...) +} + +// Update updates a struct in the database +// Usage: Update(&user, "users", "id") // id is primary key +func (s *StarDB) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) { + return s.update(nil, target, tableName, primaryKeys...) +} + +// UpdateContext updates a struct with context +func (s *StarDB) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) { + return s.update(ctx, target, tableName, primaryKeys...) +} + +// update is the internal implementation +func (s *StarDB) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) { + query, params, err := buildUpdateSQL(target, tableName, primaryKeys...) + if err != nil { + return nil, err + } + + args := []interface{}{} + for _, param := range params { + args = append(args, param) + } + + return s.execX(ctx, target, query, args...) +} + +// buildInsertSQL builds an INSERT SQL statement +func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) { + fieldNames, err := getStructFieldNames(target, "db") + if err != nil { + return "", []string{}, err + } + + var columns []string + var placeholders []string + var params []string + + for _, fieldName := range fieldNames { + // Skip auto-increment fields + isAutoIncrement := false + for _, autoField := range autoIncrementFields { + if fieldName == autoField { + isAutoIncrement = true + break + } + } + if isAutoIncrement { + continue + } + + columns = append(columns, fieldName) + placeholders = append(placeholders, "?") + params = append(params, ":"+fieldName) + } + + query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", + tableName, + strings.Join(columns, ", "), + strings.Join(placeholders, ", ")) + + return query, params, nil +} + +// buildUpdateSQL builds an UPDATE SQL statement +func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) { + fieldNames, err := getStructFieldNames(target, "db") + if err != nil { + return "", []string{}, err + } + + var setClauses []string + var params []string + + // Build SET clause + for _, fieldName := range fieldNames { + setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName)) + params = append(params, ":"+fieldName) + } + + // Build WHERE clause + var whereClauses []string + for _, pk := range primaryKeys { + whereClauses = append(whereClauses, fmt.Sprintf("%s = ?", pk)) + params = append(params, ":"+pk) + } + + query := fmt.Sprintf("UPDATE %s SET %s WHERE %s", + tableName, + strings.Join(setClauses, ", "), + strings.Join(whereClauses, " AND ")) + + return query, params, nil +} + +// Transaction ORM methods + +// QueryX executes a query with named parameter binding in a transaction +func (t *StarTx) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) { + return t.queryX(nil, target, query, args...) +} + +// QueryXContext executes a query with context in a transaction +func (t *StarTx) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) { + return t.queryX(ctx, target, query, args...) +} + +// queryX is the internal implementation +func (t *StarTx) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) { + fieldValues, err := getStructFieldValues(target, "db") + if err != nil { + return nil, err + } + + processedArgs := make([]interface{}, len(args)) + for i, arg := range args { + if str, ok := arg.(string); ok { + if strings.HasPrefix(str, ":") { + fieldName := str[1:] + if val, exists := fieldValues[fieldName]; exists { + processedArgs[i] = val + } else { + processedArgs[i] = "" + } + } else if strings.HasPrefix(str, `\:`) { + processedArgs[i] = str[1:] + } else { + processedArgs[i] = arg + } + } else { + processedArgs[i] = arg + } + } + + return t.query(ctx, query, processedArgs...) +} + +// QueryXS executes queries for multiple structs in a transaction +func (t *StarTx) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { + return t.queryXS(nil, targets, query, args...) +} + +// QueryXSContext executes queries with context for multiple structs in a transaction +func (t *StarTx) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { + return t.queryXS(ctx, targets, query, args...) +} + +// queryXS is the internal implementation +func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) { + var results []*StarRows + + targetType := reflect.TypeOf(targets) + targetValue := reflect.ValueOf(targets) + + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + targetValue = targetValue.Elem() + } + + if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + for i := 0; i < targetValue.Len(); i++ { + result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + } else { + result, err := t.queryX(ctx, targets, query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + + return results, nil +} + +// ExecX executes a statement with named parameter binding in a transaction +func (t *StarTx) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) { + return t.execX(nil, target, query, args...) +} + +// ExecXContext executes a statement with context in a transaction +func (t *StarTx) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) { + return t.execX(ctx, target, query, args...) +} + +// execX is the internal implementation +func (t *StarTx) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) { + fieldValues, err := getStructFieldValues(target, "db") + if err != nil { + return nil, err + } + + processedArgs := make([]interface{}, len(args)) + for i, arg := range args { + if str, ok := arg.(string); ok { + if strings.HasPrefix(str, ":") { + fieldName := str[1:] + if val, exists := fieldValues[fieldName]; exists { + processedArgs[i] = val + } else { + processedArgs[i] = "" + } + } else if strings.HasPrefix(str, `\:`) { + processedArgs[i] = str[1:] + } else { + processedArgs[i] = arg + } + } else { + processedArgs[i] = arg + } + } + + return t.exec(ctx, query, processedArgs...) +} + +// ExecXS executes statements for multiple structs in a transaction +func (t *StarTx) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { + return t.execXS(nil, targets, query, args...) +} + +// ExecXSContext executes statements with context for multiple structs in a transaction +func (t *StarTx) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { + return t.execXS(ctx, targets, query, args...) +} + +// execXS is the internal implementation +func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) { + var results []sql.Result + + targetType := reflect.TypeOf(targets) + targetValue := reflect.ValueOf(targets) + + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + targetValue = targetValue.Elem() + } + + if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array { + for i := 0; i < targetValue.Len(); i++ { + result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + } else { + result, err := t.execX(ctx, targets, query, args...) + if err != nil { + return results, err + } + results = append(results, result) + } + + return results, nil +} + +// Insert inserts a struct in a transaction +func (t *StarTx) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) { + return t.insert(nil, target, tableName, autoIncrementFields...) +} + +// InsertContext inserts a struct with context in a transaction +func (t *StarTx) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) { + return t.insert(ctx, target, tableName, autoIncrementFields...) +} + +// insert is the internal implementation +func (t *StarTx) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) { + query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...) + if err != nil { + return nil, err + } + + args := []interface{}{} + for _, param := range params { + args = append(args, param) + } + + return t.execX(ctx, target, query, args...) +} + +// Update updates a struct in a transaction +func (t *StarTx) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) { + return t.update(nil, target, tableName, primaryKeys...) +} + +// UpdateContext updates a struct with context in a transaction +func (t *StarTx) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) { + return t.update(ctx, target, tableName, primaryKeys...) +} + +// update is the internal implementation +func (t *StarTx) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) { + query, params, err := buildUpdateSQL(target, tableName, primaryKeys...) + if err != nil { + return nil, err + } + + args := []interface{}{} + for _, param := range params { + args = append(args, param) + } + + return t.execX(ctx, target, query, args...) +} diff --git a/orm_test.go b/orm_test.go new file mode 100644 index 0000000..61c684c --- /dev/null +++ b/orm_test.go @@ -0,0 +1,88 @@ +package stardb + +import ( + "testing" + "time" +) + +type User struct { + ID int64 `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Age int `db:"age"` + Balance float64 `db:"balance"` + Active bool `db:"active"` + CreatedAt time.Time `db:"created_at"` +} + +type Profile struct { + UserID int `db:"user_id"` + Bio string `db:"bio"` + Avatar string `db:"avatar"` +} + +type NestedUser struct { + ID int64 `db:"id"` + Name string `db:"name"` + Profile `db:"---"` +} + +func TestBuildInsertSQL(t *testing.T) { + user := User{ + ID: 1, + Name: "Test", + Email: "test@example.com", + Age: 30, + Balance: 100.0, + Active: true, + CreatedAt: time.Now(), + } + + query, params, err := buildInsertSQL(&user, "users", "id") + if err != nil { + t.Fatalf("buildInsertSQL failed: %v", err) + } + + expectedQuery := "INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + expectedParams := []string{":name", ":email", ":age", ":balance", ":active", ":created_at"} + if len(params) != len(expectedParams) { + t.Errorf("Expected %d params, got %d", len(expectedParams), len(params)) + } + + for i, param := range params { + if param != expectedParams[i] { + t.Errorf("Expected param %s, got %s", expectedParams[i], param) + } + } +} + +func TestBuildUpdateSQL(t *testing.T) { + user := User{ + ID: 1, + Name: "Test", + Email: "test@example.com", + Age: 30, + Balance: 100.0, + Active: true, + CreatedAt: time.Now(), + } + + query, params, err := buildUpdateSQL(&user, "users", "id") + if err != nil { + t.Fatalf("buildUpdateSQL failed: %v", err) + } + + expectedQuery := "UPDATE users SET id = ?, name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?" + if query != expectedQuery { + t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query) + } + + expectedParamCount := 8 // 7 fields + 1 primary key + if len(params) != expectedParamCount { + t.Errorf("Expected %d params, got %d", expectedParamCount, len(params)) + } +} diff --git a/orm_v1.go b/orm_v1.go deleted file mode 100644 index 961c24c..0000000 --- a/orm_v1.go +++ /dev/null @@ -1,511 +0,0 @@ -package stardb - -import ( - "context" - "database/sql" - "errors" - "fmt" - "reflect" - "strings" -) - -func (star *StarRows) Orm(ins interface{}) error { - //check if is slice - if !star.parsed { - if err := star.parserows(); err != nil { - return err - } - } - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() != reflect.Ptr { - return errors.New("interface not writable") - } - //now convert to slice - t = t.Elem() - v = v.Elem() - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - //get type of slice - sigType := t.Elem() - var result reflect.Value - result = reflect.New(t).Elem() - if star.Length == 0 { - v.Set(result) - return nil - } - for i := 0; i < star.Length; i++ { - val := reflect.New(sigType) - star.setAllRefValue(val.Interface(), "db", i) - result = reflect.Append(result, val.Elem()) - } - v.Set(result) - return nil - } - if star.Length == 0 { - return nil - } - return star.setAllRefValue(ins, "db", 0) -} - -func (star *StarDB) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { - kvMap, err := getAllRefValue(ins, "db") - if err != nil { - return nil, err - } - for k, v := range args { - if k == 0 { - continue - } - switch v.(type) { - case string: - str := v.(string) - if strings.Index(str, ":") == 0 { - if _, ok := kvMap[str[1:]]; ok { - args[k] = kvMap[str[1:]] - } else { - args[k] = "" - } - continue - } - if strings.Index(str, `\:`) == 0 { - args[k] = kvMap[str[1:]] - } - } - } - return star.query(ctx, args...) -} -func (star *StarDB) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) { - return star.queryX(nil, ins, args) -} -func (star *StarDB) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) { - var starRes []*StarRows - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.queryX(nil, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.queryX(nil, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarDB) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) { - var starRes []sql.Result - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.execX(nil, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.execX(nil, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarDB) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) { - return star.execX(nil, ins, args...) -} - -func getUpdateSentence(ins interface{}, sheetName string, primaryKey ...string) (string, []string, error) { - Keys, err := getAllRefKey(ins, "db") - if err != nil { - return "", []string{}, err - } - var mystr string - for k, v := range Keys { - mystr += fmt.Sprintf("%s=? ", v) - Keys[k] = ":" + v - } - mystr = fmt.Sprintf("update %s set %s where ", sheetName, mystr) - var whereSlice []string - for _, v := range primaryKey { - whereSlice = append(whereSlice, v+"=?") - Keys = append(Keys, ":"+v) - } - mystr += strings.Join(whereSlice, " and ") - return mystr, Keys, nil -} - -func getInsertSentence(ins interface{}, sheetName string, autoIncrease ...string) (string, []string, error) { - Keys, err := getAllRefKey(ins, "db") - if err != nil { - return "", []string{}, err - } - var mystr, rps string - var rtnKeys []string -cns: - for _, v := range Keys { - for _, vs := range autoIncrease { - if v == vs { - rps += "null," - continue cns - } - } - rtnKeys = append(rtnKeys, ":"+v) - rps += "?," - } - mystr = fmt.Sprintf("insert into %s (%s) values (%s) ", sheetName, strings.Join(Keys, ","), rps[:len(rps)-1]) - return mystr, rtnKeys, nil -} - -func (star *StarDB) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { - kvMap, err := getAllRefValue(ins, "db") - if err != nil { - return nil, err - } - for k, v := range args { - if k == 0 { - continue - } - switch v.(type) { - case string: - str := v.(string) - if strings.Index(str, ":") == 0 { - if _, ok := kvMap[str[1:]]; ok { - args[k] = kvMap[str[1:]] - } else { - args[k] = "" - } - continue - } - if strings.Index(str, `\:`) == 0 { - args[k] = kvMap[str[1:]] - } - } - } - return star.exec(ctx, args...) -} - -func (star *StarDB) Update(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { - return star.updateinsert(nil, true, ins, sheetName, primaryKey...) -} - -func (star *StarDB) UpdateContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { - return star.updateinsert(ctx, true, ins, sheetName, primaryKey...) -} - -func (star *StarDB) Insert(ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) { - return star.updateinsert(nil, false, ins, sheetName, autoCreaseKey...) -} - -func (star *StarDB) InsertContext(ctx context.Context, ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) { - return star.updateinsert(ctx, false, ins, sheetName, autoCreaseKey...) -} - -func (star *StarDB) updateinsert(ctx context.Context, isUpdate bool, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { - var sqlStr string - var para []string - var err error - if isUpdate { - sqlStr, para, err = getUpdateSentence(ins, sheetName, primaryKey...) - } else { - sqlStr, para, err = getInsertSentence(ins, sheetName, primaryKey...) - } - if err != nil { - return nil, err - } - tmpStr := append([]interface{}{}, sqlStr) - for _, v := range para { - tmpStr = append(tmpStr, v) - } - return star.execX(ctx, ins, tmpStr...) -} -func (star *StarDB) QueryXContext(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { - return star.queryX(ctx, ins, args) -} -func (star *StarDB) QueryXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]*StarRows, error) { - var starRes []*StarRows - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.queryX(ctx, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.queryX(ctx, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarDB) ExecXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]sql.Result, error) { - var starRes []sql.Result - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.execX(ctx, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.execX(ctx, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarDB) ExecXContext(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { - return star.execX(ctx, ins, args...) -} - -func (star *StarTx) queryX(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { - kvMap, err := getAllRefValue(ins, "db") - if err != nil { - return nil, err - } - for k, v := range args { - if k == 0 { - continue - } - switch v.(type) { - case string: - str := v.(string) - if strings.Index(str, ":") == 0 { - if _, ok := kvMap[str[1:]]; ok { - args[k] = kvMap[str[1:]] - } else { - args[k] = "" - } - continue - } - if strings.Index(str, `\:`) == 0 { - args[k] = kvMap[str[1:]] - } - } - } - return star.query(ctx, args...) -} -func (star *StarTx) QueryX(ins interface{}, args ...interface{}) (*StarRows, error) { - return star.queryX(nil, ins, args) -} - -func (star *StarTx) QueryXS(ins interface{}, args ...interface{}) ([]*StarRows, error) { - var starRes []*StarRows - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.queryX(nil, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.queryX(nil, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} -func (star *StarTx) Update(ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { - return star.updateinsert(nil, true, ins, sheetName, primaryKey...) -} - -func (star *StarTx) UpdateContext(ctx context.Context, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { - return star.updateinsert(ctx, true, ins, sheetName, primaryKey...) -} - -func (star *StarTx) Insert(ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) { - return star.updateinsert(nil, false, ins, sheetName, autoCreaseKey...) -} - -func (star *StarTx) InsertContext(ctx context.Context, ins interface{}, sheetName string, autoCreaseKey ...string) (sql.Result, error) { - return star.updateinsert(ctx, false, ins, sheetName, autoCreaseKey...) -} - -func (star *StarTx) updateinsert(ctx context.Context, isUpdate bool, ins interface{}, sheetName string, primaryKey ...string) (sql.Result, error) { - var sqlStr string - var para []string - var err error - if isUpdate { - sqlStr, para, err = getUpdateSentence(ins, sheetName, primaryKey...) - } else { - sqlStr, para, err = getInsertSentence(ins, sheetName, primaryKey...) - } - if err != nil { - return nil, err - } - tmpStr := append([]interface{}{}, sqlStr) - for _, v := range para { - tmpStr = append(tmpStr, v) - } - return star.execX(ctx, ins, tmpStr...) -} -func (star *StarTx) ExecXS(ins interface{}, args ...interface{}) ([]sql.Result, error) { - var starRes []sql.Result - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.execX(nil, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.execX(nil, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarTx) ExecX(ins interface{}, args ...interface{}) (sql.Result, error) { - return star.execX(nil, ins, args...) -} -func (star *StarTx) execX(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { - kvMap, err := getAllRefValue(ins, "db") - if err != nil { - return nil, err - } - for k, v := range args { - if k == 0 { - continue - } - switch v.(type) { - case string: - str := v.(string) - if strings.Index(str, ":") == 0 { - if _, ok := kvMap[str[1:]]; ok { - args[k] = kvMap[str[1:]] - } else { - args[k] = "" - } - continue - } - if strings.Index(str, `\:`) == 0 { - args[k] = kvMap[str[1:]] - } - } - } - return star.exec(ctx, args...) -} - -func (star *StarTx) QueryXContext(ctx context.Context, ins interface{}, args ...interface{}) (*StarRows, error) { - return star.queryX(ctx, ins, args) -} -func (star *StarTx) QueryXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]*StarRows, error) { - var starRes []*StarRows - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.queryX(ctx, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.queryX(ctx, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarTx) ExecXSContext(ctx context.Context, ins interface{}, args ...interface{}) ([]sql.Result, error) { - var starRes []sql.Result - t := reflect.TypeOf(ins) - v := reflect.ValueOf(ins) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - //now convert to slice - if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { - for i := 0; i < v.Len(); i++ { - result, err := star.execX(ctx, v.Index(i).Interface(), args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - } else { - result, err := star.execX(ctx, ins, args...) - if err != nil { - return starRes, err - } - starRes = append(starRes, result) - } - return starRes, nil -} - -func (star *StarTx) ExecXContext(ctx context.Context, ins interface{}, args ...interface{}) (sql.Result, error) { - return star.execX(ctx, ins, args...) -} diff --git a/pool.go b/pool.go new file mode 100644 index 0000000..e608ae4 --- /dev/null +++ b/pool.go @@ -0,0 +1,54 @@ +package stardb + +import ( + "time" +) + +// PoolConfig represents database connection pool configuration +type PoolConfig struct { + MaxOpenConns int + MaxIdleConns int + ConnMaxLifetime time.Duration + ConnMaxIdleTime time.Duration +} + +// DefaultPoolConfig returns default pool configuration +func DefaultPoolConfig() *PoolConfig { + return &PoolConfig{ + MaxOpenConns: 25, + MaxIdleConns: 5, + ConnMaxLifetime: time.Hour, + ConnMaxIdleTime: 10 * time.Minute, + } +} + +// SetPoolConfig applies pool configuration to the database +func (s *StarDB) SetPoolConfig(config *PoolConfig) { + if config.MaxOpenConns > 0 { + s.db.SetMaxOpenConns(config.MaxOpenConns) + } + if config.MaxIdleConns > 0 { + s.db.SetMaxIdleConns(config.MaxIdleConns) + } + if config.ConnMaxLifetime > 0 { + s.db.SetConnMaxLifetime(config.ConnMaxLifetime) + } + if config.ConnMaxIdleTime > 0 { + s.db.SetConnMaxIdleTime(config.ConnMaxIdleTime) + } +} + +// OpenWithPool opens a database connection with pool configuration +func OpenWithPool(driver, connStr string, config *PoolConfig) (*StarDB, error) { + db := NewStarDB() + if err := db.Open(driver, connStr); err != nil { + return nil, err + } + + if config == nil { + config = DefaultPoolConfig() + } + db.SetPoolConfig(config) + + return db, nil +} diff --git a/reflect.go b/reflect.go index 3c8aeba..78a73f4 100644 --- a/reflect.go +++ b/reflect.go @@ -6,231 +6,252 @@ import ( "time" ) -func (star *StarRows) setAllRefValue(stc interface{}, skey string, rows int) error { - t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc) - if t.Kind() == reflect.Ptr { - v = v.Elem() - } - if t.Kind() != reflect.Ptr && !v.CanSet() { - return errors.New("interface{} is not writable") - } - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - if v.Kind() != reflect.Struct { - return errors.New("interface{} is not a struct") +// setStructFieldsFromRow sets struct fields from a row result using reflection +func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error { + targetType := reflect.TypeOf(target) + targetValue := reflect.ValueOf(target) + + if targetType.Kind() == reflect.Ptr { + targetValue = targetValue.Elem() } - for i := 0; i < t.NumField(); i++ { - tp := t.Field(i) - srFrd := v.Field(i) - seg := tp.Tag.Get(skey) + if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() { + return errors.New("target is not writable") + } - if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { - if seg == "" { + if targetType.Kind() == reflect.Ptr { + targetType = targetType.Elem() + } + + if targetValue.Kind() != reflect.Struct { + return errors.New("target is not a struct") + } + + for i := 0; i < targetType.NumField(); i++ { + field := targetType.Field(i) + fieldValue := targetValue.Field(i) + tagValue := field.Tag.Get(tagKey) + + // Handle nested structs + if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { + if tagValue == "" { continue } - if seg == "---" { - sp := reflect.New(reflect.TypeOf(srFrd.Interface()).Elem()).Interface() - star.setAllRefValue(sp, skey, rows) - v.Field(i).Set(reflect.ValueOf(sp)) + if tagValue == "---" { + nestedPtr := reflect.New(reflect.TypeOf(fieldValue.Interface()).Elem()).Interface() + r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex) + targetValue.Field(i).Set(reflect.ValueOf(nestedPtr)) continue } } - if srFrd.Kind() == reflect.Struct { - if seg == "" { + + if fieldValue.Kind() == reflect.Struct { + if tagValue == "" { continue } - if seg == "---" { - sp := reflect.New(reflect.TypeOf(v.Field(i).Interface())).Interface() - star.setAllRefValue(sp, skey, rows) - v.Field(i).Set(reflect.ValueOf(sp).Elem()) + if tagValue == "---" { + nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface() + r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex) + targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem()) continue } } - if seg == "" { + + if tagValue == "" { continue } - if _, ok := star.Row(rows).columnref[seg]; !ok { + + // Check if column exists + if _, ok := r.Row(rowIndex).columnIndex[tagValue]; !ok { continue } - myInt64 := star.Row(rows).MustInt64(seg) - myUint64 := star.Row(rows).MustUint64(seg) - switch v.Field(i).Kind() { - case reflect.String: - v.Field(i).SetString(star.Row(rows).MustString(seg)) - case reflect.Int64: - v.Field(i).SetInt(myInt64) - case reflect.Int32: - v.Field(i).SetInt(int64(star.Row(rows).MustInt32(seg))) - case reflect.Int16: - v.Field(i).SetInt(int64(int16(myInt64))) - case reflect.Int8: - v.Field(i).SetInt(int64(int8(myInt64))) - case reflect.Uint64: - v.Field(i).SetUint(myUint64) - case reflect.Uint32: - v.Field(i).SetUint(uint64(uint32(myUint64))) - case reflect.Uint16: - v.Field(i).SetUint(uint64(uint16(myUint64))) - case reflect.Uint8: - v.Field(i).SetUint(uint64(uint8(myUint64))) - case reflect.Bool: - v.Field(i).SetBool(star.Row(rows).MustBool(seg)) - case reflect.Float64: - v.Field(i).SetFloat(star.Row(rows).MustFloat64(seg)) - case reflect.Float32: - v.Field(i).SetFloat(float64(star.Row(rows).MustFloat32(seg))) - case reflect.Interface, reflect.Struct, reflect.Ptr: - inf := star.Row(rows).Result[star.columnref[seg]] - switch vtype := inf.(type) { - case time.Time: - v.Field(i).Set(reflect.ValueOf(vtype)) - } - default: - } + // Set field value based on type + r.setFieldValue(fieldValue, tagValue, rowIndex) } return nil } -func setRefValue(stc interface{}, skey, key string, value interface{}) error { - t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc).Elem() - if t.Kind() != reflect.Ptr || !v.CanSet() { - return errors.New("interface{} is not writable") - } - if v.Kind() != reflect.Struct { - return errors.New("interface{} is not a struct") - } - t = t.Elem() - for i := 0; i < t.NumField(); i++ { - tp := t.Field(i) - seg := tp.Tag.Get(skey) - if seg == "" || key != seg { - continue +// setFieldValue sets a single field value +func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, rowIndex int) { + row := r.Row(rowIndex) + + switch fieldValue.Kind() { + case reflect.String: + fieldValue.SetString(row.MustString(columnName)) + case reflect.Int: + fieldValue.SetInt(int64(row.MustInt(columnName))) + case reflect.Int8: + fieldValue.SetInt(int64(int8(row.MustInt64(columnName)))) + case reflect.Int16: + fieldValue.SetInt(int64(int16(row.MustInt64(columnName)))) + case reflect.Int32: + fieldValue.SetInt(int64(row.MustInt32(columnName))) + case reflect.Int64: + fieldValue.SetInt(row.MustInt64(columnName)) + case reflect.Uint: + fieldValue.SetUint(uint64(row.MustUint64(columnName))) + case reflect.Uint8: + fieldValue.SetUint(uint64(uint8(row.MustUint64(columnName)))) + case reflect.Uint16: + fieldValue.SetUint(uint64(uint16(row.MustUint64(columnName)))) + case reflect.Uint32: + fieldValue.SetUint(uint64(uint32(row.MustUint64(columnName)))) + case reflect.Uint64: + fieldValue.SetUint(row.MustUint64(columnName)) + case reflect.Bool: + fieldValue.SetBool(row.MustBool(columnName)) + case reflect.Float32: + fieldValue.SetFloat(float64(row.MustFloat32(columnName))) + case reflect.Float64: + fieldValue.SetFloat(row.MustFloat64(columnName)) + case reflect.Interface, reflect.Struct, reflect.Ptr: + // Handle special types like time.Time + colIndex := r.columnIndex[columnName] + val := row.Result()[colIndex] + if t, ok := val.(time.Time); ok { + fieldValue.Set(reflect.ValueOf(t)) } - v.Field(i).Set(reflect.ValueOf(value)) } - return nil } -func getAllRefValue(stc interface{}, skey string) (map[string]interface{}, error) { +// getStructFieldValues extracts all field values from a struct +func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) { result := make(map[string]interface{}) - t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc) - if t.Kind() == reflect.Ptr { - if v.IsNil() { - return nil, errors.New("ptr interface{} is nil") + targetType := reflect.TypeOf(target) + targetValue := reflect.ValueOf(target) + + if targetType.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return nil, errors.New("pointer target is nil") } - t = t.Elem() - v = v.Elem() + targetType = targetType.Elem() + targetValue = targetValue.Elem() } - if v.Kind() != reflect.Struct { - return nil, errors.New("interface{} is not a struct") + + if targetValue.Kind() != reflect.Struct { + return nil, errors.New("target is not a struct") } - for i := 0; i < t.NumField(); i++ { - tp := t.Field(i) - srFrd := v.Field(i) - seg := tp.Tag.Get(skey) - if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { - if srFrd.IsNil() { + + for i := 0; i < targetType.NumField(); i++ { + field := targetType.Field(i) + fieldValue := targetValue.Field(i) + tagValue := field.Tag.Get(tagKey) + + // Handle nested pointer structs + if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { + if fieldValue.IsNil() { continue } - if seg == "---" { - res, err := getAllRefValue(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) + if tagValue == "---" { + nestedValues, err := getStructFieldValues(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey) if err != nil { return result, err } - for k, v := range res { + for k, v := range nestedValues { result[k] = v } continue } } - if v.Field(i).Kind() == reflect.Struct { - res, err := getAllRefValue(v.Field(i).Interface(), skey) - if seg == "---" { + + // Handle nested structs + if targetValue.Field(i).Kind() == reflect.Struct { + if tagValue == "---" { + nestedValues, err := getStructFieldValues(targetValue.Field(i).Interface(), tagKey) if err != nil { return result, err } - for k, v := range res { + for k, v := range nestedValues { result[k] = v } continue } } - if seg == "" { + + if tagValue == "" { continue } - value := v.Field(i) - if !value.CanInterface() { + + if !fieldValue.CanInterface() { continue } - result[seg] = value.Interface() + + result[tagValue] = fieldValue.Interface() } + return result, nil } -func getAllRefKey(stc interface{}, skey string) ([]string, error) { +// getStructFieldNames extracts all field names (tag values) from a struct +func getStructFieldNames(target interface{}, tagKey string) ([]string, error) { var result []string - _, isStruct := isWritableStruct(stc) - if !isStruct { - return []string{}, errors.New("interface{} is not a struct") + + if !isStruct(target) { + return []string{}, errors.New("target is not a struct") } - t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc) - if t.Kind() == reflect.Ptr { - if v.IsNil() { - return []string{}, errors.New("ptr interface{} is nil") + + targetType := reflect.TypeOf(target) + targetValue := reflect.ValueOf(target) + + if targetType.Kind() == reflect.Ptr { + if targetValue.IsNil() { + return []string{}, errors.New("pointer target is nil") } - t = t.Elem() - v = v.Elem() + targetType = targetType.Elem() + targetValue = targetValue.Elem() } - for i := 0; i < t.NumField(); i++ { - srFrd := v.Field(i) - profile := t.Field(i) - seg := profile.Tag.Get(skey) - if srFrd.Kind() == reflect.Ptr && reflect.TypeOf(srFrd.Interface()).Elem().Kind() == reflect.Struct { - if srFrd.IsNil() { + + for i := 0; i < targetType.NumField(); i++ { + fieldValue := targetValue.Field(i) + field := targetType.Field(i) + tagValue := field.Tag.Get(tagKey) + + // Handle nested pointer structs + if fieldValue.Kind() == reflect.Ptr && reflect.TypeOf(fieldValue.Interface()).Elem().Kind() == reflect.Struct { + if fieldValue.IsNil() { continue } - if seg == "---" { - res, err := getAllRefKey(reflect.ValueOf(srFrd.Elem().Interface()).Interface(), skey) + if tagValue == "---" { + nestedNames, err := getStructFieldNames(reflect.ValueOf(fieldValue.Elem().Interface()).Interface(), tagKey) if err != nil { return result, err } - for _, v := range res { - result = append(result, v) - } + result = append(result, nestedNames...) continue } } - if v.Field(i).Kind() == reflect.Struct && seg == "---" { - res, err := getAllRefKey(v.Field(i).Interface(), skey) + + // Handle nested structs + if targetValue.Field(i).Kind() == reflect.Struct && tagValue == "---" { + nestedNames, err := getStructFieldNames(targetValue.Field(i).Interface(), tagKey) if err != nil { return result, err } - for _, v := range res { - result = append(result, v) - } + result = append(result, nestedNames...) + continue } - if seg != "" { - result = append(result, seg) + + if tagValue != "" { + result = append(result, tagValue) } } + return result, nil } -func isWritableStruct(stc interface{}) (isWritable bool, isStruct bool) { - t := reflect.TypeOf(stc) - v := reflect.ValueOf(stc) - if t.Kind() == reflect.Ptr || v.CanSet() { - isWritable = true - } - if v.Kind() == reflect.Struct { - isStruct = true - } - return +// isWritable checks if a value is writable +func isWritable(target interface{}) bool { + targetType := reflect.TypeOf(target) + targetValue := reflect.ValueOf(target) + return targetType.Kind() == reflect.Ptr || targetValue.CanSet() +} + +// isStruct checks if a value is a struct +func isStruct(target interface{}) bool { + targetValue := reflect.ValueOf(target) + if targetValue.Kind() == reflect.Ptr { + targetValue = targetValue.Elem() + } + return targetValue.Kind() == reflect.Struct } diff --git a/reflect_test.go b/reflect_test.go deleted file mode 100644 index 887d66b..0000000 --- a/reflect_test.go +++ /dev/null @@ -1,49 +0,0 @@ -package stardb - -import ( - "fmt" - "testing" -) - -type Useless struct { - Leader string `db:"leader"` - Usable bool `db:"use"` - O *Whoami `db:"---"` -} - -type Whoami struct { - Hehe string `db:"hehe"` -} - -func TestUpInOrm(t *testing.T) { - var hehe = Useless{ - Leader: "no", - Usable: false, - } - sqlstr, param, err := getUpdateSentence(hehe, "ryz", "leader") - fmt.Println(sqlstr, param, err) - sqlstr, param, err = getInsertSentence(hehe, "ryz", "use") - fmt.Println(sqlstr, param, err) -} - -func Test_SetRefVal(t *testing.T) { - var hehe = Useless{ - Leader: "no", - } - fmt.Printf("%+v\n", hehe) - fmt.Println(setRefValue(&hehe, "db", "leader", "sb")) - fmt.Printf("%+v\n", hehe) - fmt.Printf("%+v\n", hehe) - fmt.Println(getAllRefKey(hehe, "db")) - fmt.Println(getAllRefValue(hehe, "db")) -} - -func Test_Ref(t *testing.T) { - oooooo := Useless{ - Leader: "Heheeee", - } - oooooo.O = &Whoami{"fuck"} - fmt.Println(getAllRefKey(oooooo, "db")) - fmt.Println(getAllRefValue(oooooo, "db")) - fmt.Println(getAllRefValue(&oooooo, "db")) -} diff --git a/result.go b/result.go new file mode 100644 index 0000000..24f6346 --- /dev/null +++ b/result.go @@ -0,0 +1,286 @@ +package stardb + +import ( + "fmt" + "reflect" + "time" +) + +// StarResult represents a single row result +type StarResult struct { + result []interface{} + columns []string + columnIndex map[string]int + columnsType []reflect.Type +} + +// Result returns the raw result slice +func (r *StarResult) Result() []interface{} { + return r.result +} + +// Columns returns column names +func (r *StarResult) Columns() []string { + return r.columns +} + +// ColumnsType returns column types +func (r *StarResult) ColumnsType() []reflect.Type { + return r.columnsType +} + +// IsNil checks if a column value is nil +func (r *StarResult) IsNil(name string) bool { + index, ok := r.columnIndex[name] + if !ok { + return true + } + return r.result[index] == nil +} + +// MustString returns column value as string +func (r *StarResult) MustString(name string) string { + index, ok := r.columnIndex[name] + if !ok { + return "" + } + return convertToString(r.result[index]) +} + +// MustInt returns column value as int +func (r *StarResult) MustInt(name string) int { + return int(r.MustInt64(name)) +} + +// MustInt32 returns column value as int32 +func (r *StarResult) MustInt32(name string) int32 { + return int32(r.MustInt64(name)) +} + +// MustInt64 returns column value as int64 +func (r *StarResult) MustInt64(name string) int64 { + index, ok := r.columnIndex[name] + if !ok { + return 0 + } + return convertToInt64(r.result[index]) +} + +// MustUint64 returns column value as uint64 +func (r *StarResult) MustUint64(name string) uint64 { + index, ok := r.columnIndex[name] + if !ok { + return 0 + } + return convertToUint64(r.result[index]) +} + +// MustFloat32 returns column value as float32 +func (r *StarResult) MustFloat32(name string) float32 { + return float32(r.MustFloat64(name)) +} + +// MustFloat64 returns column value as float64 +func (r *StarResult) MustFloat64(name string) float64 { + index, ok := r.columnIndex[name] + if !ok { + return 0 + } + return convertToFloat64(r.result[index]) +} + +// MustBool returns column value as bool +func (r *StarResult) MustBool(name string) bool { + index, ok := r.columnIndex[name] + if !ok { + return false + } + return convertToBool(r.result[index]) +} + +// MustBytes returns column value as []byte +func (r *StarResult) MustBytes(name string) []byte { + index, ok := r.columnIndex[name] + if !ok { + return []byte{} + } + if b, ok := r.result[index].([]byte); ok { + return b + } + return []byte(r.MustString(name)) +} + +// MustDate returns column value as time.Time +func (r *StarResult) MustDate(name, layout string) time.Time { + index, ok := r.columnIndex[name] + if !ok { + return time.Time{} + } + return convertToTime(r.result[index], layout) +} + +// Scan scans the result into provided pointers +// Usage: row.Scan(&id, &name, &email) +func (r *StarResult) Scan(dest ...interface{}) error { + if len(dest) != len(r.result) { + return fmt.Errorf("expected %d destination arguments, got %d", len(r.result), len(dest)) + } + + for i, d := range dest { + if err := convertAssign(d, r.result[i]); err != nil { + return fmt.Errorf("error scanning column %d: %v", i, err) + } + } + return nil +} + +// convertAssign assigns src to dest +func convertAssign(dest, src interface{}) error { + switch d := dest.(type) { + case *string: + *d = convertToString(src) + case *int: + *d = int(convertToInt64(src)) + case *int32: + *d = int32(convertToInt64(src)) + case *int64: + *d = convertToInt64(src) + case *uint64: + *d = convertToUint64(src) + case *float32: + *d = float32(convertToFloat64(src)) + case *float64: + *d = convertToFloat64(src) + case *bool: + *d = convertToBool(src) + case *time.Time: + if t, ok := src.(time.Time); ok { + *d = t + } + case *[]byte: + if b, ok := src.([]byte); ok { + *d = b + } + default: + return fmt.Errorf("unsupported Scan type: %T", dest) + } + return nil +} + +// StarResultCol represents a single column result +type StarResultCol struct { + result []interface{} +} + +// Result returns the raw result slice +func (c *StarResultCol) Result() []interface{} { + return c.result +} + +// Len returns the number of values +func (c *StarResultCol) Len() int { + return len(c.result) +} + +// IsNil returns slice of nil checks for each row +func (c *StarResultCol) IsNil() []bool { + result := make([]bool, len(c.result)) + for i, v := range c.result { + result[i] = (v == nil) + } + return result +} + +// MustString returns all values as strings +func (c *StarResultCol) MustString() []string { + result := make([]string, len(c.result)) + for i, v := range c.result { + result[i] = convertToString(v) + } + return result +} + +// MustInt returns all values as ints +func (c *StarResultCol) MustInt() []int { + result := make([]int, len(c.result)) + for i, v := range c.result { + result[i] = int(convertToInt64(v)) + } + return result +} + +// MustInt32 returns all values as int32 +func (c *StarResultCol) MustInt32() []int32 { + result := make([]int32, len(c.result)) + for i, v := range c.result { + result[i] = int32(convertToInt64(v)) + } + return result +} + +// MustInt64 returns all values as int64 +func (c *StarResultCol) MustInt64() []int64 { + result := make([]int64, len(c.result)) + for i, v := range c.result { + result[i] = convertToInt64(v) + } + return result +} + +// MustUint64 returns all values as uint64 +func (c *StarResultCol) MustUint64() []uint64 { + result := make([]uint64, len(c.result)) + for i, v := range c.result { + result[i] = convertToUint64(v) + } + return result +} + +// MustFloat32 returns all values as float32 +func (c *StarResultCol) MustFloat32() []float32 { + result := make([]float32, len(c.result)) + for i, v := range c.result { + result[i] = float32(convertToFloat64(v)) + } + return result +} + +// MustFloat64 returns all values as float64 +func (c *StarResultCol) MustFloat64() []float64 { + result := make([]float64, len(c.result)) + for i, v := range c.result { + result[i] = convertToFloat64(v) + } + return result +} + +// MustBool returns all values as bool +func (c *StarResultCol) MustBool() []bool { + result := make([]bool, len(c.result)) + for i, v := range c.result { + result[i] = convertToBool(v) + } + return result +} + +// MustBytes returns all values as []byte +func (c *StarResultCol) MustBytes() [][]byte { + result := make([][]byte, len(c.result)) + for i, v := range c.result { + if b, ok := v.([]byte); ok { + result[i] = b + } else { + result[i] = []byte(convertToString(v)) + } + } + return result +} + +// MustDate returns all values as time.Time +func (c *StarResultCol) MustDate(layout string) []time.Time { + result := make([]time.Time, len(c.result)) + for i, v := range c.result { + result[i] = convertToTime(v, layout) + } + return result +} diff --git a/result_safe.go b/result_safe.go new file mode 100644 index 0000000..63dd486 --- /dev/null +++ b/result_safe.go @@ -0,0 +1,51 @@ +package stardb + +import ( + "errors" + "fmt" + "strconv" +) + +// GetString returns column value as string with error +func (r *StarResult) GetString(name string) (string, error) { + index, ok := r.columnIndex[name] + if !ok { + return "", errors.New("column not found: " + name) + } + return ConvertToStringSafe(r.Result()[index]) +} + +// GetInt64 returns column value as int64 with error +func (r *StarResult) GetInt64(name string) (int64, error) { + index, ok := r.columnIndex[name] + if !ok { + return 0, errors.New("column not found: " + name) + } + return ConvertToInt64Safe(r.Result()[index]) +} + +// GetFloat64 returns column value as float64 with error +func (r *StarResult) GetFloat64(name string) (float64, error) { + index, ok := r.columnIndex[name] + if !ok { + return 0, errors.New("column not found: " + name) + } + + switch v := r.Result()[index].(type) { + case nil: + return 0, nil + case float64: + return v, nil + case float32: + return float64(v), nil + case int, int32, int64, uint64: + val, err := ConvertToInt64Safe(v) + return float64(val), err + case string: + return strconv.ParseFloat(v, 64) + case []byte: + return strconv.ParseFloat(string(v), 64) + default: + return 0, fmt.Errorf("cannot convert %T to float64", v) + } +} diff --git a/rows.go b/rows.go new file mode 100644 index 0000000..4af1ee1 --- /dev/null +++ b/rows.go @@ -0,0 +1,168 @@ +package stardb + +import ( + "database/sql" + "reflect" + "strconv" + "time" +) + +// StarRows represents a result set from a query +type StarRows struct { + rows *sql.Rows + db *StarDB + length int + stringResult []map[string]string + columns []string + columnsType []reflect.Type + columnIndex map[string]int + data [][]interface{} + parsed bool +} + +// Length returns the number of rows +func (r *StarRows) Length() int { + return r.length +} + +// StringResult returns all rows as string maps +func (r *StarRows) StringResult() []map[string]string { + return r.stringResult +} + +// Columns returns column names +func (r *StarRows) Columns() []string { + return r.columns +} + +// ColumnsType returns column types +func (r *StarRows) ColumnsType() []reflect.Type { + return r.columnsType +} + +// Close closes the result set +func (r *StarRows) Close() error { + return r.rows.Close() +} + +// Rescan re-parses the result set +func (r *StarRows) Rescan() error { + return r.parse() +} + +// Row returns a specific row by index +func (r *StarRows) Row(index int) *StarResult { + result := &StarResult{} + if index >= len(r.data) { + return result + } + result.result = r.data[index] + result.columns = r.columns + result.columnsType = r.columnsType + result.columnIndex = r.columnIndex + return result +} + +// Col returns all values for a specific column +func (r *StarRows) Col(name string) *StarResultCol { + result := &StarResultCol{} + if _, ok := r.columnIndex[name]; !ok { + return result + } + + colIndex := r.columnIndex[name] + for _, row := range r.data { + result.result = append(result.result, row[colIndex]) + } + return result +} + +// parse parses the sql.Rows into internal data structures +func (r *StarRows) parse() error { + if r.parsed { + return nil + } + defer func() { + r.parsed = true + }() + + r.data = [][]interface{}{} + r.columnIndex = make(map[string]int) + r.stringResult = []map[string]string{} + + var err error + r.columns, err = r.rows.Columns() + if err != nil { + return err + } + + columnTypes, err := r.rows.ColumnTypes() + if err != nil { + return err + } + + for _, colType := range columnTypes { + r.columnsType = append(r.columnsType, colType.ScanType()) + } + + // Build column index map + for i, colName := range r.columns { + r.columnIndex[colName] = i + } + + // Prepare scan arguments + scanArgs := make([]interface{}, len(r.columns)) + values := make([]interface{}, len(r.columns)) + for i := range values { + scanArgs[i] = &values[i] + } + + // Scan all rows + for r.rows.Next() { + if err := r.rows.Scan(scanArgs...); err != nil { + return err + } + + record := make(map[string]string) + rowCopy := make([]interface{}, len(values)) + + for i, val := range values { + rowCopy[i] = val + record[r.columns[i]] = convertToString(val) + } + + r.data = append(r.data, rowCopy) + r.stringResult = append(r.stringResult, record) + } + + r.length = len(r.stringResult) + return nil +} + +// convertToString converts any value to string +func convertToString(val interface{}) string { + switch v := val.(type) { + case nil: + return "" + case string: + return v + case int: + return strconv.Itoa(v) + case int32: + return strconv.FormatInt(int64(v), 10) + case int64: + return strconv.FormatInt(v, 10) + case float32: + return strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + return strconv.FormatFloat(v, 'f', -1, 64) + case bool: + return strconv.FormatBool(v) + case time.Time: + return v.String() + case []byte: + return string(v) + default: + return "" + } +} diff --git a/stardb.go b/stardb.go new file mode 100644 index 0000000..db43d6d --- /dev/null +++ b/stardb.go @@ -0,0 +1,296 @@ +package stardb + +import ( + "context" + "database/sql" + "errors" +) + +// StarDB is a simple wrapper around sql.DB providing enhanced functionality +type StarDB struct { + db *sql.DB + ManualScan bool // If true, rows won't be automatically parsed +} + +// NewStarDB creates a new StarDB instance +func NewStarDB() *StarDB { + return &StarDB{} +} + +// NewStarDBWithDB creates a new StarDB instance with an existing *sql.DB +func NewStarDBWithDB(db *sql.DB) *StarDB { + return &StarDB{db: db} +} + +// DB returns the underlying *sql.DB +func (s *StarDB) DB() *sql.DB { + return s.db +} + +// SetDB sets the underlying *sql.DB +func (s *StarDB) SetDB(db *sql.DB) { + s.db = db +} + +// Open opens a new database connection +func (s *StarDB) Open(driver, connStr string) error { + var err error + s.db, err = sql.Open(driver, connStr) + return err +} + +// Close closes the database connection +func (s *StarDB) Close() error { + return s.db.Close() +} + +// Ping verifies the database connection is alive +func (s *StarDB) Ping() error { + return s.db.Ping() +} + +// PingContext verifies the database connection with context +func (s *StarDB) PingContext(ctx context.Context) error { + return s.db.PingContext(ctx) +} + +// Stats returns database statistics +func (s *StarDB) Stats() sql.DBStats { + return s.db.Stats() +} + +// SetMaxOpenConns sets the maximum number of open connections +func (s *StarDB) SetMaxOpenConns(n int) { + s.db.SetMaxOpenConns(n) +} + +// SetMaxIdleConns sets the maximum number of idle connections +func (s *StarDB) SetMaxIdleConns(n int) { + s.db.SetMaxIdleConns(n) +} + +// Conn returns a single connection from the pool +func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) { + return s.db.Conn(ctx) +} + +// Query executes a query that returns rows +// Usage: Query("SELECT * FROM users WHERE id = ?", 1) +func (s *StarDB) Query(query string, args ...interface{}) (*StarRows, error) { + return s.query(nil, query, args...) +} + +// QueryContext executes a query with context +func (s *StarDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + return s.query(ctx, query, args...) +} + +// query is the internal query implementation +func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + if err := s.db.Ping(); err != nil { + return nil, err + } + + var rows *sql.Rows + var err error + + if ctx == nil { + rows, err = s.db.Query(query, args...) + } else { + rows, err = s.db.QueryContext(ctx, query, args...) + } + + if err != nil { + return nil, err + } + + starRows := &StarRows{ + rows: rows, + db: s, + } + + if !s.ManualScan { + err = starRows.parse() + } + + return starRows, err +} + +// Exec executes a query that doesn't return rows +// Usage: Exec("INSERT INTO users (name) VALUES (?)", "John") +func (s *StarDB) Exec(query string, args ...interface{}) (sql.Result, error) { + return s.exec(nil, query, args...) +} + +// ExecContext executes a query with context +func (s *StarDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return s.exec(ctx, query, args...) +} + +// exec is the internal exec implementation +func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if err := s.db.Ping(); err != nil { + return nil, err + } + + if ctx == nil { + return s.db.Exec(query, args...) + } + return s.db.ExecContext(ctx, query, args...) +} + +// Prepare creates a prepared statement +func (s *StarDB) Prepare(query string) (*StarStmt, error) { + stmt, err := s.db.Prepare(query) + if err != nil { + return nil, err + } + return &StarStmt{stmt: stmt, db: s}, nil +} + +// PrepareContext creates a prepared statement with context +func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) { + stmt, err := s.db.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &StarStmt{stmt: stmt, db: s}, nil +} + +// QueryStmt executes a prepared statement query +// Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1) +func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := s.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Query(args...) +} + +// QueryStmtContext executes a prepared statement query with context +func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := s.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.QueryContext(ctx, args...) +} + +// ExecStmt executes a prepared statement +// Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John") +func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := s.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Exec(args...) +} + +// ExecStmtContext executes a prepared statement with context +func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := s.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.ExecContext(ctx, args...) +} + +// Begin starts a transaction +func (s *StarDB) Begin() (*StarTx, error) { + tx, err := s.db.Begin() + if err != nil { + return nil, err + } + return &StarTx{tx: tx, db: s}, nil +} + +// BeginTx starts a transaction with options +func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) { + tx, err := s.db.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + return &StarTx{tx: tx, db: s}, nil +} + +// StarStmt represents a prepared statement +type StarStmt struct { + stmt *sql.Stmt + db *StarDB +} + +// Query executes a prepared statement query +func (s *StarStmt) Query(args ...interface{}) (*StarRows, error) { + return s.query(nil, args...) +} + +// QueryContext executes a prepared statement query with context +func (s *StarStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { + return s.query(ctx, args...) +} + +// query is the internal query implementation +func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) { + var rows *sql.Rows + var err error + + if ctx == nil { + rows, err = s.stmt.Query(args...) + } else { + rows, err = s.stmt.QueryContext(ctx, args...) + } + + if err != nil { + return nil, err + } + + starRows := &StarRows{ + rows: rows, + db: s.db, + } + + if !s.db.ManualScan { + err = starRows.parse() + } + + return starRows, err +} + +// Exec executes a prepared statement +func (s *StarStmt) Exec(args ...interface{}) (sql.Result, error) { + return s.exec(nil, args...) +} + +// ExecContext executes a prepared statement with context +func (s *StarStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { + return s.exec(ctx, args...) +} + +// exec is the internal exec implementation +func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { + if ctx == nil { + return s.stmt.Exec(args...) + } + return s.stmt.ExecContext(ctx, args...) +} + +// Close closes the prepared statement +func (s *StarStmt) Close() error { + return s.stmt.Close() +} diff --git a/stardb_v1.go b/stardb_v1.go deleted file mode 100644 index 0da8483..0000000 --- a/stardb_v1.go +++ /dev/null @@ -1,1291 +0,0 @@ -package stardb - -import ( - "context" - "database/sql" - "errors" - "reflect" - "strconv" - "time" -) - -// StarDB 一个简单封装的DB库 -type StarDB struct { - Db *sql.DB - ManualScan bool -} - -type StarTx struct { - Db *StarDB - Tx *sql.Tx -} - -// StarRows 为查询结果集(按行) -type StarRows struct { - Rows *sql.Rows - Length int - StringResult []map[string]string - Columns []string - ColumnsType []reflect.Type - columnref map[string]int - result [][]interface{} - parsed bool -} - -type StarDBStmt struct { - Stmt *sql.Stmt - Db *StarDB -} - -// StarResult 为查询结果集(总) -type StarResult struct { - Result []interface{} - Columns []string - columnref map[string]int - ColumnsType []reflect.Type -} - -// StarResultCol 为查询结果集(按列) -type StarResultCol struct { - Result []interface{} -} - -// MustBytes 列查询结果转Bytes -func (star *StarResultCol) MustBytes() [][]byte { - var res [][]byte - for _, v := range star.Result { - res = append(res, v.([]byte)) - } - return res -} - -// MustBool 列查询结果转Bool -func (star *StarResultCol) MustBool() []bool { - var res []bool - var tmp bool - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = false - case bool: - tmp = vtype - case float64: - if vtype > 0 { - tmp = true - } else { - tmp = false - } - case float32: - if vtype > 0 { - tmp = true - } else { - tmp = false - } - case int: - if vtype > 0 { - tmp = true - } else { - tmp = false - } - case int32: - if vtype > 0 { - tmp = true - } else { - tmp = false - } - case int64: - if vtype > 0 { - tmp = true - } else { - tmp = false - } - case uint64: - if vtype > 0 { - tmp = true - } else { - tmp = false - } - case string: - tmp, _ = strconv.ParseBool(vtype) - default: - tmp, _ = strconv.ParseBool(string(vtype.([]byte))) - } - res = append(res, tmp) - } - return res -} - -// MustFloat32 列查询结果转Float32 -func (star *StarResultCol) MustFloat32() []float32 { - var res []float32 - var tmp float32 - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = 0 - case float32: - tmp = vtype - case float64: - tmp = float32(vtype) - case string: - tmps, _ := strconv.ParseFloat(vtype, 32) - tmp = float32(tmps) - case int: - tmp = float32(vtype) - case int32: - tmp = float32(vtype) - case int64: - tmp = float32(vtype) - case uint64: - tmp = float32(vtype) - case time.Time: - tmp = float32(vtype.Unix()) - default: - tmpt := string(vtype.([]byte)) - tmps, _ := strconv.ParseFloat(tmpt, 32) - tmp = float32(tmps) - } - res = append(res, tmp) - } - return res -} - -// MustFloat64 列查询结果转Float64 -func (star *StarResultCol) MustFloat64() []float64 { - var res []float64 - var tmp float64 - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = 0 - case float64: - tmp = vtype - case float32: - tmp = float64(vtype) - case string: - tmp, _ = strconv.ParseFloat(vtype, 64) - case int: - tmp = float64(vtype) - case int32: - tmp = float64(vtype) - case int64: - tmp = float64(vtype) - case uint64: - tmp = float64(vtype) - case time.Time: - tmp = float64(vtype.Unix()) - default: - tmpt := string(vtype.([]byte)) - tmps, _ := strconv.ParseFloat(tmpt, 64) - tmp = float64(tmps) - } - res = append(res, tmp) - } - return res -} - -// MustString 列查询结果转String -func (star *StarResultCol) MustString() []string { - var res []string - var tmp string - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = "" - case string: - tmp = vtype - case int64: - tmp = strconv.FormatInt(vtype, 10) - case int32: - tmp = strconv.Itoa(int(vtype)) - case bool: - tmp = strconv.FormatBool(vtype) - case float64: - tmp = strconv.FormatFloat(vtype, 'f', 10, 64) - case float32: - tmp = strconv.FormatFloat(float64(vtype), 'f', 10, 32) - case int: - tmp = strconv.Itoa(vtype) - case uint64: - tmp = strconv.FormatUint(vtype, 10) - case time.Time: - tmp = vtype.String() - default: - tmp = string(vtype.([]byte)) - } - res = append(res, tmp) - } - return res -} - -// MustInt32 列查询结果转Int32 -func (star *StarResultCol) MustInt32() []int32 { - var res []int32 - var tmp int32 - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = 0 - case float64: - tmp = int32(vtype) - case float32: - tmp = int32(vtype) - case string: - tmps, _ := strconv.ParseInt(vtype, 10, 32) - tmp = int32(tmps) - case int: - tmp = int32(vtype) - case int64: - tmp = int32(vtype) - case uint64: - tmp = int32(vtype) - case int32: - tmp = vtype - case time.Time: - tmp = int32(vtype.Unix()) - default: - tmpt := string(vtype.([]byte)) - tmps, _ := strconv.ParseInt(tmpt, 10, 32) - tmp = int32(tmps) - } - res = append(res, tmp) - } - return res -} - -// MustInt64 列查询结果转Int64 -func (star *StarResultCol) MustInt64() []int64 { - var res []int64 - var tmp int64 - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = 0 - case float64: - tmp = int64(vtype) - case float32: - tmp = int64(vtype) - case string: - tmps, _ := strconv.ParseInt(vtype, 10, 64) - tmp = int64(tmps) - case int: - tmp = int64(vtype) - case int32: - tmp = int64(vtype) - case uint64: - tmp = int64(vtype) - case int64: - tmp = vtype - case time.Time: - tmp = vtype.Unix() - default: - tmpt := string(vtype.([]byte)) - tmp, _ = strconv.ParseInt(tmpt, 10, 64) - } - res = append(res, tmp) - } - return res -} - -// MustUint64 列查询结果转Int64 -func (star *StarResultCol) MustUint64() []uint64 { - var res []uint64 - var tmp uint64 - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = 0 - case float64: - tmp = uint64(vtype) - case float32: - tmp = uint64(vtype) - case string: - tmp, _ = strconv.ParseUint(vtype, 10, 64) - case int: - tmp = uint64(vtype) - case int32: - tmp = uint64(vtype) - case int64: - tmp = uint64(vtype) - case uint64: - tmp = vtype - case time.Time: - tmp = uint64(vtype.Unix()) - default: - tmpt := string(vtype.([]byte)) - tmp, _ = strconv.ParseUint(tmpt, 10, 64) - } - res = append(res, tmp) - } - return res -} - -// MustInt 列查询结果转Int -func (star *StarResultCol) MustInt() []int { - var res []int - var tmp int - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = 0 - case float64: - tmp = int(vtype) - case float32: - tmp = int(vtype) - case string: - tmps, _ := strconv.ParseInt(vtype, 10, 64) - tmp = int(tmps) - case int: - tmp = vtype - case int32: - tmp = int(vtype) - case int64: - tmp = int(vtype) - case uint64: - tmp = int(vtype) - case time.Time: - tmp = int(vtype.Unix()) - default: - tmpt := string(vtype.([]byte)) - tmps, _ := strconv.ParseInt(tmpt, 10, 64) - tmp = int(tmps) - } - res = append(res, tmp) - } - return res -} - -// MustDate 列查询结果转Date(time.Time) -func (star *StarResultCol) MustDate(layout string) []time.Time { - var res []time.Time - var tmp time.Time - for _, v := range star.Result { - switch vtype := v.(type) { - case nil: - tmp = time.Time{} - case float64: - tmp = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000) - case float32: - tmp = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000) - case string: - tmp, _ = time.Parse(layout, vtype) - case int: - tmp = time.Unix(int64(vtype), 0) - case int32: - tmp = time.Unix(int64(vtype), 0) - case int64: - tmp = time.Unix(vtype, 0) - case uint64: - tmp = time.Unix(int64(vtype), 0) - case time.Time: - tmp = vtype - default: - tmpt := string(vtype.([]byte)) - tmp, _ = time.Parse(layout, tmpt) - } - res = append(res, tmp) - } - return res -} - -// IsNil 检测是不是nil 列查询结果是不是nil -func (star *StarResultCol) IsNil(name string) []bool { - var res []bool - var tmp bool - for _, v := range star.Result { - switch v.(type) { - case nil: - tmp = true - default: - tmp = false - } - res = append(res, tmp) - } - return res -} - -// IsNil 检测是不是nil -func (star *StarResult) IsNil(name string) bool { - num, ok := star.columnref[name] - if !ok { - return false - } - tmp := star.Result[num] - switch tmp.(type) { - case nil: - return true - default: - return false - } -} - -// MustDate 列查询结果转Date -func (star *StarResult) MustDate(name, layout string) time.Time { - var res time.Time - num, ok := star.columnref[name] - if !ok { - return time.Time{} - } - tmp := star.Result[num] - switch vtype := tmp.(type) { - case nil: - res = time.Time{} - case float64: - res = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000) - case float32: - res = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000) - case string: - res, _ = time.Parse(layout, vtype) - case int: - res = time.Unix(int64(vtype), 0) - case int32: - res = time.Unix(int64(vtype), 0) - case int64: - res = time.Unix(vtype, 0) - case uint64: - res = time.Unix(int64(vtype), 0) - case time.Time: - res = vtype - default: - res, _ = time.Parse(layout, string(tmp.([]byte))) - } - return res -} - -// MustInt64 列查询结果转int64 -func (star *StarResult) MustInt64(name string) int64 { - var res int64 - num, ok := star.columnref[name] - if !ok { - return 0 - } - tmp := star.Result[num] - switch vtype := tmp.(type) { - case nil: - res = 0 - case float64: - res = int64(vtype) - case float32: - res = int64(vtype) - case string: - res, _ = strconv.ParseInt(vtype, 10, 64) - case int: - res = int64(vtype) - case int32: - res = int64(vtype) - case uint64: - res = int64(vtype) - case int64: - res = vtype - case time.Time: - res = int64(vtype.Unix()) - default: - res, _ = strconv.ParseInt(string(tmp.([]byte)), 10, 64) - } - return res -} - -// MustInt32 列查询结果转Int32 -func (star *StarResult) MustInt32(name string) int32 { - var res int32 - num, ok := star.columnref[name] - if !ok { - return 0 - } - tmp := star.Result[num] - switch vtype := tmp.(type) { - case nil: - res = 0 - case float64: - res = int32(vtype) - case float32: - res = int32(vtype) - case string: - ress, _ := strconv.ParseInt(vtype, 10, 32) - res = int32(ress) - case int: - res = int32(vtype) - case int32: - res = vtype - case int64: - res = int32(vtype) - case uint64: - res = int32(vtype) - case time.Time: - res = int32(vtype.Unix()) - default: - ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 32) - res = int32(ress) - } - return res -} - -// MustUint 列查询结果转uint -func (star *StarResult) MustUint64(name string) uint64 { - var res uint64 - num, ok := star.columnref[name] - if !ok { - return 0 - } - tmp := star.Result[num] - switch vtype := tmp.(type) { - case nil: - res = 0 - case float64: - res = uint64(vtype) - case float32: - res = uint64(vtype) - case string: - res, _ = strconv.ParseUint(vtype, 10, 64) - case uint64: - res = vtype - case int32: - res = uint64(vtype) - case int64: - res = uint64(vtype) - case time.Time: - res = uint64(vtype.Unix()) - default: - res, _ = strconv.ParseUint(string(tmp.([]byte)), 10, 64) - } - return res -} - -// MustString 列查询结果转string -func (star *StarResult) MustString(name string) string { - var res string - num, ok := star.columnref[name] - if !ok { - return "" - } - switch vtype := star.Result[num].(type) { - case nil: - res = "" - case string: - res = vtype - case int64: - res = strconv.FormatInt(vtype, 10) - case int32: - res = strconv.Itoa(int(vtype)) - case bool: - res = strconv.FormatBool(vtype) - case float64: - res = strconv.FormatFloat(vtype, 'f', 10, 64) - case float32: - res = strconv.FormatFloat(float64(vtype), 'f', 10, 32) - case int: - res = strconv.Itoa(vtype) - case uint64: - res = strconv.FormatUint(vtype, 10) - case time.Time: - res = vtype.String() - default: - res = string(vtype.([]byte)) - } - return res -} - -// MustFloat64 列查询结果转float64 -func (star *StarResult) MustFloat64(name string) float64 { - var res float64 - num, ok := star.columnref[name] - if !ok { - return 0 - } - switch vtype := star.Result[num].(type) { - case nil: - res = 0 - case string: - res, _ = strconv.ParseFloat(vtype, 64) - case float64: - res = vtype - case int: - res = float64(vtype) - case int64: - res = float64(vtype) - case int32: - res = float64(vtype) - case float32: - res = float64(vtype) - case uint64: - res = float64(vtype) - case time.Time: - res = float64(vtype.Unix()) - default: - res, _ = strconv.ParseFloat(string(vtype.([]byte)), 64) - } - return res -} - -// MustFloat32 列查询结果转float32 -func (star *StarResult) MustFloat32(name string) float32 { - var res float32 - num, ok := star.columnref[name] - if !ok { - return 0 - } - switch vtype := star.Result[num].(type) { - case nil: - res = 0 - case string: - tmp, _ := strconv.ParseFloat(vtype, 32) - res = float32(tmp) - case float64: - res = float32(vtype) - case float32: - res = vtype - case int: - res = float32(vtype) - case int64: - res = float32(vtype) - case int32: - res = float32(vtype) - case uint64: - res = float32(vtype) - case time.Time: - res = float32(vtype.Unix()) - default: - tmp, _ := strconv.ParseFloat(string(vtype.([]byte)), 32) - res = float32(tmp) - } - return res -} - -// MustInt 列查询结果转int -func (star *StarResult) MustInt(name string) int { - var res int - num, ok := star.columnref[name] - if !ok { - return 0 - } - tmp := star.Result[num] - switch vtype := tmp.(type) { - case nil: - res = 0 - case float64: - res = int(vtype) - case float32: - res = int(vtype) - case string: - ress, _ := strconv.ParseInt(vtype, 10, 64) - res = int(ress) - case int: - res = vtype - case int32: - res = int(vtype) - case int64: - res = int(vtype) - case uint64: - res = int(vtype) - case time.Time: - res = int(vtype.Unix()) - default: - ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 64) - res = int(ress) - } - return res -} - -// MustBool 列查询结果转bool -func (star *StarResult) MustBool(name string) bool { - var res bool - num, ok := star.columnref[name] - if !ok { - return false - } - tmp := star.Result[num] - switch vtype := tmp.(type) { - case nil: - res = false - case bool: - res = vtype - case float64: - if vtype > 0 { - res = true - } else { - res = false - } - case float32: - if vtype > 0 { - res = true - } else { - res = false - } - case int: - if vtype > 0 { - res = true - } else { - res = false - } - case int32: - if vtype > 0 { - res = true - } else { - res = false - } - case int64: - if vtype > 0 { - res = true - } else { - res = false - } - case uint64: - if vtype > 0 { - res = true - } else { - res = false - } - case string: - res, _ = strconv.ParseBool(vtype) - default: - res, _ = strconv.ParseBool(string(vtype.([]byte))) - } - return res -} - -// MustBytes 列查询结果转byte -func (star *StarResult) MustBytes(name string) []byte { - num, ok := star.columnref[name] - if !ok { - return []byte{} - } - res := star.Result[num].([]byte) - return res -} - -// Rescan 重新分析结果集 -func (star *StarRows) Rescan() { - star.parserows() -} - -// Col 选择需要进行操作的数据结果列 -func (star *StarRows) Col(name string) *StarResultCol { - result := new(StarResultCol) - if _, ok := star.columnref[name]; !ok { - return result - } - var rescol []interface{} - for _, v := range star.result { - rescol = append(rescol, v[star.columnref[name]]) - } - result.Result = rescol - return result -} - -// Row 选择需要进行操作的数据结果行 -func (star *StarRows) Row(id int) *StarResult { - result := new(StarResult) - if id+1 > len(star.result) { - return result - } - result.Result = star.result[id] - result.Columns = star.Columns - result.ColumnsType = star.ColumnsType - result.columnref = star.columnref - return result -} - -// Close 关闭打开的结果集 -func (star *StarRows) Close() error { - return star.Rows.Close() -} - -func (star *StarRows) parserows() error { - defer func() { - star.parsed = true - }() - star.result = [][]interface{}{} - star.columnref = make(map[string]int) - star.StringResult = []map[string]string{} - star.Columns, _ = star.Rows.Columns() - types, _ := star.Rows.ColumnTypes() - for _, v := range types { - star.ColumnsType = append(star.ColumnsType, v.ScanType()) - } - scanArgs := make([]interface{}, len(star.Columns)) - values := make([]interface{}, len(star.Columns)) - for i := range values { - star.columnref[star.Columns[i]] = i - scanArgs[i] = &values[i] - } - for star.Rows.Next() { - if err := star.Rows.Scan(scanArgs...); err != nil { - return err - } - record := make(map[string]string) - var rescopy []interface{} - for i, col := range values { - rescopy = append(rescopy, col) - switch vtype := col.(type) { - case float32: - record[star.Columns[i]] = strconv.FormatFloat(float64(vtype), 'f', -1, 64) - case float64: - record[star.Columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64) - case int64: - record[star.Columns[i]] = strconv.FormatInt(vtype, 10) - case int32: - record[star.Columns[i]] = strconv.FormatInt(int64(vtype), 10) - case int: - record[star.Columns[i]] = strconv.Itoa(vtype) - case string: - record[star.Columns[i]] = vtype - case bool: - record[star.Columns[i]] = strconv.FormatBool(vtype) - case time.Time: - record[star.Columns[i]] = vtype.String() - case nil: - record[star.Columns[i]] = "" - default: - record[star.Columns[i]] = string(vtype.([]byte)) - } - } - star.result = append(star.result, rescopy) - star.StringResult = append(star.StringResult, record) - } - star.Length = len(star.StringResult) - return nil -} - -func (star *StarDB) Begin() (*StarTx, error) { - tx, err := star.Db.Begin() - if err != nil { - return nil, err - } - stx := new(StarTx) - stx.Db = star - stx.Tx = tx - return stx, err -} - -func (star *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) { - tx, err := star.Db.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - stx := new(StarTx) - stx.Db = star - stx.Tx = tx - return stx, err -} - -func (star *StarTx) Query(args ...interface{}) (*StarRows, error) { - return star.query(nil, args...) -} - -func (star *StarTx) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { - return star.query(ctx, args...) -} -func (star *StarTx) ExecStmt(args ...interface{}) (sql.Result, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.Prepare(args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.Exec(args[1:]...) -} - -func (star *StarTx) ExecStmtContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.PrepareContext(ctx, args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.ExecContext(ctx, args[1:]...) -} - -func (star *StarTx) QueryStmt(args ...interface{}) (*StarRows, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.Prepare(args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.Query(args[1:]...) -} - -func (star *StarTx) QueryStmtContext(ctx context.Context, args ...interface{}) (*StarRows, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.PrepareContext(ctx, args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.QueryContext(ctx, args[1:]...) -} - -func (star *StarTx) query(ctx context.Context, args ...interface{}) (*StarRows, error) { - var err error - var rows *sql.Rows - effect := new(StarRows) - if err = star.Db.Ping(); err != nil { - return effect, err - } - if len(args) == 0 { - return effect, errors.New("no args") - } - var para []interface{} - for k, v := range args { - if k != 0 { - switch vtype := v.(type) { - default: - para = append(para, vtype) - } - } - } - if ctx == nil { - if rows, err = star.Tx.Query(args[0].(string), para...); err != nil { - return effect, err - } - } else { - if rows, err = star.Tx.QueryContext(ctx, args[0].(string), para...); err != nil { - return effect, err - } - } - effect.Rows = rows - if !star.Db.ManualScan { - err = effect.parserows() - } - return effect, err -} - -func (star *StarDB) Query(args ...interface{}) (*StarRows, error) { - return star.query(nil, args...) -} - -func (star *StarDB) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { - return star.query(ctx, args...) -} - -func (star *StarDB) QueryStmt(args ...interface{}) (*StarRows, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.Prepare(args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.Query(args[1:]...) -} - -func (star *StarDB) QueryStmtContext(ctx context.Context, args ...interface{}) (*StarRows, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.PrepareContext(ctx, args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.QueryContext(ctx, args[1:]...) -} - -func (star *StarDB) ExecStmt(args ...interface{}) (sql.Result, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.Prepare(args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.Exec(args[1:]...) -} - -func (star *StarDB) ExecStmtContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - if len(args) <= 1 { - return nil, errors.New("parameter not enough") - } - stmt, err := star.PrepareContext(ctx, args[0].(string)) - if err != nil { - return nil, err - } - defer stmt.Close() - return stmt.ExecContext(ctx, args[1:]...) -} - -func (star *StarDBStmt) Query(args ...interface{}) (*StarRows, error) { - return star.query(nil, args...) -} - -func (star *StarDBStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) { - return star.query(ctx, args...) -} - -func (star *StarDBStmt) Exec(args ...interface{}) (sql.Result, error) { - return star.exec(nil, args...) -} - -func (star *StarDBStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - return star.exec(ctx, args...) -} - -func (star *StarDBStmt) Close() error { - return star.Stmt.Close() -} - -func (star *StarDBStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) { - var err error - var rows *sql.Rows - effect := new(StarRows) - if len(args) == 0 { - return effect, errors.New("no args") - } - if ctx == nil { - if rows, err = star.Stmt.Query(args...); err != nil { - return effect, err - } - } else { - if rows, err = star.Stmt.QueryContext(ctx, args...); err != nil { - return effect, err - } - } - effect.Rows = rows - if !star.Db.ManualScan { - err = effect.parserows() - } - return effect, err -} - -func (star *StarDBStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { - if len(args) == 0 { - return nil, errors.New("no args") - } - if ctx == nil { - return star.Stmt.Exec(args...) - } - return star.Stmt.ExecContext(ctx, args...) -} - -func (star *StarDB) Prepare(sqlStr string) (*StarDBStmt, error) { - stmt := new(StarDBStmt) - stmtS, err := star.Db.Prepare(sqlStr) - if err != nil { - return nil, err - } - stmt.Stmt = stmtS - stmt.Db = star - return stmt, err -} - -func (star *StarDB) PrepareContext(ctx context.Context, sqlStr string) (*StarDBStmt, error) { - stmt := new(StarDBStmt) - stmtS, err := star.Db.PrepareContext(ctx, sqlStr) - if err != nil { - return nil, err - } - stmt.Stmt = stmtS - stmt.Db = star - return stmt, err -} - -func (star *StarTx) Prepare(sqlStr string) (*StarDBStmt, error) { - stmt := new(StarDBStmt) - stmtS, err := star.Tx.Prepare(sqlStr) - if err != nil { - return nil, err - } - stmt.Stmt = stmtS - stmt.Db = star.Db - return stmt, err -} - -func (star *StarTx) PrepareContext(ctx context.Context, sqlStr string) (*StarDBStmt, error) { - stmt := new(StarDBStmt) - stmtS, err := star.Tx.PrepareContext(ctx, sqlStr) - if err != nil { - return nil, err - } - stmt.Db = star.Db - stmt.Stmt = stmtS - return stmt, err -} - -// Query 进行Query操作 -func (star *StarDB) query(ctx context.Context, args ...interface{}) (*StarRows, error) { - var err error - var rows *sql.Rows - effect := new(StarRows) - if err = star.Db.Ping(); err != nil { - return effect, err - } - if len(args) == 0 { - return effect, errors.New("no args") - } - var para []interface{} - for k, v := range args { - if k != 0 { - switch vtype := v.(type) { - default: - para = append(para, vtype) - } - } - } - if ctx == nil { - if rows, err = star.Db.Query(args[0].(string), para...); err != nil { - return effect, err - } - } else { - if rows, err = star.Db.QueryContext(ctx, args[0].(string), para...); err != nil { - return effect, err - } - } - effect.Rows = rows - if !star.ManualScan { - err = effect.parserows() - } - return effect, err -} - -// Open 打开一个新的数据库 -func (star *StarDB) Open(Method, ConnStr string) error { - var err error - star.Db, err = sql.Open(Method, ConnStr) - return err -} - -// Close 关闭打开的数据库 -func (star *StarDB) Close() error { - return star.Db.Close() -} - -func (star *StarDB) Ping() error { - return star.Db.Ping() -} - -func (star *StarDB) Stats() sql.DBStats { - return star.Db.Stats() -} - -func (star *StarDB) SetMaxOpenConns(n int) { - star.Db.SetMaxOpenConns(n) -} - -func (star *StarDB) SetMaxIdleConns(n int) { - star.Db.SetMaxIdleConns(n) -} - -func (star *StarDB) PingContext(ctx context.Context) error { - return star.Db.PingContext(ctx) -} - -func (star *StarDB) Conn(ctx context.Context) (*sql.Conn, error) { - return star.Db.Conn(ctx) -} - -func (star *StarDB) Exec(args ...interface{}) (sql.Result, error) { - return star.exec(nil, args...) -} -func (star *StarDB) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - return star.exec(ctx, args...) -} - -// Exec 执行Exec操作 -func (star *StarDB) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { - var err error - if err = star.Db.Ping(); err != nil { - return nil, err - } - if len(args) == 0 { - return nil, errors.New("no args") - } - var para []interface{} - for k, v := range args { - if k != 0 { - switch vtype := v.(type) { - default: - para = append(para, vtype) - } - } - } - if ctx == nil { - return star.Db.Exec(args[0].(string), para...) - } - return star.Db.ExecContext(ctx, args[0].(string), para...) -} - -func (star *StarTx) Exec(args ...interface{}) (sql.Result, error) { - return star.exec(nil, args...) -} -func (star *StarTx) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) { - return star.exec(ctx, args...) -} - -func (star *StarTx) exec(ctx context.Context, args ...interface{}) (sql.Result, error) { - var err error - if err = star.Db.Ping(); err != nil { - return nil, err - } - if len(args) == 0 { - return nil, errors.New("no args") - } - var para []interface{} - for k, v := range args { - if k != 0 { - switch vtype := v.(type) { - default: - para = append(para, vtype) - } - } - } - if ctx == nil { - return star.Tx.Exec(args[0].(string), para...) - } - return star.Tx.ExecContext(ctx, args[0].(string), para...) -} - -func (star *StarTx) Commit() error { - return star.Tx.Commit() -} - -func (star *StarTx) Rollback() error { - return star.Tx.Rollback() -} - -// FetchAll 把结果集全部转为key-value型数据 -func FetchAll(rows *sql.Rows) (error, map[int]map[string]string) { - var ii int = 0 - records := make(map[int]map[string]string) - columns, err := rows.Columns() - if err != nil { - return err, records - } - scanArgs := make([]interface{}, len(columns)) - values := make([]interface{}, len(columns)) - for i := range values { - scanArgs[i] = &values[i] - } - for rows.Next() { - if err := rows.Scan(scanArgs...); err != nil { - return err, records - } - record := make(map[string]string) - for i, col := range values { - switch vtype := col.(type) { - case float64: - record[columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64) - case int64: - record[columns[i]] = strconv.FormatInt(vtype, 10) - case string: - record[columns[i]] = vtype - case nil: - record[columns[i]] = "" - default: - record[columns[i]] = string(vtype.([]byte)) - } - } - records[ii] = record - ii++ - } - return nil, records -} diff --git a/testing/batch_test.go b/testing/batch_test.go new file mode 100644 index 0000000..a8a78a3 --- /dev/null +++ b/testing/batch_test.go @@ -0,0 +1,563 @@ +package testing + +import ( + "context" + "testing" + "time" + + "b612.me/stardb" + _ "modernc.org/sqlite" +) + +type TestUser struct { + ID int64 `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Age int `db:"age"` + CreatedAt time.Time `db:"created_at"` +} + +func setupBatchTestDB(t *testing.T) *stardb.StarDB { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + + _, err = db.Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + email TEXT NOT NULL, + age INTEGER, + created_at DATETIME + ) + `) + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + return db +} + +func TestStarDB_BatchInsert_Basic(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + {"Charlie", "charlie@example.com", 35}, + } + + result, err := db.BatchInsert("users", columns, values) + if err != nil { + t.Fatalf("BatchInsert failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 3 { + t.Errorf("Expected 3 rows affected, got %d", affected) + } + + // Verify insertion + rows, err := db.Query("SELECT COUNT(*) as count FROM users") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + count := rows.Row(0).MustInt("count") + if count != 3 { + t.Errorf("Expected 3 rows in database, got %d", count) + } +} + +func TestStarDB_BatchInsert_Single(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + } + + result, err := db.BatchInsert("users", columns, values) + if err != nil { + t.Fatalf("BatchInsert failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_BatchInsert_Empty(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + values := [][]interface{}{} + + _, err := db.BatchInsert("users", columns, values) + if err == nil { + t.Error("Expected error with empty values, got nil") + } +} + +func TestStarDB_BatchInsert_Large(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + var values [][]interface{} + + // Insert 100 rows + for i := 0; i < 100; i++ { + values = append(values, []interface{}{ + "User" + string(rune(i)), + "user" + string(rune(i)) + "@example.com", + 20 + i%50, + }) + } + + result, err := db.BatchInsert("users", columns, values) + if err != nil { + t.Fatalf("BatchInsert failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 100 { + t.Errorf("Expected 100 rows affected, got %d", affected) + } + + // Verify + rows, err := db.Query("SELECT COUNT(*) as count FROM users") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + count := rows.Row(0).MustInt("count") + if count != 100 { + t.Errorf("Expected 100 rows in database, got %d", count) + } +} + +func TestStarDB_BatchInsertContext(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + ctx := context.Background() + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + } + + result, err := db.BatchInsertContext(ctx, "users", columns, values) + if err != nil { + t.Fatalf("BatchInsertContext failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 2 { + t.Errorf("Expected 2 rows affected, got %d", affected) + } +} + +func TestStarDB_BatchInsertContext_Timeout(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) // Ensure timeout + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + } + + _, err := db.BatchInsertContext(ctx, "users", columns, values) + if err == nil { + t.Error("Expected timeout error, got nil") + } +} + +func TestStarDB_BatchInsertStructs_Basic(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + users := []TestUser{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, + {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()}, + {Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()}, + } + + result, err := db.BatchInsertStructs("users", users, "id") + if err != nil { + t.Fatalf("BatchInsertStructs failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 3 { + t.Errorf("Expected 3 rows affected, got %d", affected) + } + + // Verify insertion + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 3 { + t.Errorf("Expected 3 rows, got %d", rows.Length()) + } + + // Verify first user + name := rows.Row(0).MustString("name") + if name != "Alice" { + t.Errorf("Expected first user 'Alice', got '%s'", name) + } +} + +func TestStarDB_BatchInsertStructs_Single(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + users := []TestUser{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, + } + + result, err := db.BatchInsertStructs("users", users, "id") + if err != nil { + t.Fatalf("BatchInsertStructs failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_BatchInsertStructs_Empty(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + users := []TestUser{} + + _, err := db.BatchInsertStructs("users", users, "id") + if err == nil { + t.Error("Expected error with empty slice, got nil") + } +} + +func TestStarDB_BatchInsertStructs_NotSlice(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25} + + _, err := db.BatchInsertStructs("users", user, "id") + if err == nil { + t.Error("Expected error with non-slice, got nil") + } +} + +func TestStarDB_BatchInsertStructs_Pointer(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + users := []TestUser{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, + {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()}, + } + + result, err := db.BatchInsertStructs("users", &users, "id") + if err != nil { + t.Fatalf("BatchInsertStructs with pointer failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 2 { + t.Errorf("Expected 2 rows affected, got %d", affected) + } +} + +func TestStarDB_BatchInsertStructsContext(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + ctx := context.Background() + users := []TestUser{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()}, + {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()}, + } + + result, err := db.BatchInsertStructsContext(ctx, "users", users, "id") + if err != nil { + t.Fatalf("BatchInsertStructsContext failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 2 { + t.Errorf("Expected 2 rows affected, got %d", affected) + } +} + +func TestStarDB_BatchInsertStructs_Large(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + var users []TestUser + for i := 0; i < 50; i++ { + users = append(users, TestUser{ + Name: "User" + string(rune(i)), + Email: "user" + string(rune(i)) + "@example.com", + Age: 20 + i%30, + CreatedAt: time.Now(), + }) + } + + result, err := db.BatchInsertStructs("users", users, "id") + if err != nil { + t.Fatalf("BatchInsertStructs failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 50 { + t.Errorf("Expected 50 rows affected, got %d", affected) + } + + // Verify + rows, err := db.Query("SELECT COUNT(*) as count FROM users") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + count := rows.Row(0).MustInt("count") + if count != 50 { + t.Errorf("Expected 50 rows in database, got %d", count) + } +} + +func TestStarDB_BatchInsert_VerifyData(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + columns := []string{"name", "email", "age"} + values := [][]interface{}{ + {"Alice", "alice@example.com", 25}, + {"Bob", "bob@example.com", 30}, + } + + _, err := db.BatchInsert("users", columns, values) + if err != nil { + t.Fatalf("BatchInsert failed: %v", err) + } + + // Verify data integrity + rows, err := db.Query("SELECT name, email, age FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + // Check Alice + row0 := rows.Row(0) + if row0.MustString("name") != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", row0.MustString("name")) + } + if row0.MustString("email") != "alice@example.com" { + t.Errorf("Expected email 'alice@example.com', got '%s'", row0.MustString("email")) + } + if row0.MustInt("age") != 25 { + t.Errorf("Expected age 25, got %d", row0.MustInt("age")) + } + + // Check Bob + row1 := rows.Row(1) + if row1.MustString("name") != "Bob" { + t.Errorf("Expected name 'Bob', got '%s'", row1.MustString("name")) + } + if row1.MustString("email") != "bob@example.com" { + t.Errorf("Expected email 'bob@example.com', got '%s'", row1.MustString("email")) + } + if row1.MustInt("age") != 30 { + t.Errorf("Expected age 30, got %d", row1.MustInt("age")) + } +} + +func TestStarDB_BatchInsertStructs_VerifyData(t *testing.T) { + db := setupBatchTestDB(t) + defer db.Close() + + now := time.Now() + users := []TestUser{ + {Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: now}, + {Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: now}, + } + + _, err := db.BatchInsertStructs("users", users, "id") + if err != nil { + t.Fatalf("BatchInsertStructs failed: %v", err) + } + + // Query and verify with ORM + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var resultUsers []TestUser + err = rows.Orm(&resultUsers) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if len(resultUsers) != 2 { + t.Fatalf("Expected 2 users, got %d", len(resultUsers)) + } + + // Verify Alice + if resultUsers[0].Name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", resultUsers[0].Name) + } + if resultUsers[0].Email != "alice@example.com" { + t.Errorf("Expected email 'alice@example.com', got '%s'", resultUsers[0].Email) + } + if resultUsers[0].Age != 25 { + t.Errorf("Expected age 25, got %d", resultUsers[0].Age) + } + + // Verify Bob + if resultUsers[1].Name != "Bob" { + t.Errorf("Expected name 'Bob', got '%s'", resultUsers[1].Name) + } + if resultUsers[1].Email != "bob@example.com" { + t.Errorf("Expected email 'bob@example.com', got '%s'", resultUsers[1].Email) + } + if resultUsers[1].Age != 30 { + t.Errorf("Expected age 30, got %d", resultUsers[1].Age) + } +} + +// Benchmark tests +func BenchmarkBatchInsert_10(b *testing.B) { + db := setupBatchTestDB(&testing.T{}) + defer db.Close() + + columns := []string{"name", "email", "age"} + var values [][]interface{} + for i := 0; i < 10; i++ { + values = append(values, []interface{}{"User", "user@example.com", 25}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + db.BatchInsert("users", columns, values) + db.Exec("DELETE FROM users") + } +} + +func BenchmarkBatchInsert_100(b *testing.B) { + db := setupBatchTestDB(&testing.T{}) + defer db.Close() + + columns := []string{"name", "email", "age"} + var values [][]interface{} + for i := 0; i < 100; i++ { + values = append(values, []interface{}{"User", "user@example.com", 25}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + db.BatchInsert("users", columns, values) + db.Exec("DELETE FROM users") + } +} + +func BenchmarkBatchInsertStructs_10(b *testing.B) { + db := setupBatchTestDB(&testing.T{}) + defer db.Close() + + var users []TestUser + for i := 0; i < 10; i++ { + users = append(users, TestUser{ + Name: "User", + Email: "user@example.com", + Age: 25, + CreatedAt: time.Now(), + }) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + db.BatchInsertStructs("users", users, "id") + db.Exec("DELETE FROM users") + } +} + +func BenchmarkBatchInsertStructs_100(b *testing.B) { + db := setupBatchTestDB(&testing.T{}) + defer db.Close() + + var users []TestUser + for i := 0; i < 100; i++ { + users = append(users, TestUser{ + Name: "User", + Email: "user@example.com", + Age: 25, + CreatedAt: time.Now(), + }) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + db.BatchInsertStructs("users", users, "id") + db.Exec("DELETE FROM users") + } +} diff --git a/testing/go.mod b/testing/go.mod new file mode 100644 index 0000000..e0d31a5 --- /dev/null +++ b/testing/go.mod @@ -0,0 +1,23 @@ +module b612.me/stardb/testing + +go 1.25.6 + +require ( + b612.me/stardb v0.0.0 + modernc.org/sqlite v1.46.1 +) + +require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect + golang.org/x/sys v0.37.0 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect +) + +replace b612.me/stardb => ../ diff --git a/testing/go.sum b/testing/go.sum new file mode 100644 index 0000000..2131c5c --- /dev/null +++ b/testing/go.sum @@ -0,0 +1,53 @@ +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= +github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA= +golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ= +golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU= +modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= diff --git a/testing/orm_test.go b/testing/orm_test.go new file mode 100644 index 0000000..d856d2d --- /dev/null +++ b/testing/orm_test.go @@ -0,0 +1,691 @@ +package testing + +import ( + "context" + "testing" + "time" +) + +type User struct { + ID int64 `db:"id"` + Name string `db:"name"` + Email string `db:"email"` + Age int `db:"age"` + Balance float64 `db:"balance"` + Active bool `db:"active"` + CreatedAt time.Time `db:"created_at"` +} + +type Profile struct { + UserID int `db:"user_id"` + Bio string `db:"bio"` + Avatar string `db:"avatar"` +} + +type NestedUser struct { + ID int64 `db:"id"` + Name string `db:"name"` + Profile `db:"---"` +} + +func TestStarRows_Orm_Single(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user User + err = rows.Orm(&user) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if user.Name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", user.Name) + } + + if user.Email != "alice@example.com" { + t.Errorf("Expected email 'alice@example.com', got '%s'", user.Email) + } + + if user.Age != 25 { + t.Errorf("Expected age 25, got %d", user.Age) + } + + if user.Balance != 100.50 { + t.Errorf("Expected balance 100.50, got %f", user.Balance) + } + + if !user.Active { + t.Errorf("Expected user to be active") + } +} + +func TestStarRows_Orm_Multiple(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var users []User + err = rows.Orm(&users) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if len(users) != 3 { + t.Fatalf("Expected 3 users, got %d", len(users)) + } + + expectedNames := []string{"Alice", "Bob", "Charlie"} + for i, user := range users { + if user.Name != expectedNames[i] { + t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name) + } + } +} + +func TestStarRows_Orm_Empty(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NonExistent") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var users []User + err = rows.Orm(&users) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if len(users) != 0 { + t.Errorf("Expected 0 users, got %d", len(users)) + } +} + +func TestStarRows_Orm_NotPointer(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user User + err = rows.Orm(user) // Not a pointer + if err == nil { + t.Errorf("Expected error when passing non-pointer, got nil") + } +} + +func TestStarDB_Insert(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + user := User{ + Name: "David", + Email: "david@example.com", + Age: 40, + Balance: 400.00, + Active: true, + CreatedAt: time.Now(), + } + + result, err := db.Insert(&user, "users", "id") + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + lastID, err := result.LastInsertId() + if err != nil { + t.Fatalf("LastInsertId failed: %v", err) + } + + if lastID <= 0 { + t.Errorf("Expected positive last insert ID, got %d", lastID) + } + + // Verify insertion + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var insertedUser User + err = rows.Orm(&insertedUser) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if insertedUser.Name != "David" { + t.Errorf("Expected name 'David', got '%s'", insertedUser.Name) + } + + if insertedUser.Email != "david@example.com" { + t.Errorf("Expected email 'david@example.com', got '%s'", insertedUser.Email) + } + + if insertedUser.Age != 40 { + t.Errorf("Expected age 40, got %d", insertedUser.Age) + } +} + +func TestStarDB_InsertContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx := context.Background() + user := User{ + Name: "Eve", + Email: "eve@example.com", + Age: 28, + Balance: 250.00, + Active: true, + CreatedAt: time.Now(), + } + + result, err := db.InsertContext(ctx, &user, "users", "id") + if err != nil { + t.Fatalf("InsertContext failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_Update(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + // First, get the user + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user User + err = rows.Orm(&user) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + // Update the user + user.Age = 26 + user.Balance = 150.75 + + result, err := db.Update(&user, "users", "id") + if err != nil { + t.Fatalf("Update failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } + + // Verify the update + rows2, err := db.Query("SELECT * FROM users WHERE id = ?", user.ID) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows2.Close() + + var updatedUser User + err = rows2.Orm(&updatedUser) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if updatedUser.Age != 26 { + t.Errorf("Expected age 26, got %d", updatedUser.Age) + } + + if updatedUser.Balance != 150.75 { + t.Errorf("Expected balance 150.75, got %f", updatedUser.Balance) + } +} + +func TestStarDB_UpdateContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx := context.Background() + + // Get user + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + var user User + err = rows.Orm(&user) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + // Update + user.Active = false + result, err := db.UpdateContext(ctx, &user, "users", "id") + if err != nil { + t.Fatalf("UpdateContext failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_QueryX(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + user := User{ + Name: "Alice", + } + + rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name") + if err != nil { + t.Fatalf("QueryX failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } + + var result User + err = rows.Orm(&result) + if err != nil { + t.Fatalf("Orm failed: %v", err) + } + + if result.Name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", result.Name) + } +} + +func TestStarDB_QueryXContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx := context.Background() + user := User{ + Age: 30, + } + + rows, err := db.QueryXContext(ctx, &user, "SELECT * FROM users WHERE age = ?", ":age") + if err != nil { + t.Fatalf("QueryXContext failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } +} + +func TestStarDB_QueryXS(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + users := []User{ + {Name: "Alice"}, + {Name: "Bob"}, + } + + results, err := db.QueryXS(&users, "SELECT * FROM users WHERE name = ?", ":name") + if err != nil { + t.Fatalf("QueryXS failed: %v", err) + } + + if len(results) != 2 { + t.Errorf("Expected 2 results, got %d", len(results)) + } + + for i, rows := range results { + if rows.Length() != 1 { + t.Errorf("Expected 1 row in result %d, got %d", i, rows.Length()) + } + } +} + +func TestStarDB_ExecX(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + user := User{ + Name: "Alice", + Age: 99, + } + + result, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name") + if err != nil { + t.Fatalf("ExecX failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } + + // Verify + rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + age := rows.Row(0).MustInt("age") + if age != 99 { + t.Errorf("Expected age 99, got %d", age) + } +} + +func TestStarDB_ExecXContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx := context.Background() + user := User{ + Name: "Bob", + Active: false, + } + + result, err := db.ExecXContext(ctx, &user, "UPDATE users SET active = ? WHERE name = ?", ":active", ":name") + if err != nil { + t.Fatalf("ExecXContext failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_ExecXS(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + users := []User{ + {Name: "Alice", Age: 26}, + {Name: "Bob", Age: 31}, + } + + results, err := db.ExecXS(&users, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name") + if err != nil { + t.Fatalf("ExecXS failed: %v", err) + } + + if len(results) != 2 { + t.Errorf("Expected 2 results, got %d", len(results)) + } + + for i, result := range results { + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed for result %d: %v", i, err) + } + if affected != 1 { + t.Errorf("Expected 1 row affected in result %d, got %d", i, affected) + } + } +} + +func TestStarTx_Insert(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + + user := User{ + Name: "Frank", + Email: "frank@example.com", + Age: 45, + Balance: 500.00, + Active: true, + CreatedAt: time.Now(), + } + + result, err := tx.Insert(&user, "users", "id") + if err != nil { + tx.Rollback() + t.Fatalf("Tx.Insert failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit failed: %v", err) + } + + lastID, err := result.LastInsertId() + if err != nil { + t.Fatalf("LastInsertId failed: %v", err) + } + + if lastID <= 0 { + t.Errorf("Expected positive last insert ID, got %d", lastID) + } + + // Verify + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Frank") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } +} + +func TestStarTx_Update(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + + // Get user + rows, err := tx.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + tx.Rollback() + t.Fatalf("Query failed: %v", err) + } + + var user User + err = rows.Orm(&user) + rows.Close() + if err != nil { + tx.Rollback() + t.Fatalf("Orm failed: %v", err) + } + + // Update + user.Age = 27 + result, err := tx.Update(&user, "users", "id") + if err != nil { + tx.Rollback() + t.Fatalf("Tx.Update failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarTx_QueryX(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + defer tx.Rollback() + + user := User{ + Name: "Charlie", + } + + rows, err := tx.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name") + if err != nil { + t.Fatalf("Tx.QueryX failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } +} + +func TestStarTx_ExecX(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + + user := User{ + Name: "Charlie", + Age: 36, + } + + result, err := tx.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name") + if err != nil { + tx.Rollback() + t.Fatalf("Tx.ExecX failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarTx_Rollback(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + + user := User{ + Name: "Rollback User", + Email: "rollback@example.com", + Age: 50, + Balance: 600.00, + Active: true, + CreatedAt: time.Now(), + } + + _, err = tx.Insert(&user, "users", "id") + if err != nil { + tx.Rollback() + t.Fatalf("Tx.Insert failed: %v", err) + } + + // Rollback instead of commit + err = tx.Rollback() + if err != nil { + t.Fatalf("Rollback failed: %v", err) + } + + // Verify the insert was rolled back + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Rollback User") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 0 { + t.Errorf("Expected 0 rows after rollback, got %d", rows.Length()) + } +} + +func TestNamedParameterEscape(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + user := User{ + Name: "Alice", + } + + // Test escaped colon + rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", `\:name`) + if err != nil { + t.Fatalf("QueryX with escaped parameter failed: %v", err) + } + defer rows.Close() + + // Should use literal ":name" string, not the field value + if rows.Length() != 0 { + t.Errorf("Expected 0 rows with literal ':name', got %d", rows.Length()) + } +} diff --git a/testing/pool_test.go b/testing/pool_test.go new file mode 100644 index 0000000..d9a9552 --- /dev/null +++ b/testing/pool_test.go @@ -0,0 +1,272 @@ +package testing + +import ( + "testing" + "time" + + "b612.me/stardb" + _ "modernc.org/sqlite" +) + +func TestDefaultPoolConfig(t *testing.T) { + config := stardb.DefaultPoolConfig() + + if config.MaxOpenConns != 25 { + t.Errorf("Expected MaxOpenConns 25, got %d", config.MaxOpenConns) + } + + if config.MaxIdleConns != 5 { + t.Errorf("Expected MaxIdleConns 5, got %d", config.MaxIdleConns) + } + + if config.ConnMaxLifetime != time.Hour { + t.Errorf("Expected ConnMaxLifetime 1 hour, got %v", config.ConnMaxLifetime) + } + + if config.ConnMaxIdleTime != 10*time.Minute { + t.Errorf("Expected ConnMaxIdleTime 10 minutes, got %v", config.ConnMaxIdleTime) + } +} + +func TestStarDB_SetPoolConfig(t *testing.T) { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + config := &stardb.PoolConfig{ + MaxOpenConns: 50, + MaxIdleConns: 10, + ConnMaxLifetime: 30 * time.Minute, + ConnMaxIdleTime: 5 * time.Minute, + } + + db.SetPoolConfig(config) + + stats := db.Stats() + if stats.MaxOpenConnections != 50 { + t.Errorf("Expected MaxOpenConnections 50, got %d", stats.MaxOpenConnections) + } +} + +func TestStarDB_SetPoolConfig_Partial(t *testing.T) { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Only set MaxOpenConns + config := &stardb.PoolConfig{ + MaxOpenConns: 100, + } + + db.SetPoolConfig(config) + + stats := db.Stats() + if stats.MaxOpenConnections != 100 { + t.Errorf("Expected MaxOpenConnections 100, got %d", stats.MaxOpenConnections) + } +} + +func TestStarDB_SetPoolConfig_Zero(t *testing.T) { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Zero values should be ignored + config := &stardb.PoolConfig{ + MaxOpenConns: 0, + MaxIdleConns: 0, + } + + db.SetPoolConfig(config) + + // Should not panic or error + err = db.Ping() + if err != nil { + t.Errorf("Ping failed after SetPoolConfig with zero values: %v", err) + } +} + +func TestOpenWithPool_Default(t *testing.T) { + db, err := stardb.OpenWithPool("sqlite", ":memory:", nil) + if err != nil { + t.Fatalf("OpenWithPool failed: %v", err) + } + defer db.Close() + + err = db.Ping() + if err != nil { + t.Errorf("Ping failed: %v", err) + } + + stats := db.Stats() + if stats.MaxOpenConnections != 25 { + t.Errorf("Expected default MaxOpenConnections 25, got %d", stats.MaxOpenConnections) + } +} + +func TestOpenWithPool_Custom(t *testing.T) { + config := &stardb.PoolConfig{ + MaxOpenConns: 15, + MaxIdleConns: 3, + ConnMaxLifetime: 20 * time.Minute, + ConnMaxIdleTime: 3 * time.Minute, + } + + db, err := stardb.OpenWithPool("sqlite", ":memory:", config) + if err != nil { + t.Fatalf("OpenWithPool failed: %v", err) + } + defer db.Close() + + err = db.Ping() + if err != nil { + t.Errorf("Ping failed: %v", err) + } + + stats := db.Stats() + if stats.MaxOpenConnections != 15 { + t.Errorf("Expected MaxOpenConnections 15, got %d", stats.MaxOpenConnections) + } +} + +func TestOpenWithPool_InvalidDriver(t *testing.T) { + config := stardb.DefaultPoolConfig() + + db, err := stardb.OpenWithPool("invalid_driver", "invalid_conn", config) + if err == nil { + db.Close() + t.Error("Expected error with invalid driver, got nil") + } +} + +func TestOpenWithPool_Query(t *testing.T) { + db, err := stardb.OpenWithPool("sqlite", ":memory:", nil) + if err != nil { + t.Fatalf("OpenWithPool failed: %v", err) + } + defer db.Close() + + // Create table + _, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)`) + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + // Insert data + _, err = db.Exec(`INSERT INTO test (name) VALUES (?)`, "Alice") + if err != nil { + t.Fatalf("Failed to insert data: %v", err) + } + + // Query data + rows, err := db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } +} + +func TestPoolConfig_AllFields(t *testing.T) { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + config := &stardb.PoolConfig{ + MaxOpenConns: 30, + MaxIdleConns: 8, + ConnMaxLifetime: 45 * time.Minute, + ConnMaxIdleTime: 7 * time.Minute, + } + + db.SetPoolConfig(config) + + // Verify by checking stats + stats := db.Stats() + if stats.MaxOpenConnections != 30 { + t.Errorf("Expected MaxOpenConnections 30, got %d", stats.MaxOpenConnections) + } + + // Test that connections work + err = db.Ping() + if err != nil { + t.Errorf("Ping failed after SetPoolConfig: %v", err) + } +} + +func TestPoolConfig_NegativeValues(t *testing.T) { + db := stardb.NewStarDB() + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Negative values should be ignored + config := &stardb.PoolConfig{ + MaxOpenConns: -1, + MaxIdleConns: -1, + } + + db.SetPoolConfig(config) + + // Should not panic or error + err = db.Ping() + if err != nil { + t.Errorf("Ping failed after SetPoolConfig with negative values: %v", err) + } +} + +func TestOpenWithPool_MultipleConnections(t *testing.T) { + config := &stardb.PoolConfig{ + MaxOpenConns: 5, + MaxIdleConns: 2, + } + + db, err := stardb.OpenWithPool("sqlite", ":memory:", config) + if err != nil { + t.Fatalf("OpenWithPool failed: %v", err) + } + defer db.Close() + + // Create table + _, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)`) + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + // Perform multiple operations + for i := 0; i < 10; i++ { + _, err = db.Exec(`INSERT INTO test (value) VALUES (?)`, "test") + if err != nil { + t.Errorf("Insert %d failed: %v", i, err) + } + } + + // Verify + rows, err := db.Query("SELECT COUNT(*) as count FROM test") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + count := rows.Row(0).MustInt("count") + if count != 10 { + t.Errorf("Expected 10 rows, got %d", count) + } +} diff --git a/testing/result_test.go b/testing/result_test.go new file mode 100644 index 0000000..0d56fde --- /dev/null +++ b/testing/result_test.go @@ -0,0 +1,231 @@ +package testing + +import ( + "testing" + "time" +) + +func TestStarResult_MustString(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + + name := row.MustString("name") + if name != "Alice" { + t.Errorf("Expected 'Alice', got '%s'", name) + } + + email := row.MustString("email") + if email != "alice@example.com" { + t.Errorf("Expected 'alice@example.com', got '%s'", email) + } + + // Test non-existent column + nonexistent := row.MustString("nonexistent") + if nonexistent != "" { + t.Errorf("Expected empty string for non-existent column, got '%s'", nonexistent) + } +} + +func TestStarResult_MustInt(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + + age := row.MustInt("age") + if age != 30 { + t.Errorf("Expected age 30, got %d", age) + } + + id := row.MustInt("id") + if id <= 0 { + t.Errorf("Expected positive id, got %d", id) + } +} + +func TestStarResult_MustInt64(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Charlie") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + + age := row.MustInt64("age") + if age != 35 { + t.Errorf("Expected age 35, got %d", age) + } +} + +func TestStarResult_MustFloat64(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + + balance := row.MustFloat64("balance") + if balance != 100.50 { + t.Errorf("Expected balance 100.50, got %f", balance) + } +} + +func TestStarResult_MustBool(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + // Alice - active + row := rows.Row(0) + active := row.MustBool("active") + if !active { + t.Errorf("Expected Alice to be active") + } + + // Charlie - inactive + row = rows.Row(2) + active = row.MustBool("active") + if active { + t.Errorf("Expected Charlie to be inactive") + } +} + +func TestStarResult_IsNil(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + // Insert a row with NULL value + _, err := db.Exec("INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, NULL, ?, ?, ?)", + "David", "david@example.com", 150.0, true, time.Now()) + if err != nil { + t.Fatalf("Insert failed: %v", err) + } + + rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + row := rows.Row(0) + + if !row.IsNil("age") { + t.Errorf("Expected age to be NULL") + } + + if row.IsNil("name") { + t.Errorf("Expected name to not be NULL") + } +} + +func TestStarResultCol_MustString(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + col := rows.Col("name") + names := col.MustString() + + expected := []string{"Alice", "Bob", "Charlie"} + for i, name := range names { + if name != expected[i] { + t.Errorf("Expected name '%s', got '%s'", expected[i], name) + } + } +} + +func TestStarResultCol_MustInt(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + col := rows.Col("age") + ages := col.MustInt() + + expected := []int{25, 30, 35} + for i, age := range ages { + if age != expected[i] { + t.Errorf("Expected age %d, got %d", expected[i], age) + } + } +} + +func TestStarResultCol_MustFloat64(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + col := rows.Col("balance") + balances := col.MustFloat64() + + expected := []float64{100.50, 200.75, 300.25} + for i, balance := range balances { + if balance != expected[i] { + t.Errorf("Expected balance %f, got %f", expected[i], balance) + } + } +} + +func TestStarResultCol_MustBool(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + col := rows.Col("active") + actives := col.MustBool() + + expected := []bool{true, true, false} + for i, active := range actives { + if active != expected[i] { + t.Errorf("Expected active %v, got %v", expected[i], active) + } + } +} diff --git a/testing/rows_test.go b/testing/rows_test.go new file mode 100644 index 0000000..8f18aa0 --- /dev/null +++ b/testing/rows_test.go @@ -0,0 +1,103 @@ +package testing + +import ( + "testing" +) + +func TestStarRows_Row(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + // Test first row + row := rows.Row(0) + name := row.MustString("name") + if name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", name) + } + + // Test out of bounds + row = rows.Row(999) + if len(row.Result()) != 0 { + t.Errorf("Expected empty result for out of bounds index") + } +} + +func TestStarRows_Col(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users ORDER BY name") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + // Test column extraction + col := rows.Col("name") + names := col.MustString() + + if len(names) != 3 { + t.Errorf("Expected 3 names, got %d", len(names)) + } + + if names[0] != "Alice" { + t.Errorf("Expected first name 'Alice', got '%s'", names[0]) + } + + // Test non-existent column + col = rows.Col("nonexistent") + if len(col.Result()) != 0 { + t.Errorf("Expected empty result for non-existent column") + } +} + +func TestStarRows_Rescan(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.ManualScan = true + rows, err := db.Query("SELECT * FROM users") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + err = rows.Rescan() + if err != nil { + t.Fatalf("Rescan failed: %v", err) + } + + if rows.Length() != 3 { + t.Errorf("Expected 3 rows, got %d", rows.Length()) + } +} + +func TestStarRows_StringResult(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT name, age FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if len(rows.StringResult()) != 1 { + t.Fatalf("Expected 1 string result, got %d", len(rows.StringResult())) + } + + record := rows.StringResult()[0] + if record["name"] != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", record["name"]) + } + + if record["age"] != "25" { + t.Errorf("Expected age '25', got '%s'", record["age"]) + } +} diff --git a/testing/stardb_test.go b/testing/stardb_test.go new file mode 100644 index 0000000..ffa0261 --- /dev/null +++ b/testing/stardb_test.go @@ -0,0 +1,233 @@ +package testing + +import ( + "context" + "testing" + "time" + + "b612.me/stardb" + _ "modernc.org/sqlite" +) + +func TestStarDB_Open(t *testing.T) { + db := &stardb.StarDB{} + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Errorf("Open failed: %v", err) + } + defer db.Close() + + err = db.Ping() + if err != nil { + t.Errorf("Ping failed: %v", err) + } +} + +func TestStarDB_Query(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.Query("SELECT * FROM users WHERE age > ?", 25) + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 2 { + t.Errorf("Expected 2 rows, got %d", rows.Length()) + } + + if len(rows.Columns()) != 7 { + t.Errorf("Expected 7 columns, got %d", len(rows.Columns())) + } +} + +func TestStarDB_QueryContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("QueryContext failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } +} + +func TestStarDB_Exec(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + result, err := db.Exec("UPDATE users SET age = ? WHERE name = ?", 26, "Alice") + if err != nil { + t.Fatalf("Exec failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_ExecContext(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + ctx := context.Background() + result, err := db.ExecContext(ctx, "DELETE FROM users WHERE name = ?", "Charlie") + if err != nil { + t.Fatalf("ExecContext failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_QueryStmt(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + rows, err := db.QueryStmt("SELECT * FROM users WHERE age > ?", 25) + if err != nil { + t.Fatalf("QueryStmt failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 2 { + t.Errorf("Expected 2 rows, got %d", rows.Length()) + } +} + +func TestStarDB_ExecStmt(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + result, err := db.ExecStmt("UPDATE users SET age = ? WHERE name = ?", 27, "Bob") + if err != nil { + t.Fatalf("ExecStmt failed: %v", err) + } + + affected, err := result.RowsAffected() + if err != nil { + t.Fatalf("RowsAffected failed: %v", err) + } + + if affected != 1 { + t.Errorf("Expected 1 row affected, got %d", affected) + } +} + +func TestStarDB_Prepare(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + stmt, err := db.Prepare("SELECT * FROM users WHERE name = ?") + if err != nil { + t.Fatalf("Prepare failed: %v", err) + } + defer stmt.Close() + + rows, err := stmt.Query("Alice") + if err != nil { + t.Fatalf("Stmt.Query failed: %v", err) + } + defer rows.Close() + + if rows.Length() != 1 { + t.Errorf("Expected 1 row, got %d", rows.Length()) + } +} + +func TestStarDB_Transaction(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + + _, err = tx.Exec("UPDATE users SET age = ? WHERE name = ?", 28, "Alice") + if err != nil { + tx.Rollback() + t.Fatalf("Tx.Exec failed: %v", err) + } + + err = tx.Commit() + if err != nil { + t.Fatalf("Commit failed: %v", err) + } + + // Verify the change + rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + age := rows.Row(0).MustInt("age") + if age != 28 { + t.Errorf("Expected age 28, got %d", age) + } +} + +func TestStarDB_TransactionRollback(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatalf("Begin failed: %v", err) + } + + _, err = tx.Exec("UPDATE users SET age = ? WHERE name = ?", 99, "Alice") + if err != nil { + t.Fatalf("Tx.Exec failed: %v", err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("Rollback failed: %v", err) + } + + // Verify the change was rolled back + rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice") + if err != nil { + t.Fatalf("Query failed: %v", err) + } + defer rows.Close() + + age := rows.Row(0).MustInt("age") + if age == 99 { + t.Errorf("Expected age to be rolled back, but got %d", age) + } +} + +func TestStarDB_SetMaxConnections(t *testing.T) { + db := setupTestDB(t) + defer db.Close() + + db.SetMaxOpenConns(10) + db.SetMaxIdleConns(5) + + stats := db.Stats() + if stats.MaxOpenConnections != 10 { + t.Errorf("Expected MaxOpenConnections 10, got %d", stats.MaxOpenConnections) + } +} diff --git a/testing/testing.go b/testing/testing.go new file mode 100644 index 0000000..f6abbd0 --- /dev/null +++ b/testing/testing.go @@ -0,0 +1,47 @@ +package testing + +import ( + "testing" + + "b612.me/stardb" + _ "modernc.org/sqlite" +) + +// setupTestDB creates a test database with sample data +// This function is only available when building with -tags=testing +func setupTestDB(t *testing.T) *stardb.StarDB { + db := &stardb.StarDB{} + err := db.Open("sqlite", ":memory:") + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + + // Create test table + _, err = db.Exec(` + CREATE TABLE users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + email TEXT NOT NULL, + age INTEGER, + balance REAL, + active BOOLEAN, + created_at DATETIME + ) + `) + if err != nil { + t.Fatalf("Failed to create table: %v", err) + } + + // Insert test data + _, err = db.Exec(` + INSERT INTO users (name, email, age, balance, active, created_at) VALUES + ('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'), + ('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'), + ('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00') + `) + if err != nil { + t.Fatalf("Failed to insert test data: %v", err) + } + + return db +} diff --git a/tx.go b/tx.go new file mode 100644 index 0000000..541f453 --- /dev/null +++ b/tx.go @@ -0,0 +1,156 @@ +package stardb + +import ( + "context" + "database/sql" + "errors" +) + +// StarTx represents a database transaction +type StarTx struct { + tx *sql.Tx + db *StarDB +} + +// Query executes a query within the transaction +func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) { + return t.query(nil, query, args...) +} + +// QueryContext executes a query with context within the transaction +func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + return t.query(ctx, query, args...) +} + +// query is the internal query implementation +func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + if err := t.db.Ping(); err != nil { + return nil, err + } + + var rows *sql.Rows + var err error + + if ctx == nil { + rows, err = t.tx.Query(query, args...) + } else { + rows, err = t.tx.QueryContext(ctx, query, args...) + } + + if err != nil { + return nil, err + } + + starRows := &StarRows{ + rows: rows, + db: t.db, + } + + if !t.db.ManualScan { + err = starRows.parse() + } + + return starRows, err +} + +// Exec executes a query within the transaction +func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) { + return t.exec(nil, query, args...) +} + +// ExecContext executes a query with context within the transaction +func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + return t.exec(ctx, query, args...) +} + +// exec is the internal exec implementation +func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if err := t.db.Ping(); err != nil { + return nil, err + } + + if ctx == nil { + return t.tx.Exec(query, args...) + } + return t.tx.ExecContext(ctx, query, args...) +} + +// Prepare creates a prepared statement within the transaction +func (t *StarTx) Prepare(query string) (*StarStmt, error) { + stmt, err := t.tx.Prepare(query) + if err != nil { + return nil, err + } + return &StarStmt{stmt: stmt, db: t.db}, nil +} + +// PrepareContext creates a prepared statement with context +func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) { + stmt, err := t.tx.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + return &StarStmt{stmt: stmt, db: t.db}, nil +} + +// QueryStmt executes a prepared statement query within the transaction +func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := t.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Query(args...) +} + +// QueryStmtContext executes a prepared statement query with context +func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := t.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.QueryContext(ctx, args...) +} + +// ExecStmt executes a prepared statement within the transaction +func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := t.Prepare(query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.Exec(args...) +} + +// ExecStmtContext executes a prepared statement with context +func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { + if query == "" { + return nil, errors.New("query string cannot be empty") + } + stmt, err := t.PrepareContext(ctx, query) + if err != nil { + return nil, err + } + defer stmt.Close() + return stmt.ExecContext(ctx, args...) +} + +// Commit commits the transaction +func (t *StarTx) Commit() error { + return t.tx.Commit() +} + +// Rollback rolls back the transaction +func (t *StarTx) Rollback() error { + return t.tx.Rollback() +}