Compare commits
No commits in common. "master" and "v1.0.0" have entirely different histories.
24
.gitignore
vendored
24
.gitignore
vendored
@ -1,24 +0,0 @@
|
|||||||
# Binaries
|
|
||||||
*.exe
|
|
||||||
*.exe~
|
|
||||||
*.dll
|
|
||||||
*.so
|
|
||||||
*.dylib
|
|
||||||
*.test
|
|
||||||
|
|
||||||
# Go build cache
|
|
||||||
*.out
|
|
||||||
|
|
||||||
# Dependency directories
|
|
||||||
vendor/
|
|
||||||
|
|
||||||
# IDE
|
|
||||||
.idea
|
|
||||||
|
|
||||||
# OS
|
|
||||||
.DS_Store
|
|
||||||
|
|
||||||
# Agent local governance files
|
|
||||||
.sentrux/
|
|
||||||
agent_readme.md
|
|
||||||
target.md
|
|
||||||
8
.idea/.gitignore
generated
vendored
8
.idea/.gitignore
generated
vendored
@ -1,8 +0,0 @@
|
|||||||
# 默认忽略的文件
|
|
||||||
/shelf/
|
|
||||||
/workspace.xml
|
|
||||||
# 数据源本地存储已忽略文件
|
|
||||||
/dataSources/
|
|
||||||
/dataSources.local.xml
|
|
||||||
# 基于编辑器的 HTTP 客户端请求
|
|
||||||
/httpRequests/
|
|
||||||
8
.idea/modules.xml
generated
8
.idea/modules.xml
generated
@ -1,8 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectModuleManager">
|
|
||||||
<modules>
|
|
||||||
<module fileurl="file://$PROJECT_DIR$/.idea/stardb.iml" filepath="$PROJECT_DIR$/.idea/stardb.iml" />
|
|
||||||
</modules>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
||||||
17
.idea/stardb.iml
generated
17
.idea/stardb.iml
generated
@ -1,17 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="WEB_MODULE" version="4">
|
|
||||||
<component name="Go" enabled="true">
|
|
||||||
<buildTags>
|
|
||||||
<option name="customFlags">
|
|
||||||
<array>
|
|
||||||
<option value="testing" />
|
|
||||||
</array>
|
|
||||||
</option>
|
|
||||||
</buildTags>
|
|
||||||
</component>
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="inheritedJdk" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
||||||
6
.idea/vcs.xml
generated
6
.idea/vcs.xml
generated
@ -1,6 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="VcsDirectoryMappings">
|
|
||||||
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
||||||
59
CHANGELOG.MD
59
CHANGELOG.MD
@ -1,59 +0,0 @@
|
|||||||
# Changelog
|
|
||||||
|
|
||||||
本文档记录 StarDB 的主要变更。
|
|
||||||
|
|
||||||
## [Unreleased] - 2026-03-20
|
|
||||||
|
|
||||||
### Added
|
|
||||||
- 新增可判定错误类型(`errors.Is` 友好):
|
|
||||||
- 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized`
|
|
||||||
- 参数/目标校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` 等
|
|
||||||
- 映射与批量写入:`ErrColumnNotFound` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch` 等
|
|
||||||
- 新增流式查询能力(DB / Tx / Stmt):
|
|
||||||
- `QueryRaw` / `QueryRawContext`
|
|
||||||
- `ScanEach` / `ScanEachContext`
|
|
||||||
- `ScanEachORM` / `ScanEachORMContext`
|
|
||||||
- 新增 NULL 安全取值:
|
|
||||||
- `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime`
|
|
||||||
- 新增 ORM 行为开关:
|
|
||||||
- `SetStrictORM(true)` 启用严格列检查
|
|
||||||
- `ClearReflectCache()` 清理反射缓存
|
|
||||||
- 新增 SQL 运行时可观测能力:
|
|
||||||
- Hook:`SetSQLHooks` `SetSQLBeforeHook` `SetSQLAfterHook`
|
|
||||||
- 慢 SQL 阈值:`SetSQLSlowThreshold`
|
|
||||||
- 指纹:`SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments`
|
|
||||||
- 指纹计数:`SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters`
|
|
||||||
- Context 元信息:`SQLHookMetaFromContext` `BatchExecMetaFromContext`
|
|
||||||
- 新增占位符方言适配:
|
|
||||||
- `SetPlaceholderStyle(PlaceholderQuestion|PlaceholderDollar)`(`?` / `$1,$2...`)
|
|
||||||
- 新增批量插入分片控制:
|
|
||||||
- `SetBatchInsertMaxRows`
|
|
||||||
- `SetBatchInsertMaxParams`
|
|
||||||
- 常见驱动参数上限自动识别(SQLite / PostgreSQL / MySQL / SQL Server)
|
|
||||||
|
|
||||||
### Changed
|
|
||||||
- 批量写入在开启分片或触发参数阈值时,改为事务内多分片执行,降低单条 SQL 过大风险。
|
|
||||||
- 分片批量写入结果语义明确:
|
|
||||||
- `RowsAffected()` 返回分片累计值
|
|
||||||
- `LastInsertId()` 返回最后一个分片的 insert id
|
|
||||||
- 内部结构按模块归档到 `internal/`,保持外部 API 稳定:
|
|
||||||
- `internal/convert`
|
|
||||||
- `internal/scanutil`
|
|
||||||
- `internal/sqlplaceholder`
|
|
||||||
- `internal/sqlruntime`
|
|
||||||
- README 重写为面向使用场景的说明,补齐能力边界、接入顺序和 API 细节。
|
|
||||||
|
|
||||||
### Behavior Notes
|
|
||||||
- 默认查询 `Query` 仍为内存模式(解析到 `StarRows`)。
|
|
||||||
- 关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`。
|
|
||||||
- SQL Hook、指纹与指纹计数默认关闭,需显式开启。
|
|
||||||
- 批量分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。
|
|
||||||
|
|
||||||
### Tests
|
|
||||||
- 新增/补强测试覆盖:
|
|
||||||
- 流式查询与流式 ORM
|
|
||||||
- NULL 安全取值
|
|
||||||
- 严格 ORM 行为
|
|
||||||
- 占位符转换
|
|
||||||
- SQL Hook、慢 SQL 阈值、指纹模式、注释保留开关、指纹计数
|
|
||||||
- BatchInsert 分片(按行数/参数)、失败回滚与结果语义
|
|
||||||
201
LICENSE
201
LICENSE
@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright [yyyy] [name of copyright owner]
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
560
README.MD
560
README.MD
@ -1,560 +0,0 @@
|
|||||||
# StarDB
|
|
||||||
|
|
||||||
StarDB 是对 Go `database/sql` 的轻量封装,目标是把常见的数据库操作做得更直白:
|
|
||||||
- 少量 API 覆盖日常 CRUD、事务、批量写入、结构体映射。
|
|
||||||
- 兼容原生 `database/sql` 心智,不引入重量级依赖。
|
|
||||||
- 在可读性、可调试性和性能之间做实用平衡。
|
|
||||||
|
|
||||||
适合:
|
|
||||||
- 想保留 SQL 控制权,但不想反复写样板代码。
|
|
||||||
- 需要轻量 ORM 映射(不是全功能 ORM)。
|
|
||||||
- 需要在生产里追踪 SQL(可选 Hook + 慢 SQL 阈值)。
|
|
||||||
|
|
||||||
不适合:
|
|
||||||
- 需要完整领域模型关系管理、自动迁移、复杂查询 DSL 的项目。
|
|
||||||
|
|
||||||
## 安装
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go get b612.me/stardb
|
|
||||||
```
|
|
||||||
|
|
||||||
要求:
|
|
||||||
- Go `>= 1.16`
|
|
||||||
- 自行选择并导入数据库驱动(本库只封装 `database/sql`)
|
|
||||||
|
|
||||||
## 常见 DSN 示例
|
|
||||||
|
|
||||||
下面示例都可以直接用于 `db.Open(driver, dsn)`,替换为实际账号、密码、库名即可。
|
|
||||||
|
|
||||||
### MySQL(`github.com/go-sql-driver/mysql`)
|
|
||||||
|
|
||||||
```go
|
|
||||||
import _ "github.com/go-sql-driver/mysql"
|
|
||||||
|
|
||||||
dsn := "app:secret@tcp(127.0.0.1:3306)/demo?charset=utf8mb4&parseTime=true&loc=Local"
|
|
||||||
if err := db.Open("mysql", dsn); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
常用参数说明:
|
|
||||||
- `charset=utf8mb4`:避免字符集问题。
|
|
||||||
- `parseTime=true`:把 `DATETIME/TIMESTAMP` 解析为 `time.Time`。
|
|
||||||
- `loc=Local`:指定时间解析时区(也可改成 `Asia/Shanghai`)。
|
|
||||||
|
|
||||||
### PostgreSQL(`github.com/lib/pq`)
|
|
||||||
|
|
||||||
```go
|
|
||||||
import _ "github.com/lib/pq"
|
|
||||||
|
|
||||||
dsn := "host=127.0.0.1 port=5432 user=postgres password=secret dbname=demo sslmode=disable"
|
|
||||||
if err := db.Open("postgres", dsn); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
也可以用 URL 形式:
|
|
||||||
|
|
||||||
```go
|
|
||||||
urlDSN := "postgres://postgres:secret@127.0.0.1:5432/demo?sslmode=disable"
|
|
||||||
if err := db.Open("postgres", urlDSN); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### SQLite(`modernc.org/sqlite`)
|
|
||||||
|
|
||||||
```go
|
|
||||||
import _ "modernc.org/sqlite"
|
|
||||||
|
|
||||||
// 文件数据库
|
|
||||||
if err := db.Open("sqlite", "file:demo.db"); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 内存数据库(适合测试)
|
|
||||||
if err := db.Open("sqlite", "file::memory:?cache=shared"); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Windows 路径建议使用 `file:C:/data/demo.db` 这种写法,跨平台更稳。
|
|
||||||
|
|
||||||
## 能力概览
|
|
||||||
|
|
||||||
| 能力 | 主要 API | 说明 |
|
|
||||||
|---|---|---|
|
|
||||||
| 连接与连接池 | `Open` `Close` `Ping` `SetPoolConfig` | 保留原生 `sql.DB` 用法 |
|
|
||||||
| 常规查询 | `Query` `QueryContext` | 自动解析为 `StarRows` |
|
|
||||||
| 流式查询 | `QueryRaw` `ScanEach` | 大结果集不必全量进内存 |
|
|
||||||
| 流式 ORM | `ScanEachORM` | 逐行映射结构体 |
|
|
||||||
| 安全取值 | `Get*` `GetNull*` | 明确错误与 NULL 语义 |
|
|
||||||
| 结构体 ORM | `rows.Orm` | 支持单个、切片、数组映射 |
|
|
||||||
| 命名参数 | `QueryX` `ExecX` | `:field` 绑定结构体字段 |
|
|
||||||
| 结构体写入 | `Insert` `Update` | 通过 `db` tag 生成 SQL |
|
|
||||||
| 批量写入 | `BatchInsert` `BatchInsertStructs` `SetBatchInsertMaxRows` `SetBatchInsertMaxParams` | 多行插入,支持按行数/参数阈值分片 |
|
|
||||||
| 事务 | `Begin/Commit/Rollback` `WithTx` | 手动或托管事务 |
|
|
||||||
| 可观测性 | `SetSQLHooks` `SetSQLSlowThreshold` `SetSQLFingerprintEnabled` `SetSQLFingerprintMode` `SetSQLFingerprintKeepComments` `SetSQLFingerprintCounterEnabled` `SQLFingerprintCounters` `ResetSQLFingerprintCounters` `SQLHookMetaFromContext` `BatchExecMetaFromContext` | Before/After Hook,默认关闭,支持指纹策略、命中计数与批量分片元信息 |
|
|
||||||
| 方言占位符 | `SetPlaceholderStyle` | `?` / `$1,$2...` |
|
|
||||||
| 查询构建 | `QueryBuilder` | 支持 `JOIN/GROUP BY/HAVING` |
|
|
||||||
|
|
||||||
## 场景选型
|
|
||||||
|
|
||||||
| 场景 | 首选 API | 说明 |
|
|
||||||
|---|---|---|
|
|
||||||
| 中小结果集查询 | `Query` + `rows.Orm` | 读取方便,开发效率高 |
|
|
||||||
| 大结果集查询 | `ScanEach` / `ScanEachORM` | 逐行处理,避免全量缓存 |
|
|
||||||
| 需要底层 `Scan` 控制 | `QueryRaw` | 直接返回 `*sql.Rows` |
|
|
||||||
| 批量写入 | `BatchInsert` + 分片阈值 | 控制单条 SQL 大小与参数数量 |
|
|
||||||
| SQL 可观测 | `SetSQLHooks` + `SetSQLSlowThreshold` + 指纹配置 | 支持慢 SQL、指纹、分片元信息 |
|
|
||||||
|
|
||||||
## 快速开始
|
|
||||||
|
|
||||||
```go
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Age int `db:"age"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
if err := db.Open("sqlite", "test.db"); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT id, name, age FROM users WHERE age >= ?", 18)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
if err := rows.Orm(&users); err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("users: %d", len(users))
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## 接入流程
|
|
||||||
|
|
||||||
按下面顺序接入,可在开发阶段先固定运行边界:
|
|
||||||
|
|
||||||
1. 建立连接并设置连接池。
|
|
||||||
2. 在查询路径中区分内存模式与流式模式。
|
|
||||||
3. 在批量写入路径设置分片阈值(按行数和参数数)。
|
|
||||||
4. 启用 SQL Hook、慢 SQL 阈值、指纹策略。
|
|
||||||
5. 在调用侧统一使用 `errors.Is` 判定错误类别。
|
|
||||||
|
|
||||||
一个常用初始化示例:
|
|
||||||
|
|
||||||
```go
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
if err := db.Open("mysql", dsn); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetPoolConfig(&stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 25,
|
|
||||||
MaxIdleConns: 5,
|
|
||||||
ConnMaxLifetime: time.Hour,
|
|
||||||
ConnMaxIdleTime: 10 * time.Minute,
|
|
||||||
})
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(500)
|
|
||||||
db.SetBatchInsertMaxParams(60000)
|
|
||||||
|
|
||||||
db.SetSQLSlowThreshold(200 * time.Millisecond)
|
|
||||||
db.SetSQLFingerprintEnabled(true)
|
|
||||||
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals)
|
|
||||||
db.SetSQLFingerprintKeepComments(false)
|
|
||||||
```
|
|
||||||
|
|
||||||
## API 指南
|
|
||||||
|
|
||||||
### 1) 连接与连接池
|
|
||||||
|
|
||||||
```go
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
_ = db.Open("mysql", "user:pass@tcp(localhost:3306)/app")
|
|
||||||
|
|
||||||
db.SetPoolConfig(&stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 25,
|
|
||||||
MaxIdleConns: 5,
|
|
||||||
ConnMaxLifetime: time.Hour,
|
|
||||||
ConnMaxIdleTime: 10 * time.Minute,
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
也可以直接拿到底层连接:
|
|
||||||
|
|
||||||
```go
|
|
||||||
raw := db.DB()
|
|
||||||
raw.SetMaxOpenConns(50)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2) 三种查询模式
|
|
||||||
|
|
||||||
#### A. 内存模式(默认)
|
|
||||||
|
|
||||||
`Query` 会把结果解析为 `StarRows`,适合中小结果集:
|
|
||||||
|
|
||||||
```go
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE active = ?", true)
|
|
||||||
if err != nil { /* ... */ }
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
for i := 0; i < rows.Length(); i++ {
|
|
||||||
row := rows.Row(i)
|
|
||||||
_ = row.MustString("name")
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### B. 原生流式模式
|
|
||||||
|
|
||||||
`QueryRaw` 返回 `*sql.Rows`,完全按原生 `Scan` 处理:
|
|
||||||
|
|
||||||
```go
|
|
||||||
rawRows, err := db.QueryRaw("SELECT id, name FROM users")
|
|
||||||
if err != nil { /* ... */ }
|
|
||||||
defer rawRows.Close()
|
|
||||||
```
|
|
||||||
|
|
||||||
#### C. 回调流式模式(常用)
|
|
||||||
|
|
||||||
`ScanEach` 逐行回调,避免全量缓存:
|
|
||||||
|
|
||||||
关闭内存预读时,使用 `QueryRaw` / `ScanEach` / `ScanEachORM`,不使用 `Query`。
|
|
||||||
|
|
||||||
```go
|
|
||||||
err := db.ScanEach("SELECT id, name FROM users", func(row *stardb.StarResult) error {
|
|
||||||
id := row.MustInt64("id")
|
|
||||||
name := row.MustString("name")
|
|
||||||
_ = id
|
|
||||||
_ = name
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
可通过 `stardb.ErrScanStopped` 提前终止:
|
|
||||||
|
|
||||||
```go
|
|
||||||
count := 0
|
|
||||||
_ = db.ScanEach("SELECT * FROM users", func(row *stardb.StarResult) error {
|
|
||||||
count++
|
|
||||||
if count >= 1000 {
|
|
||||||
return stardb.ErrScanStopped
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3) 流式 ORM(逐行映射)
|
|
||||||
|
|
||||||
`ScanEachORM` 将每行映射到结构体,再回调。
|
|
||||||
|
|
||||||
```go
|
|
||||||
var model User
|
|
||||||
var users []User
|
|
||||||
|
|
||||||
err := db.ScanEachORM("SELECT id, name, age FROM users", &model, func(target interface{}) error {
|
|
||||||
u := *(target.(*User)) // 注意拷贝一份,target 会被复用
|
|
||||||
users = append(users, u)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
同样支持 `Tx` / `Stmt`:
|
|
||||||
- `tx.ScanEachORM(...)`
|
|
||||||
- `stmt.ScanEachORM(...)`
|
|
||||||
|
|
||||||
### 4) 结果读取与 NULL 语义
|
|
||||||
|
|
||||||
#### Must 系列(无错误,失败给零值)
|
|
||||||
- `MustString` `MustInt64` `MustFloat64` `MustBool` ...
|
|
||||||
|
|
||||||
#### 安全系列(带错误)
|
|
||||||
- `GetString` `GetInt64` `GetFloat64`
|
|
||||||
- `GetNullString` `GetNullInt64` `GetNullFloat64` `GetNullBool` `GetNullTime`
|
|
||||||
|
|
||||||
```go
|
|
||||||
name, err := row.GetString("name")
|
|
||||||
age, err := row.GetNullInt64("age")
|
|
||||||
if age.Valid {
|
|
||||||
// use age.Int64
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5) ORM 映射
|
|
||||||
|
|
||||||
```go
|
|
||||||
type User struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var u User
|
|
||||||
_ = rows.Orm(&u)
|
|
||||||
|
|
||||||
var list []User
|
|
||||||
_ = rows.Orm(&list)
|
|
||||||
```
|
|
||||||
|
|
||||||
严格列检查(字段/SQL 变更敏感场景可开启):
|
|
||||||
|
|
||||||
```go
|
|
||||||
db.SetStrictORM(true)
|
|
||||||
```
|
|
||||||
|
|
||||||
若结构体 tag 大范围调整,可清理反射缓存:
|
|
||||||
|
|
||||||
```go
|
|
||||||
stardb.ClearReflectCache()
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6) 命名参数绑定
|
|
||||||
|
|
||||||
```go
|
|
||||||
type Filter struct {
|
|
||||||
Name string `db:"name"`
|
|
||||||
MinAge int `db:"min_age"`
|
|
||||||
}
|
|
||||||
|
|
||||||
f := Filter{Name: "Alice", MinAge: 18}
|
|
||||||
rows, err := db.QueryX(&f,
|
|
||||||
"SELECT * FROM users WHERE name = ? AND age >= ?",
|
|
||||||
":name", ":min_age")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7) 写入能力
|
|
||||||
|
|
||||||
#### Insert / Update
|
|
||||||
|
|
||||||
```go
|
|
||||||
_, _ = db.Insert(&user, "users", "id") // id 作为自增字段跳过
|
|
||||||
_, _ = db.Update(&user, "users", "id") // id 作为主键
|
|
||||||
```
|
|
||||||
|
|
||||||
#### BatchInsert
|
|
||||||
|
|
||||||
```go
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
}
|
|
||||||
_, _ = db.BatchInsert("users", columns, values)
|
|
||||||
```
|
|
||||||
|
|
||||||
如需避免单条 SQL 过大(参数过多),可打开分片:
|
|
||||||
|
|
||||||
```go
|
|
||||||
db.SetBatchInsertMaxRows(500) // 0 或负数表示关闭分片(默认)
|
|
||||||
db.SetBatchInsertMaxParams(60000) // 0 表示自动识别常见驱动参数上限
|
|
||||||
```
|
|
||||||
|
|
||||||
分片模式下会在一个事务里按块执行,避免部分写入成功。
|
|
||||||
自动识别当前覆盖:SQLite `999`、PostgreSQL `65535`、MySQL `65535`、SQL Server `2100`。
|
|
||||||
|
|
||||||
分片行为细节:
|
|
||||||
- 分片阈值按更严格条件生效:`min(maxRows, maxParams/列数)`(忽略未设置项)。
|
|
||||||
- 分片关闭条件:`maxRows <= 0` 且 `maxParams <= 0` 且未命中驱动自动阈值。
|
|
||||||
- 分片执行失败会回滚整个批次。
|
|
||||||
- 分片结果语义:
|
|
||||||
- `RowsAffected()` 返回所有分片累计值。
|
|
||||||
- `LastInsertId()` 返回最后一个分片的 insert id。
|
|
||||||
|
|
||||||
#### BatchInsertStructs
|
|
||||||
|
|
||||||
```go
|
|
||||||
users := []User{{Name: "Alice"}, {Name: "Bob"}}
|
|
||||||
_, _ = db.BatchInsertStructs("users", users, "id")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8) 事务
|
|
||||||
|
|
||||||
#### 手动事务
|
|
||||||
|
|
||||||
```go
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil { /* ... */ }
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
if _, err := tx.Exec("UPDATE users SET age = age + 1 WHERE id = ?", 1); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return tx.Commit()
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 托管事务(常用)
|
|
||||||
|
|
||||||
```go
|
|
||||||
err := db.WithTx(func(tx *stardb.StarTx) error {
|
|
||||||
if _, err := tx.Exec("UPDATE users SET age = ? WHERE id = ?", 26, 1); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec("INSERT INTO logs (msg) VALUES (?)", "age updated"); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
`WithTx` 规则:
|
|
||||||
- `fn` 返回 `nil` -> `Commit`
|
|
||||||
- `fn` 返回错误 -> `Rollback`
|
|
||||||
- `fn` panic -> `Rollback` 后继续抛出 panic
|
|
||||||
|
|
||||||
### 9) SQL Hook 与慢 SQL 阈值
|
|
||||||
|
|
||||||
默认关闭;仅在显式设置时生效。
|
|
||||||
|
|
||||||
```go
|
|
||||||
db.SetSQLSlowThreshold(200 * time.Millisecond)
|
|
||||||
db.SetSQLFingerprintEnabled(true) // 可选:在 Hook context 附带 SQL 指纹
|
|
||||||
db.SetSQLFingerprintMode(stardb.SQLFingerprintMaskLiterals) // 可选:指纹里脱敏数字/字符串字面量
|
|
||||||
db.SetSQLFingerprintKeepComments(false) // 默认 false:指纹不保留 SQL 注释
|
|
||||||
db.SetSQLFingerprintCounterEnabled(true) // 可选:记录指纹命中次数(内存级)
|
|
||||||
|
|
||||||
db.SetSQLHooks(
|
|
||||||
func(ctx context.Context, query string, args []interface{}) {
|
|
||||||
// before
|
|
||||||
},
|
|
||||||
func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
|
|
||||||
// after
|
|
||||||
if hookMeta, ok := stardb.SQLHookMetaFromContext(ctx); ok {
|
|
||||||
_ = hookMeta.Fingerprint
|
|
||||||
}
|
|
||||||
if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok {
|
|
||||||
// chunked batch insert metadata
|
|
||||||
// meta.ChunkIndex / meta.ChunkCount / meta.ChunkRows ...
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
阈值行为:
|
|
||||||
- `threshold <= 0`:`After` 每次都触发
|
|
||||||
- `threshold > 0`:仅在“慢于阈值”或“执行出错”时触发
|
|
||||||
- 打开 `SetSQLFingerprintEnabled(true)` 后,可从 `SQLHookMetaFromContext` 获取 SQL 指纹
|
|
||||||
- 指纹模式:`SQLFingerprintBasic`(默认,仅归一化)/ `SQLFingerprintMaskLiterals`(归一化 + 字面量脱敏)
|
|
||||||
- `SetSQLFingerprintKeepComments(true)` 可保留注释文本(默认关闭,利于聚合)
|
|
||||||
- `SetSQLFingerprintCounterEnabled(true)` 后,可通过 `SQLFingerprintCounters()` 查看命中次数,`ResetSQLFingerprintCounters()` 清空
|
|
||||||
- 若是分片批量写入,Hook 可通过 `BatchExecMetaFromContext` 读取分片元信息
|
|
||||||
|
|
||||||
Hook 上下文字段说明:
|
|
||||||
- `SQLHookMetaFromContext(ctx)`:
|
|
||||||
- `Fingerprint`:按配置生成的 SQL 指纹。
|
|
||||||
- `BatchExecMetaFromContext(ctx)`(仅分片批量写入):
|
|
||||||
- `ChunkIndex`:当前分片序号(从 1 开始)
|
|
||||||
- `ChunkCount`:总分片数
|
|
||||||
- `ChunkRows`:当前分片行数
|
|
||||||
- `TotalRows`:本次批量总行数
|
|
||||||
- `ColumnCount`:本次写入列数
|
|
||||||
|
|
||||||
### 10) 占位符方言
|
|
||||||
|
|
||||||
```go
|
|
||||||
db.SetPlaceholderStyle(stardb.PlaceholderQuestion) // 默认
|
|
||||||
// 或
|
|
||||||
db.SetPlaceholderStyle(stardb.PlaceholderDollar) // ? -> $1,$2...
|
|
||||||
```
|
|
||||||
|
|
||||||
### 11) QueryBuilder
|
|
||||||
|
|
||||||
```go
|
|
||||||
query, args := stardb.NewQueryBuilder("users u").
|
|
||||||
Select("u.id", "u.name", "COUNT(o.id) AS order_count").
|
|
||||||
Join("LEFT JOIN orders o ON o.user_id = u.id").
|
|
||||||
Where("u.active = ?", true).
|
|
||||||
GroupBy("u.id", "u.name").
|
|
||||||
Having("COUNT(o.id) > ?", 2).
|
|
||||||
OrderBy("order_count DESC").
|
|
||||||
Limit(20).
|
|
||||||
Offset(0).
|
|
||||||
Build()
|
|
||||||
|
|
||||||
_ = query
|
|
||||||
_ = args
|
|
||||||
```
|
|
||||||
|
|
||||||
## 错误处理
|
|
||||||
|
|
||||||
库内置可判定错误,调用侧使用 `errors.Is` 做分支处理:
|
|
||||||
|
|
||||||
```go
|
|
||||||
if errors.Is(err, stardb.ErrDBNotInitialized) {
|
|
||||||
// 未初始化
|
|
||||||
}
|
|
||||||
if errors.Is(err, stardb.ErrColumnNotFound) {
|
|
||||||
// 字段/列不匹配
|
|
||||||
}
|
|
||||||
if errors.Is(err, stardb.ErrNoInsertValues) {
|
|
||||||
// 批量插入空数据
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
常见错误类别:
|
|
||||||
- 生命周期:`ErrDBNotInitialized` `ErrTxNotInitialized` `ErrStmtNotInitialized`
|
|
||||||
- 参数校验:`ErrQueryEmpty` `ErrTargetNil` `ErrTargetNotPointer` ...
|
|
||||||
- 映射问题:`ErrColumnNotFound` `ErrFieldNotFound`
|
|
||||||
- 批量写入:`ErrNoInsertColumns` `ErrNoInsertValues` `ErrBatchRowValueCountMismatch`
|
|
||||||
- 流式回调:`ErrScanFuncNil` `ErrScanORMFuncNil`
|
|
||||||
|
|
||||||
## 使用边界
|
|
||||||
|
|
||||||
1. 这是轻量封装,不是全功能 ORM。
|
|
||||||
- 不做模型关系管理(has-many/association)
|
|
||||||
- 不做自动迁移
|
|
||||||
- 不做复杂查询 DSL
|
|
||||||
|
|
||||||
2. 大结果集优先用流式 API。
|
|
||||||
- `Query` 适合中小结果集
|
|
||||||
- `ScanEach` / `ScanEachORM` 更稳
|
|
||||||
|
|
||||||
3. 日志 Hook 按需打开。
|
|
||||||
- 生产环境最好配合慢 SQL 阈值,减少噪音
|
|
||||||
|
|
||||||
4. `ScanEachORM` 回调里的 target 会复用。
|
|
||||||
- 需要持久化时请拷贝结构体值
|
|
||||||
|
|
||||||
## 测试、竞态与基准
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 根模块
|
|
||||||
go test ./...
|
|
||||||
|
|
||||||
go test -race ./...
|
|
||||||
|
|
||||||
go test -run ^$ -bench BenchmarkQueryBuilder_ -benchmem ./...
|
|
||||||
|
|
||||||
# testing 子模块(集成测试/基准)
|
|
||||||
cd testing
|
|
||||||
go test ./...
|
|
||||||
go test -race ./...
|
|
||||||
go test -run ^$ -bench "Benchmark(QueryX|Orm|ScanEach|BatchInsert)" -benchmem
|
|
||||||
```
|
|
||||||
|
|
||||||
## 支持数据库驱动
|
|
||||||
|
|
||||||
本库兼容所有实现 `database/sql` 的驱动。常见示例:
|
|
||||||
- SQLite: `_ "modernc.org/sqlite"`
|
|
||||||
- MySQL: `_ "github.com/go-sql-driver/mysql"`
|
|
||||||
- PostgreSQL: `_ "github.com/lib/pq"`
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
Apache License 2.0
|
|
||||||
338
batch.go
338
batch.go
@ -1,338 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
type multiSQLResult struct {
|
|
||||||
results []sql.Result
|
|
||||||
}
|
|
||||||
|
|
||||||
type batchExecMetaKey struct{}
|
|
||||||
|
|
||||||
// BatchExecMeta contains chunk execution metadata for batch insert operations.
|
|
||||||
// It is attached to context for chunked execution and can be read in SQL hooks.
|
|
||||||
type BatchExecMeta struct {
|
|
||||||
ChunkIndex int
|
|
||||||
ChunkCount int
|
|
||||||
ChunkRows int
|
|
||||||
TotalRows int
|
|
||||||
ColumnCount int
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchExecMetaFromContext extracts batch chunk metadata from context.
|
|
||||||
func BatchExecMetaFromContext(ctx context.Context) (BatchExecMeta, bool) {
|
|
||||||
if ctx == nil {
|
|
||||||
return BatchExecMeta{}, false
|
|
||||||
}
|
|
||||||
meta, ok := ctx.Value(batchExecMetaKey{}).(BatchExecMeta)
|
|
||||||
return meta, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func withBatchExecMeta(ctx context.Context, meta BatchExecMeta) context.Context {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
return context.WithValue(ctx, batchExecMetaKey{}, meta)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiSQLResult) LastInsertId() (int64, error) {
|
|
||||||
if len(m.results) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return m.results[len(m.results)-1].LastInsertId()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m multiSQLResult) RowsAffected() (int64, error) {
|
|
||||||
var total int64
|
|
||||||
for _, result := range m.results {
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
total += affected
|
|
||||||
}
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBatchInsertMaxRows configures max row count per INSERT statement for batch APIs.
|
|
||||||
// <= 0 disables splitting and keeps single-statement behavior.
|
|
||||||
func (s *StarDB) SetBatchInsertMaxRows(maxRows int) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if maxRows < 0 {
|
|
||||||
maxRows = 0
|
|
||||||
}
|
|
||||||
atomic.StoreInt64(&s.batchInsertMaxRows, int64(maxRows))
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchInsertMaxRows returns current batch split threshold.
|
|
||||||
// 0 means disabled.
|
|
||||||
func (s *StarDB) BatchInsertMaxRows() int {
|
|
||||||
if s == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
value := atomic.LoadInt64(&s.batchInsertMaxRows)
|
|
||||||
if value <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return int(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBatchInsertMaxParams configures max bind parameter count per INSERT statement for batch APIs.
|
|
||||||
// <= 0 means auto mode (use built-in defaults for known drivers) or no limit for unknown drivers.
|
|
||||||
func (s *StarDB) SetBatchInsertMaxParams(maxParams int) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if maxParams < 0 {
|
|
||||||
maxParams = 0
|
|
||||||
}
|
|
||||||
atomic.StoreInt64(&s.batchInsertMaxParams, int64(maxParams))
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchInsertMaxParams returns configured max bind parameter threshold.
|
|
||||||
// 0 means auto mode.
|
|
||||||
func (s *StarDB) BatchInsertMaxParams() int {
|
|
||||||
if s == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
value := atomic.LoadInt64(&s.batchInsertMaxParams)
|
|
||||||
if value <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return int(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func detectBatchInsertMaxParams(db *sql.DB) int {
|
|
||||||
if db == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
driverType := strings.ToLower(fmt.Sprintf("%T", db.Driver()))
|
|
||||||
switch {
|
|
||||||
case strings.Contains(driverType, "sqlite"):
|
|
||||||
// Keep conservative default for wide compatibility.
|
|
||||||
return 999
|
|
||||||
case strings.Contains(driverType, "postgres"), strings.Contains(driverType, "pgx"), strings.Contains(driverType, "pq"):
|
|
||||||
return 65535
|
|
||||||
case strings.Contains(driverType, "mysql"):
|
|
||||||
return 65535
|
|
||||||
case strings.Contains(driverType, "sqlserver"), strings.Contains(driverType, "mssql"):
|
|
||||||
return 2100
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func minPositive(a, b int) int {
|
|
||||||
if a <= 0 {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
if b <= 0 {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
if a < b {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) batchInsertChunkSize(columnCount int) (int, error) {
|
|
||||||
maxRows := s.BatchInsertMaxRows()
|
|
||||||
maxParams := s.BatchInsertMaxParams()
|
|
||||||
if maxParams <= 0 {
|
|
||||||
maxParams = detectBatchInsertMaxParams(s.db)
|
|
||||||
}
|
|
||||||
|
|
||||||
maxRowsByParams := 0
|
|
||||||
if maxParams > 0 {
|
|
||||||
maxRowsByParams = maxParams / columnCount
|
|
||||||
if maxRowsByParams <= 0 {
|
|
||||||
return 0, ErrBatchInsertMaxParamsTooLow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return minPositive(maxRows, maxRowsByParams), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildBatchInsertQuery(tableName string, columns []string, values [][]interface{}) (string, []interface{}) {
|
|
||||||
placeholderGroup := "(" + strings.Repeat("?, ", len(columns)-1) + "?)"
|
|
||||||
placeholders := strings.Repeat(placeholderGroup+", ", len(values)-1) + placeholderGroup
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
|
|
||||||
tableName,
|
|
||||||
strings.Join(columns, ", "),
|
|
||||||
placeholders)
|
|
||||||
|
|
||||||
args := make([]interface{}, 0, len(values)*len(columns))
|
|
||||||
for _, row := range values {
|
|
||||||
args = append(args, row...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return query, args
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) batchInsertChunked(ctx context.Context, tableName string, columns []string, values [][]interface{}, chunkSize int) (sql.Result, error) {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := s.BeginTx(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
chunkCount := (len(values) + chunkSize - 1) / chunkSize
|
|
||||||
results := make([]sql.Result, 0, chunkCount)
|
|
||||||
for start := 0; start < len(values); start += chunkSize {
|
|
||||||
end := start + chunkSize
|
|
||||||
if end > len(values) {
|
|
||||||
end = len(values)
|
|
||||||
}
|
|
||||||
chunkIndex := start/chunkSize + 1
|
|
||||||
query, args := buildBatchInsertQuery(tableName, columns, values[start:end])
|
|
||||||
chunkCtx := withBatchExecMeta(ctx, BatchExecMeta{
|
|
||||||
ChunkIndex: chunkIndex,
|
|
||||||
ChunkCount: chunkCount,
|
|
||||||
ChunkRows: end - start,
|
|
||||||
TotalRows: len(values),
|
|
||||||
ColumnCount: len(columns),
|
|
||||||
})
|
|
||||||
result, execErr := tx.exec(chunkCtx, query, args...)
|
|
||||||
if execErr != nil {
|
|
||||||
_ = tx.Rollback()
|
|
||||||
return nil, execErr
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
_ = tx.Rollback()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return multiSQLResult{results: results}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchInsert performs batch insert operation
|
|
||||||
// Usage: BatchInsert("users", []string{"name", "age"}, [][]interface{}{{"Alice", 25}, {"Bob", 30}})
|
|
||||||
func (s *StarDB) BatchInsert(tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
|
||||||
return s.batchInsert(nil, tableName, columns, values)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchInsertContext performs batch insert with context
|
|
||||||
func (s *StarDB) BatchInsertContext(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
|
||||||
return s.batchInsert(ctx, tableName, columns, values)
|
|
||||||
}
|
|
||||||
|
|
||||||
// batchInsert is the internal implementation
|
|
||||||
func (s *StarDB) batchInsert(ctx context.Context, tableName string, columns []string, values [][]interface{}) (sql.Result, error) {
|
|
||||||
if strings.TrimSpace(tableName) == "" {
|
|
||||||
return nil, ErrTableNameEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(columns) == 0 {
|
|
||||||
return nil, ErrNoInsertColumns
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(values) == 0 {
|
|
||||||
return nil, ErrNoInsertValues
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, row := range values {
|
|
||||||
if len(row) != len(columns) {
|
|
||||||
return nil, wrapBatchRowValueCountMismatch(i, len(row), len(columns))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
chunkSize, err := s.batchInsertChunkSize(len(columns))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if chunkSize > 0 && len(values) > chunkSize {
|
|
||||||
return s.batchInsertChunked(ctx, tableName, columns, values, chunkSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
query, args := buildBatchInsertQuery(tableName, columns, values)
|
|
||||||
return s.exec(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchInsertStructs performs batch insert using structs
|
|
||||||
func (s *StarDB) BatchInsertStructs(tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
return s.batchInsertStructs(nil, tableName, structs, autoIncrementFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchInsertStructsContext performs batch insert using structs with context
|
|
||||||
func (s *StarDB) BatchInsertStructsContext(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
return s.batchInsertStructs(ctx, tableName, structs, autoIncrementFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// batchInsertStructs is the internal implementation
|
|
||||||
func (s *StarDB) batchInsertStructs(ctx context.Context, tableName string, structs interface{}, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
if structs == nil {
|
|
||||||
return nil, ErrStructsNil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get slice of structs
|
|
||||||
targetValue := reflect.ValueOf(structs)
|
|
||||||
if targetValue.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return nil, ErrStructsPointerNil
|
|
||||||
}
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Slice && targetValue.Kind() != reflect.Array {
|
|
||||||
return nil, ErrStructsNotSlice
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Len() == 0 {
|
|
||||||
return nil, ErrNoStructsToInsert
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get field names from first struct
|
|
||||||
firstStruct := targetValue.Index(0).Interface()
|
|
||||||
fieldNames, err := getStructFieldNames(firstStruct, "db")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter out auto-increment fields
|
|
||||||
var columns []string
|
|
||||||
for _, fieldName := range fieldNames {
|
|
||||||
isAutoIncrement := false
|
|
||||||
for _, autoField := range autoIncrementFields {
|
|
||||||
if fieldName == autoField {
|
|
||||||
isAutoIncrement = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !isAutoIncrement {
|
|
||||||
columns = append(columns, fieldName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract values from all structs
|
|
||||||
var values [][]interface{}
|
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
|
||||||
structVal := targetValue.Index(i).Interface()
|
|
||||||
fieldValues, err := getStructFieldValues(structVal, "db")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var row []interface{}
|
|
||||||
for _, col := range columns {
|
|
||||||
row = append(row, fieldValues[col])
|
|
||||||
}
|
|
||||||
values = append(values, row)
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.batchInsert(ctx, tableName, columns, values)
|
|
||||||
}
|
|
||||||
134
builder.go
134
builder.go
@ -1,134 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// QueryBuilder helps build SQL queries
|
|
||||||
type QueryBuilder struct {
|
|
||||||
table string
|
|
||||||
columns []string
|
|
||||||
joins []string
|
|
||||||
where []string
|
|
||||||
whereArgs []interface{}
|
|
||||||
groupBy []string
|
|
||||||
having []string
|
|
||||||
havingArgs []interface{}
|
|
||||||
orderBy string
|
|
||||||
limit int
|
|
||||||
offset int
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewQueryBuilder creates a new query builder
|
|
||||||
func NewQueryBuilder(table string) *QueryBuilder {
|
|
||||||
return &QueryBuilder{
|
|
||||||
table: table,
|
|
||||||
columns: []string{"*"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select sets the columns to select
|
|
||||||
func (qb *QueryBuilder) Select(columns ...string) *QueryBuilder {
|
|
||||||
qb.columns = columns
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// Where adds a WHERE condition
|
|
||||||
func (qb *QueryBuilder) Where(condition string, args ...interface{}) *QueryBuilder {
|
|
||||||
qb.where = append(qb.where, condition)
|
|
||||||
qb.whereArgs = append(qb.whereArgs, args...)
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// Join adds a JOIN clause.
|
|
||||||
func (qb *QueryBuilder) Join(clause string) *QueryBuilder {
|
|
||||||
qb.joins = append(qb.joins, clause)
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupBy sets GROUP BY columns.
|
|
||||||
func (qb *QueryBuilder) GroupBy(columns ...string) *QueryBuilder {
|
|
||||||
qb.groupBy = append(qb.groupBy, columns...)
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// Having adds a HAVING condition.
|
|
||||||
func (qb *QueryBuilder) Having(condition string, args ...interface{}) *QueryBuilder {
|
|
||||||
qb.having = append(qb.having, condition)
|
|
||||||
qb.havingArgs = append(qb.havingArgs, args...)
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrderBy sets the ORDER BY clause
|
|
||||||
func (qb *QueryBuilder) OrderBy(orderBy string) *QueryBuilder {
|
|
||||||
qb.orderBy = orderBy
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit sets the LIMIT
|
|
||||||
func (qb *QueryBuilder) Limit(limit int) *QueryBuilder {
|
|
||||||
qb.limit = limit
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset sets the OFFSET
|
|
||||||
func (qb *QueryBuilder) Offset(offset int) *QueryBuilder {
|
|
||||||
qb.offset = offset
|
|
||||||
return qb
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build builds the SQL query and returns query string and args
|
|
||||||
func (qb *QueryBuilder) Build() (string, []interface{}) {
|
|
||||||
var parts []string
|
|
||||||
|
|
||||||
// SELECT
|
|
||||||
parts = append(parts, fmt.Sprintf("SELECT %s FROM %s",
|
|
||||||
strings.Join(qb.columns, ", "), qb.table))
|
|
||||||
|
|
||||||
// JOIN
|
|
||||||
if len(qb.joins) > 0 {
|
|
||||||
parts = append(parts, strings.Join(qb.joins, " "))
|
|
||||||
}
|
|
||||||
|
|
||||||
// WHERE
|
|
||||||
if len(qb.where) > 0 {
|
|
||||||
parts = append(parts, "WHERE "+strings.Join(qb.where, " AND "))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GROUP BY
|
|
||||||
if len(qb.groupBy) > 0 {
|
|
||||||
parts = append(parts, "GROUP BY "+strings.Join(qb.groupBy, ", "))
|
|
||||||
}
|
|
||||||
|
|
||||||
// HAVING
|
|
||||||
if len(qb.having) > 0 {
|
|
||||||
parts = append(parts, "HAVING "+strings.Join(qb.having, " AND "))
|
|
||||||
}
|
|
||||||
|
|
||||||
// ORDER BY
|
|
||||||
if qb.orderBy != "" {
|
|
||||||
parts = append(parts, "ORDER BY "+qb.orderBy)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LIMIT
|
|
||||||
if qb.limit > 0 {
|
|
||||||
parts = append(parts, fmt.Sprintf("LIMIT %d", qb.limit))
|
|
||||||
}
|
|
||||||
|
|
||||||
// OFFSET
|
|
||||||
if qb.offset > 0 {
|
|
||||||
parts = append(parts, fmt.Sprintf("OFFSET %d", qb.offset))
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]interface{}, 0, len(qb.whereArgs)+len(qb.havingArgs))
|
|
||||||
args = append(args, qb.whereArgs...)
|
|
||||||
args = append(args, qb.havingArgs...)
|
|
||||||
return strings.Join(parts, " "), args
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query executes the query
|
|
||||||
func (qb *QueryBuilder) Query(db *StarDB) (*StarRows, error) {
|
|
||||||
query, args := qb.Build()
|
|
||||||
return db.Query(query, args...)
|
|
||||||
}
|
|
||||||
521
builder_test.go
521
builder_test.go
@ -1,521 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNewQueryBuilder(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users")
|
|
||||||
|
|
||||||
if qb.table != "users" {
|
|
||||||
t.Errorf("Expected table 'users', got '%s'", qb.table)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(qb.columns) != 1 || qb.columns[0] != "*" {
|
|
||||||
t.Errorf("Expected default columns ['*'], got %v", qb.columns)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Select(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Select("id", "name", "email")
|
|
||||||
|
|
||||||
expected := []string{"id", "name", "email"}
|
|
||||||
if !reflect.DeepEqual(qb.columns, expected) {
|
|
||||||
t.Errorf("Expected columns %v, got %v", expected, qb.columns)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Where(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
Where("age > ?", 18).
|
|
||||||
Where("active = ?", true)
|
|
||||||
|
|
||||||
if len(qb.where) != 2 {
|
|
||||||
t.Errorf("Expected 2 where conditions, got %d", len(qb.where))
|
|
||||||
}
|
|
||||||
|
|
||||||
if qb.where[0] != "age > ?" {
|
|
||||||
t.Errorf("Expected first where 'age > ?', got '%s'", qb.where[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if qb.where[1] != "active = ?" {
|
|
||||||
t.Errorf("Expected second where 'active = ?', got '%s'", qb.where[1])
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(qb.whereArgs) != 2 {
|
|
||||||
t.Errorf("Expected 2 where args, got %d", len(qb.whereArgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
if qb.whereArgs[0] != 18 {
|
|
||||||
t.Errorf("Expected first arg 18, got %v", qb.whereArgs[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if qb.whereArgs[1] != true {
|
|
||||||
t.Errorf("Expected second arg true, got %v", qb.whereArgs[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_OrderBy(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").OrderBy("name ASC")
|
|
||||||
|
|
||||||
if qb.orderBy != "name ASC" {
|
|
||||||
t.Errorf("Expected orderBy 'name ASC', got '%s'", qb.orderBy)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Limit(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Limit(10)
|
|
||||||
|
|
||||||
if qb.limit != 10 {
|
|
||||||
t.Errorf("Expected limit 10, got %d", qb.limit)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Offset(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Offset(20)
|
|
||||||
|
|
||||||
if qb.offset != 20 {
|
|
||||||
t.Errorf("Expected offset 20, got %d", qb.offset)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_Simple(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users")
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithSelect(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Select("id", "name", "email")
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT id, name, email FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithWhere(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
Where("age > ?", 18).
|
|
||||||
Where("active = ?", true)
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users WHERE age > ? AND active = ?"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 2 {
|
|
||||||
t.Errorf("Expected 2 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
|
|
||||||
if args[0] != 18 {
|
|
||||||
t.Errorf("Expected first arg 18, got %v", args[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if args[1] != true {
|
|
||||||
t.Errorf("Expected second arg true, got %v", args[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithOrderBy(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").OrderBy("created_at DESC")
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users ORDER BY created_at DESC"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithLimit(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Limit(10)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users LIMIT 10"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithOffset(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Offset(20)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users OFFSET 20"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithLimitAndOffset(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Limit(10).Offset(20)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users LIMIT 10 OFFSET 20"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_Complex(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
Select("id", "name", "email", "age").
|
|
||||||
Where("age > ?", 18).
|
|
||||||
Where("active = ?", true).
|
|
||||||
Where("country = ?", "US").
|
|
||||||
OrderBy("name ASC").
|
|
||||||
Limit(10).
|
|
||||||
Offset(20)
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT id, name, email, age FROM users WHERE age > ? AND active = ? AND country = ? ORDER BY name ASC LIMIT 10 OFFSET 20"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedArgs := []interface{}{18, true, "US"}
|
|
||||||
if len(args) != len(expectedArgs) {
|
|
||||||
t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, expected := range expectedArgs {
|
|
||||||
if args[i] != expected {
|
|
||||||
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_MultipleWhere(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("orders").
|
|
||||||
Where("user_id = ?", 123).
|
|
||||||
Where("status IN (?, ?, ?)", "pending", "processing", "shipped").
|
|
||||||
Where("created_at > ?", "2024-01-01")
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM orders WHERE user_id = ? AND status IN (?, ?, ?) AND created_at > ?"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 5 {
|
|
||||||
t.Errorf("Expected 5 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedArgs := []interface{}{123, "pending", "processing", "shipped", "2024-01-01"}
|
|
||||||
for i, expected := range expectedArgs {
|
|
||||||
if args[i] != expected {
|
|
||||||
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Chaining(t *testing.T) {
|
|
||||||
// Test method chaining returns the same builder
|
|
||||||
qb := NewQueryBuilder("users")
|
|
||||||
|
|
||||||
qb2 := qb.Select("id", "name")
|
|
||||||
if qb != qb2 {
|
|
||||||
t.Error("Select should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb3 := qb.Where("age > ?", 18)
|
|
||||||
if qb != qb3 {
|
|
||||||
t.Error("Where should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb4 := qb.OrderBy("name ASC")
|
|
||||||
if qb != qb4 {
|
|
||||||
t.Error("OrderBy should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb5 := qb.Limit(10)
|
|
||||||
if qb != qb5 {
|
|
||||||
t.Error("Limit should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb6 := qb.Offset(20)
|
|
||||||
if qb != qb6 {
|
|
||||||
t.Error("Offset should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb7 := qb.Join("LEFT JOIN orders o ON o.user_id = users.id")
|
|
||||||
if qb != qb7 {
|
|
||||||
t.Error("Join should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb8 := qb.GroupBy("users.id")
|
|
||||||
if qb != qb8 {
|
|
||||||
t.Error("GroupBy should return the same builder instance")
|
|
||||||
}
|
|
||||||
|
|
||||||
qb9 := qb.Having("COUNT(o.id) > ?", 1)
|
|
||||||
if qb != qb9 {
|
|
||||||
t.Error("Having should return the same builder instance")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_EmptyWhere(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Select("id", "name")
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT id, name FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_OnlyLimit(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Limit(5)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users LIMIT 5"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_OnlyOffset(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Offset(10)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users OFFSET 10"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_ZeroLimit(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Limit(0)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
// Limit 0 should not be included
|
|
||||||
expectedQuery := "SELECT * FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_ZeroOffset(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Offset(0)
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
// Offset 0 should not be included
|
|
||||||
expectedQuery := "SELECT * FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_MultipleSelect(t *testing.T) {
|
|
||||||
// Last Select should override previous ones
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
Select("id", "name").
|
|
||||||
Select("email", "age")
|
|
||||||
|
|
||||||
query, _ := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT email, age FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_MultipleOrderBy(t *testing.T) {
|
|
||||||
// Last OrderBy should override previous ones
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
OrderBy("name ASC").
|
|
||||||
OrderBy("created_at DESC")
|
|
||||||
|
|
||||||
query, _ := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT * FROM users ORDER BY created_at DESC"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_ComplexWhereConditions(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("products").
|
|
||||||
Select("id", "name", "price").
|
|
||||||
Where("category_id = ?", 5).
|
|
||||||
Where("price BETWEEN ? AND ?", 10.0, 100.0).
|
|
||||||
Where("stock > ?", 0).
|
|
||||||
Where("name LIKE ?", "%phone%").
|
|
||||||
OrderBy("price DESC").
|
|
||||||
Limit(20)
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT id, name, price FROM products WHERE category_id = ? AND price BETWEEN ? AND ? AND stock > ? AND name LIKE ? ORDER BY price DESC LIMIT 20"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedArgs := []interface{}{5, 10.0, 100.0, 0, "%phone%"}
|
|
||||||
if len(args) != len(expectedArgs) {
|
|
||||||
t.Errorf("Expected %d args, got %d", len(expectedArgs), len(args))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, expected := range expectedArgs {
|
|
||||||
if args[i] != expected {
|
|
||||||
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_SingleColumn(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users").Select("COUNT(*)")
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT COUNT(*) FROM users"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 0 {
|
|
||||||
t.Errorf("Expected 0 args, got %d", len(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_JoinLikeWhere(t *testing.T) {
|
|
||||||
// Test that WHERE can handle JOIN-like conditions
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
Select("users.id", "users.name", "orders.total").
|
|
||||||
Where("users.id = orders.user_id").
|
|
||||||
Where("orders.status = ?", "completed")
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT users.id, users.name, orders.total FROM users WHERE users.id = orders.user_id AND orders.status = ?"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) != 1 {
|
|
||||||
t.Errorf("Expected 1 arg, got %d", len(args))
|
|
||||||
}
|
|
||||||
|
|
||||||
if args[0] != "completed" {
|
|
||||||
t.Errorf("Expected arg 'completed', got %v", args[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_WithJoinGroupByHaving(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("users u").
|
|
||||||
Select("u.id", "u.name", "COUNT(o.id) AS order_count").
|
|
||||||
Join("LEFT JOIN orders o ON o.user_id = u.id").
|
|
||||||
Where("u.active = ?", true).
|
|
||||||
GroupBy("u.id", "u.name").
|
|
||||||
Having("COUNT(o.id) > ?", 2).
|
|
||||||
OrderBy("order_count DESC")
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT u.id, u.name, COUNT(o.id) AS order_count FROM users u LEFT JOIN orders o ON o.user_id = u.id WHERE u.active = ? GROUP BY u.id, u.name HAVING COUNT(o.id) > ? ORDER BY order_count DESC"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedArgs := []interface{}{true, 2}
|
|
||||||
if len(args) != len(expectedArgs) {
|
|
||||||
t.Fatalf("Expected %d args, got %d", len(expectedArgs), len(args))
|
|
||||||
}
|
|
||||||
for i, expected := range expectedArgs {
|
|
||||||
if args[i] != expected {
|
|
||||||
t.Errorf("Expected arg[%d] = %v, got %v", i, expected, args[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQueryBuilder_Build_HavingWithoutWhere(t *testing.T) {
|
|
||||||
qb := NewQueryBuilder("orders").
|
|
||||||
Select("user_id", "COUNT(*) AS cnt").
|
|
||||||
GroupBy("user_id").
|
|
||||||
Having("COUNT(*) >= ?", 3)
|
|
||||||
|
|
||||||
query, args := qb.Build()
|
|
||||||
|
|
||||||
expectedQuery := "SELECT user_id, COUNT(*) AS cnt FROM orders GROUP BY user_id HAVING COUNT(*) >= ?"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query '%s', got '%s'", expectedQuery, query)
|
|
||||||
}
|
|
||||||
if len(args) != 1 || args[0] != 3 {
|
|
||||||
t.Errorf("Expected args [3], got %v", args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark tests
|
|
||||||
func BenchmarkQueryBuilder_Simple(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
qb := NewQueryBuilder("users")
|
|
||||||
qb.Build()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkQueryBuilder_Complex(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
qb := NewQueryBuilder("users").
|
|
||||||
Select("id", "name", "email", "age").
|
|
||||||
Where("age > ?", 18).
|
|
||||||
Where("active = ?", true).
|
|
||||||
Where("country = ?", "US").
|
|
||||||
OrderBy("name ASC").
|
|
||||||
Limit(10).
|
|
||||||
Offset(20)
|
|
||||||
qb.Build()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
32
converter.go
32
converter.go
@ -1,32 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
internalconv "b612.me/stardb/internal/convert"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// convertToInt64 converts any value to int64
|
|
||||||
func convertToInt64(val interface{}) int64 {
|
|
||||||
return internalconv.ToInt64(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToUint64 converts any value to uint64
|
|
||||||
func convertToUint64(val interface{}) uint64 {
|
|
||||||
return internalconv.ToUint64(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToFloat64 converts any value to float64
|
|
||||||
func convertToFloat64(val interface{}) float64 {
|
|
||||||
return internalconv.ToFloat64(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToBool converts any value to bool
|
|
||||||
// Non-zero numbers are considered true
|
|
||||||
func convertToBool(val interface{}) bool {
|
|
||||||
return internalconv.ToBool(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToTime converts any value to time.Time
|
|
||||||
func convertToTime(val interface{}, layout string) time.Time {
|
|
||||||
return internalconv.ToTime(val, layout)
|
|
||||||
}
|
|
||||||
@ -1,13 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import internalconv "b612.me/stardb/internal/convert"
|
|
||||||
|
|
||||||
// ConvertToInt64Safe converts any value to int64 with error handling
|
|
||||||
func ConvertToInt64Safe(val interface{}) (int64, error) {
|
|
||||||
return internalconv.ToInt64Safe(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertToStringSafe converts any value to string with error handling
|
|
||||||
func ConvertToStringSafe(val interface{}) (string, error) {
|
|
||||||
return internalconv.ToStringSafe(val)
|
|
||||||
}
|
|
||||||
@ -1,183 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConvertToInt64(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input interface{}
|
|
||||||
expected int64
|
|
||||||
}{
|
|
||||||
{"nil", nil, 0},
|
|
||||||
{"int", 42, 42},
|
|
||||||
{"int32", int32(42), 42},
|
|
||||||
{"int64", int64(42), 42},
|
|
||||||
{"uint64", uint64(42), 42},
|
|
||||||
{"float32", float32(42.7), 42},
|
|
||||||
{"float64", float64(42.7), 42},
|
|
||||||
{"string", "42", 42},
|
|
||||||
{"bool true", true, 1},
|
|
||||||
{"bool false", false, 0},
|
|
||||||
{"bytes", []byte("42"), 42},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := convertToInt64(tt.input)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("Expected %d, got %d", tt.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertToUint64(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input interface{}
|
|
||||||
expected uint64
|
|
||||||
}{
|
|
||||||
{"nil", nil, 0},
|
|
||||||
{"int", 42, 42},
|
|
||||||
{"int32", int32(42), 42},
|
|
||||||
{"int64", int64(42), 42},
|
|
||||||
{"uint64", uint64(42), 42},
|
|
||||||
{"float32", float32(42.7), 42},
|
|
||||||
{"float64", float64(42.7), 42},
|
|
||||||
{"string", "42", 42},
|
|
||||||
{"bool true", true, 1},
|
|
||||||
{"bool false", false, 0},
|
|
||||||
{"bytes", []byte("42"), 42},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := convertToUint64(tt.input)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("Expected %d, got %d", tt.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertToFloat64(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input interface{}
|
|
||||||
expected float64
|
|
||||||
}{
|
|
||||||
{"nil", nil, 0.0},
|
|
||||||
{"int", 42, 42.0},
|
|
||||||
{"int32", int32(42), 42.0},
|
|
||||||
{"int64", int64(42), 42.0},
|
|
||||||
{"uint64", uint64(42), 42.0},
|
|
||||||
{"float32", float32(42.5), 42.5},
|
|
||||||
{"float64", float64(42.5), 42.5},
|
|
||||||
{"string", "42.5", 42.5},
|
|
||||||
{"bool true", true, 1.0},
|
|
||||||
{"bool false", false, 0.0},
|
|
||||||
{"bytes", []byte("42.5"), 42.5},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := convertToFloat64(tt.input)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("Expected %f, got %f", tt.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertToBool(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input interface{}
|
|
||||||
expected bool
|
|
||||||
}{
|
|
||||||
{"nil", nil, false},
|
|
||||||
{"bool true", true, true},
|
|
||||||
{"bool false", false, false},
|
|
||||||
{"int positive", 1, true},
|
|
||||||
{"int zero", 0, false},
|
|
||||||
{"int negative", -1, true},
|
|
||||||
{"float positive", 1.5, true},
|
|
||||||
{"float zero", 0.0, false},
|
|
||||||
{"string true", "true", true},
|
|
||||||
{"string false", "false", false},
|
|
||||||
{"string 1", "1", true},
|
|
||||||
{"string 0", "0", false},
|
|
||||||
{"bytes true", []byte("true"), true},
|
|
||||||
{"bytes false", []byte("false"), false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := convertToBool(tt.input)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertToString(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input interface{}
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"nil", nil, ""},
|
|
||||||
{"string", "hello", "hello"},
|
|
||||||
{"int", 42, "42"},
|
|
||||||
{"int32", int32(42), "42"},
|
|
||||||
{"int64", int64(42), "42"},
|
|
||||||
{"float32", float32(42.5), "42.5"},
|
|
||||||
{"float64", float64(42.5), "42.5"},
|
|
||||||
{"bool true", true, "true"},
|
|
||||||
{"bool false", false, "false"},
|
|
||||||
{"time", now, now.String()},
|
|
||||||
{"bytes", []byte("hello"), "hello"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := convertToString(tt.input)
|
|
||||||
if result != tt.expected {
|
|
||||||
t.Errorf("Expected '%s', got '%s'", tt.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertToTime(t *testing.T) {
|
|
||||||
layout := "2006-01-02 15:04:05"
|
|
||||||
timeStr := "2024-01-01 10:00:00"
|
|
||||||
expectedTime, _ := time.Parse(layout, timeStr)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input interface{}
|
|
||||||
layout string
|
|
||||||
expected time.Time
|
|
||||||
}{
|
|
||||||
{"nil", nil, layout, time.Time{}},
|
|
||||||
{"time.Time", expectedTime, layout, expectedTime},
|
|
||||||
{"int64 unix", int64(1704103200), layout, time.Unix(1704103200, 0)},
|
|
||||||
{"string", timeStr, layout, expectedTime},
|
|
||||||
{"bytes", []byte(timeStr), layout, expectedTime},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
result := convertToTime(tt.input, tt.layout)
|
|
||||||
if !result.Equal(tt.expected) {
|
|
||||||
t.Errorf("Expected %v, got %v", tt.expected, result)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
64
errors.go
64
errors.go
@ -1,64 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
// Lifecycle errors.
|
|
||||||
ErrDBNotInitialized = errors.New("database is not initialized; call Open or SetDB first")
|
|
||||||
ErrTxNotInitialized = errors.New("transaction is not initialized")
|
|
||||||
ErrStmtNotInitialized = errors.New("statement is not initialized")
|
|
||||||
ErrStmtDBNotInitialized = errors.New("statement database context is not initialized")
|
|
||||||
|
|
||||||
// SQL input errors.
|
|
||||||
ErrQueryEmpty = errors.New("query string cannot be empty")
|
|
||||||
ErrScanStopped = errors.New("scan stopped by callback")
|
|
||||||
ErrScanFuncNil = errors.New("scan callback cannot be nil")
|
|
||||||
ErrScanORMFuncNil = errors.New("scan orm callback cannot be nil")
|
|
||||||
|
|
||||||
// Mapping and schema errors.
|
|
||||||
ErrColumnNotFound = errors.New("column not found")
|
|
||||||
ErrFieldNotFound = errors.New("field not found")
|
|
||||||
ErrRowIndexOutOfRange = errors.New("row index out of range")
|
|
||||||
|
|
||||||
// Target validation errors.
|
|
||||||
ErrTargetNil = errors.New("target cannot be nil")
|
|
||||||
ErrTargetsNil = errors.New("targets cannot be nil")
|
|
||||||
ErrTargetNotPointer = errors.New("target must be a pointer")
|
|
||||||
ErrTargetPointerNil = errors.New("target pointer cannot be nil")
|
|
||||||
ErrTargetsPointerNil = errors.New("targets pointer is nil")
|
|
||||||
ErrTargetNotStruct = errors.New("target is not a struct")
|
|
||||||
ErrTargetNotWritable = errors.New("target is not writable")
|
|
||||||
ErrPointerTargetNil = errors.New("pointer target is nil")
|
|
||||||
|
|
||||||
// SQL builder errors.
|
|
||||||
ErrTableNameEmpty = errors.New("table name cannot be empty")
|
|
||||||
ErrPrimaryKeyRequired = errors.New("at least one primary key is required")
|
|
||||||
ErrPrimaryKeyEmpty = errors.New("primary key cannot be empty")
|
|
||||||
ErrNoInsertColumns = errors.New("no columns to insert")
|
|
||||||
ErrNoInsertValues = errors.New("no values to insert")
|
|
||||||
ErrBatchInsertMaxParamsTooLow = errors.New("batch insert max params is lower than column count")
|
|
||||||
ErrNoUpdateFields = errors.New("no fields to update after excluding primary keys")
|
|
||||||
ErrBatchRowValueCountMismatch = errors.New("row values count does not match columns")
|
|
||||||
ErrStructsNil = errors.New("structs cannot be nil")
|
|
||||||
ErrStructsPointerNil = errors.New("structs pointer is nil")
|
|
||||||
ErrStructsNotSlice = errors.New("structs must be a slice or array")
|
|
||||||
ErrNoStructsToInsert = errors.New("no structs to insert")
|
|
||||||
|
|
||||||
// Transaction helper errors.
|
|
||||||
ErrTxFuncNil = errors.New("transaction callback cannot be nil")
|
|
||||||
)
|
|
||||||
|
|
||||||
func wrapColumnNotFound(column string) error {
|
|
||||||
return fmt.Errorf("%w: %s", ErrColumnNotFound, column)
|
|
||||||
}
|
|
||||||
|
|
||||||
func wrapFieldNotFound(field string) error {
|
|
||||||
return fmt.Errorf("%w: %s", ErrFieldNotFound, field)
|
|
||||||
}
|
|
||||||
|
|
||||||
func wrapBatchRowValueCountMismatch(rowIndex, got, expected int) error {
|
|
||||||
return fmt.Errorf("%w: row %d has %d values, expected %d", ErrBatchRowValueCountMismatch, rowIndex, got, expected)
|
|
||||||
}
|
|
||||||
2
go.sum
2
go.sum
@ -1,2 +0,0 @@
|
|||||||
github.com/mattn/go-sqlite3 v1.14.34 h1:3NtcvcUnFBPsuRcno8pUtupspG/GM+9nZ88zgJcp6Zk=
|
|
||||||
github.com/mattn/go-sqlite3 v1.14.34/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
|
||||||
@ -1,359 +0,0 @@
|
|||||||
package convert
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
var defaultNullTimeLayouts = []string{
|
|
||||||
time.RFC3339Nano,
|
|
||||||
time.RFC3339,
|
|
||||||
"2006-01-02 15:04:05",
|
|
||||||
"2006-01-02",
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToInt64 converts any value to int64.
|
|
||||||
func ToInt64(val interface{}) int64 {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return 0
|
|
||||||
case int:
|
|
||||||
return int64(v)
|
|
||||||
case int32:
|
|
||||||
return int64(v)
|
|
||||||
case int64:
|
|
||||||
return v
|
|
||||||
case uint64:
|
|
||||||
return int64(v)
|
|
||||||
case float32:
|
|
||||||
return int64(v)
|
|
||||||
case float64:
|
|
||||||
return int64(v)
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseInt(v, 10, 64)
|
|
||||||
return result
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
case time.Time:
|
|
||||||
return v.Unix()
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseInt(string(v), 10, 64)
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToUint64 converts any value to uint64.
|
|
||||||
func ToUint64(val interface{}) uint64 {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return 0
|
|
||||||
case int:
|
|
||||||
return uint64(v)
|
|
||||||
case int32:
|
|
||||||
return uint64(v)
|
|
||||||
case int64:
|
|
||||||
return uint64(v)
|
|
||||||
case uint64:
|
|
||||||
return v
|
|
||||||
case float32:
|
|
||||||
return uint64(v)
|
|
||||||
case float64:
|
|
||||||
return uint64(v)
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseUint(v, 10, 64)
|
|
||||||
return result
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
case time.Time:
|
|
||||||
return uint64(v.Unix())
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseUint(string(v), 10, 64)
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToFloat64 converts any value to float64.
|
|
||||||
func ToFloat64(val interface{}) float64 {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return 0
|
|
||||||
case int:
|
|
||||||
return float64(v)
|
|
||||||
case int32:
|
|
||||||
return float64(v)
|
|
||||||
case int64:
|
|
||||||
return float64(v)
|
|
||||||
case uint64:
|
|
||||||
return float64(v)
|
|
||||||
case float32:
|
|
||||||
return float64(v)
|
|
||||||
case float64:
|
|
||||||
return v
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseFloat(v, 64)
|
|
||||||
return result
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
case time.Time:
|
|
||||||
return float64(v.Unix())
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseFloat(string(v), 64)
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToBool converts any value to bool.
|
|
||||||
func ToBool(val interface{}) bool {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return false
|
|
||||||
case bool:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return v != 0
|
|
||||||
case int32:
|
|
||||||
return v != 0
|
|
||||||
case int64:
|
|
||||||
return v != 0
|
|
||||||
case uint64:
|
|
||||||
return v != 0
|
|
||||||
case float32:
|
|
||||||
return v != 0
|
|
||||||
case float64:
|
|
||||||
return v != 0
|
|
||||||
case string:
|
|
||||||
result, _ := strconv.ParseBool(v)
|
|
||||||
return result
|
|
||||||
case []byte:
|
|
||||||
result, _ := strconv.ParseBool(string(v))
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToTime converts any value to time.Time.
|
|
||||||
func ToTime(val interface{}, layout string) time.Time {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return time.Time{}
|
|
||||||
case time.Time:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return time.Unix(int64(v), 0)
|
|
||||||
case int32:
|
|
||||||
return time.Unix(int64(v), 0)
|
|
||||||
case int64:
|
|
||||||
return time.Unix(v, 0)
|
|
||||||
case uint64:
|
|
||||||
return time.Unix(int64(v), 0)
|
|
||||||
case float32:
|
|
||||||
sec := int64(v)
|
|
||||||
nsec := int64((v - float32(sec)) * 1e9)
|
|
||||||
return time.Unix(sec, nsec)
|
|
||||||
case float64:
|
|
||||||
sec := int64(v)
|
|
||||||
nsec := int64((v - float64(sec)) * 1e9)
|
|
||||||
return time.Unix(sec, nsec)
|
|
||||||
case string:
|
|
||||||
result, _ := time.Parse(layout, v)
|
|
||||||
return result
|
|
||||||
case []byte:
|
|
||||||
result, _ := time.Parse(layout, string(v))
|
|
||||||
return result
|
|
||||||
default:
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToInt64Safe converts any value to int64 with error handling.
|
|
||||||
func ToInt64Safe(val interface{}) (int64, error) {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return 0, nil
|
|
||||||
case int:
|
|
||||||
return int64(v), nil
|
|
||||||
case int32:
|
|
||||||
return int64(v), nil
|
|
||||||
case int64:
|
|
||||||
return v, nil
|
|
||||||
case uint64:
|
|
||||||
return int64(v), nil
|
|
||||||
case float32:
|
|
||||||
return int64(v), nil
|
|
||||||
case float64:
|
|
||||||
return int64(v), nil
|
|
||||||
case string:
|
|
||||||
return strconv.ParseInt(v, 10, 64)
|
|
||||||
case bool:
|
|
||||||
if v {
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
case time.Time:
|
|
||||||
return v.Unix(), nil
|
|
||||||
case []byte:
|
|
||||||
return strconv.ParseInt(string(v), 10, 64)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("cannot convert %T to int64", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToStringSafe converts any value to string with error handling.
|
|
||||||
func ToStringSafe(val interface{}) (string, error) {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return "", nil
|
|
||||||
case string:
|
|
||||||
return v, nil
|
|
||||||
case int:
|
|
||||||
return strconv.Itoa(v), nil
|
|
||||||
case int32:
|
|
||||||
return strconv.FormatInt(int64(v), 10), nil
|
|
||||||
case int64:
|
|
||||||
return strconv.FormatInt(v, 10), nil
|
|
||||||
case float32:
|
|
||||||
return strconv.FormatFloat(float64(v), 'f', -1, 32), nil
|
|
||||||
case float64:
|
|
||||||
return strconv.FormatFloat(v, 'f', -1, 64), nil
|
|
||||||
case bool:
|
|
||||||
return strconv.FormatBool(v), nil
|
|
||||||
case time.Time:
|
|
||||||
return v.String(), nil
|
|
||||||
case []byte:
|
|
||||||
return string(v), nil
|
|
||||||
default:
|
|
||||||
return "", fmt.Errorf("cannot convert %T to string", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToFloat64Safe converts any value to float64 with error handling.
|
|
||||||
func ToFloat64Safe(val interface{}) (float64, error) {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return 0, nil
|
|
||||||
case float64:
|
|
||||||
return v, nil
|
|
||||||
case float32:
|
|
||||||
return float64(v), nil
|
|
||||||
case int, int32, int64, uint64:
|
|
||||||
intVal, err := ToInt64Safe(v)
|
|
||||||
return float64(intVal), err
|
|
||||||
case string:
|
|
||||||
return strconv.ParseFloat(v, 64)
|
|
||||||
case []byte:
|
|
||||||
return strconv.ParseFloat(string(v), 64)
|
|
||||||
default:
|
|
||||||
return 0, fmt.Errorf("cannot convert %T to float64", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToBoolSafe converts any value to bool with error handling.
|
|
||||||
func ToBoolSafe(val interface{}) (bool, error) {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return false, nil
|
|
||||||
case bool:
|
|
||||||
return v, nil
|
|
||||||
case int:
|
|
||||||
return v != 0, nil
|
|
||||||
case int8:
|
|
||||||
return v != 0, nil
|
|
||||||
case int16:
|
|
||||||
return v != 0, nil
|
|
||||||
case int32:
|
|
||||||
return v != 0, nil
|
|
||||||
case int64:
|
|
||||||
return v != 0, nil
|
|
||||||
case uint:
|
|
||||||
return v != 0, nil
|
|
||||||
case uint8:
|
|
||||||
return v != 0, nil
|
|
||||||
case uint16:
|
|
||||||
return v != 0, nil
|
|
||||||
case uint32:
|
|
||||||
return v != 0, nil
|
|
||||||
case uint64:
|
|
||||||
return v != 0, nil
|
|
||||||
case float32:
|
|
||||||
return v != 0, nil
|
|
||||||
case float64:
|
|
||||||
return v != 0, nil
|
|
||||||
case string:
|
|
||||||
return ParseBoolString(v)
|
|
||||||
case []byte:
|
|
||||||
return ParseBoolString(string(v))
|
|
||||||
default:
|
|
||||||
return false, fmt.Errorf("cannot convert %T to bool", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseBoolString parses string-like bool values.
|
|
||||||
func ParseBoolString(raw string) (bool, error) {
|
|
||||||
normalized := strings.TrimSpace(strings.ToLower(raw))
|
|
||||||
switch normalized {
|
|
||||||
case "", "0", "false", "f", "off", "no", "n":
|
|
||||||
return false, nil
|
|
||||||
case "1", "true", "t", "on", "yes", "y":
|
|
||||||
return true, nil
|
|
||||||
default:
|
|
||||||
return false, fmt.Errorf("cannot parse bool value: %q", raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToTimeSafe converts any value to time.Time with error handling.
|
|
||||||
func ToTimeSafe(val interface{}) (time.Time, error) {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return time.Time{}, nil
|
|
||||||
case time.Time:
|
|
||||||
return v, nil
|
|
||||||
case int:
|
|
||||||
return time.Unix(int64(v), 0), nil
|
|
||||||
case int64:
|
|
||||||
return time.Unix(v, 0), nil
|
|
||||||
case string:
|
|
||||||
return ParseTimeValue(v)
|
|
||||||
case []byte:
|
|
||||||
return ParseTimeValue(string(v))
|
|
||||||
default:
|
|
||||||
return time.Time{}, fmt.Errorf("cannot convert %T to time.Time", val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseTimeValue parses common SQL date-time formats and unix timestamp.
|
|
||||||
func ParseTimeValue(raw string) (time.Time, error) {
|
|
||||||
trimmed := strings.TrimSpace(raw)
|
|
||||||
if trimmed == "" {
|
|
||||||
return time.Time{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, layout := range defaultNullTimeLayouts {
|
|
||||||
if t, err := time.Parse(layout, trimmed); err == nil {
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ts, err := strconv.ParseInt(trimmed, 10, 64); err == nil {
|
|
||||||
return time.Unix(ts, 0), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return time.Time{}, fmt.Errorf("cannot parse time value: %q", raw)
|
|
||||||
}
|
|
||||||
@ -1,12 +0,0 @@
|
|||||||
package scanutil
|
|
||||||
|
|
||||||
// CloneScannedValue copies driver-scanned values that may be reused by driver.
|
|
||||||
// []byte is deep-copied; other types are returned as-is.
|
|
||||||
func CloneScannedValue(val interface{}) interface{} {
|
|
||||||
if b, ok := val.([]byte); ok {
|
|
||||||
copied := make([]byte, len(b))
|
|
||||||
copy(copied, b)
|
|
||||||
return copied
|
|
||||||
}
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
@ -1,119 +0,0 @@
|
|||||||
package sqlplaceholder
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
|
|
||||||
// It skips quoted strings, quoted identifiers and comments.
|
|
||||||
func ConvertQuestionToDollarPlaceholders(query string) string {
|
|
||||||
if query == "" || !strings.Contains(query, "?") {
|
|
||||||
return query
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
stateNormal = iota
|
|
||||||
stateSingleQuote
|
|
||||||
stateDoubleQuote
|
|
||||||
stateBacktick
|
|
||||||
stateLineComment
|
|
||||||
stateBlockComment
|
|
||||||
)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.Grow(len(query) + 8)
|
|
||||||
|
|
||||||
state := stateNormal
|
|
||||||
index := 1
|
|
||||||
|
|
||||||
for i := 0; i < len(query); i++ {
|
|
||||||
c := query[i]
|
|
||||||
|
|
||||||
switch state {
|
|
||||||
case stateNormal:
|
|
||||||
if c == '\'' {
|
|
||||||
state = stateSingleQuote
|
|
||||||
b.WriteByte(c)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '"' {
|
|
||||||
state = stateDoubleQuote
|
|
||||||
b.WriteByte(c)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '`' {
|
|
||||||
state = stateBacktick
|
|
||||||
b.WriteByte(c)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
|
||||||
state = stateLineComment
|
|
||||||
b.WriteByte(c)
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
|
||||||
state = stateBlockComment
|
|
||||||
b.WriteByte(c)
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '?' {
|
|
||||||
b.WriteByte('$')
|
|
||||||
b.WriteString(strconv.Itoa(index))
|
|
||||||
index++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b.WriteByte(c)
|
|
||||||
|
|
||||||
case stateSingleQuote:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '\'' {
|
|
||||||
// SQL escaped single quote: ''
|
|
||||||
if i+1 < len(query) && query[i+1] == '\'' {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateDoubleQuote:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '"' {
|
|
||||||
// escaped double quote: ""
|
|
||||||
if i+1 < len(query) && query[i+1] == '"' {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateBacktick:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '`' {
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateLineComment:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '\n' {
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateBlockComment:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
@ -1,323 +0,0 @@
|
|||||||
package sqlruntime
|
|
||||||
|
|
||||||
import "strings"
|
|
||||||
|
|
||||||
// FingerprintSQL creates a normalized SQL fingerprint.
|
|
||||||
// mode controls literal masking; keepComments controls whether comments are preserved.
|
|
||||||
func FingerprintSQL(query string, mode int, keepComments bool) string {
|
|
||||||
prepared := query
|
|
||||||
if !keepComments {
|
|
||||||
prepared = stripSQLComments(prepared)
|
|
||||||
}
|
|
||||||
|
|
||||||
normalized := normalizeSQL(prepared)
|
|
||||||
if normalized == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
if NormalizeFingerprintMode(mode) == fingerprintModeMaskLiterals {
|
|
||||||
return maskSQLLiterals(normalized, keepComments)
|
|
||||||
}
|
|
||||||
return normalized
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeSQL(query string) string {
|
|
||||||
normalized := strings.ToLower(strings.TrimSpace(query))
|
|
||||||
if normalized == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return strings.Join(strings.Fields(normalized), " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func stripSQLComments(query string) string {
|
|
||||||
if query == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
stateNormal = iota
|
|
||||||
stateSingleQuote
|
|
||||||
stateDoubleQuote
|
|
||||||
stateBacktick
|
|
||||||
stateLineComment
|
|
||||||
stateBlockComment
|
|
||||||
)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.Grow(len(query))
|
|
||||||
state := stateNormal
|
|
||||||
|
|
||||||
for i := 0; i < len(query); i++ {
|
|
||||||
c := query[i]
|
|
||||||
|
|
||||||
switch state {
|
|
||||||
case stateNormal:
|
|
||||||
if c == '\'' {
|
|
||||||
state = stateSingleQuote
|
|
||||||
b.WriteByte(c)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '"' {
|
|
||||||
state = stateDoubleQuote
|
|
||||||
b.WriteByte(c)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '`' {
|
|
||||||
state = stateBacktick
|
|
||||||
b.WriteByte(c)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
i++
|
|
||||||
state = stateLineComment
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
i++
|
|
||||||
state = stateBlockComment
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
b.WriteByte(c)
|
|
||||||
|
|
||||||
case stateSingleQuote:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '\'' {
|
|
||||||
if i+1 < len(query) && query[i+1] == '\'' {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateDoubleQuote:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '"' {
|
|
||||||
if i+1 < len(query) && query[i+1] == '"' {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateBacktick:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '`' {
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateLineComment:
|
|
||||||
if c == '\n' {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateBlockComment:
|
|
||||||
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
i++
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func maskSQLLiterals(query string, keepComments bool) string {
|
|
||||||
if query == "" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
stateNormal = iota
|
|
||||||
stateSingleQuote
|
|
||||||
stateDoubleQuote
|
|
||||||
stateBacktick
|
|
||||||
stateLineComment
|
|
||||||
stateBlockComment
|
|
||||||
)
|
|
||||||
|
|
||||||
var b strings.Builder
|
|
||||||
b.Grow(len(query))
|
|
||||||
state := stateNormal
|
|
||||||
|
|
||||||
for i := 0; i < len(query); i++ {
|
|
||||||
c := query[i]
|
|
||||||
|
|
||||||
switch state {
|
|
||||||
case stateNormal:
|
|
||||||
if c == '\'' {
|
|
||||||
b.WriteByte('?')
|
|
||||||
state = stateSingleQuote
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '"' {
|
|
||||||
b.WriteByte(c)
|
|
||||||
state = stateDoubleQuote
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '`' {
|
|
||||||
b.WriteByte(c)
|
|
||||||
state = stateBacktick
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '-' && i+1 < len(query) && query[i+1] == '-' {
|
|
||||||
if keepComments {
|
|
||||||
b.WriteByte(c)
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
} else {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
state = stateLineComment
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '/' && i+1 < len(query) && query[i+1] == '*' {
|
|
||||||
if keepComments {
|
|
||||||
b.WriteByte(c)
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
} else {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
state = stateBlockComment
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if c == '$' {
|
|
||||||
j := i + 1
|
|
||||||
for j < len(query) && isDigit(query[j]) {
|
|
||||||
j++
|
|
||||||
}
|
|
||||||
if j > i+1 {
|
|
||||||
b.WriteByte('?')
|
|
||||||
i = j - 1
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c == '-' && i+1 < len(query) && isDigit(query[i+1]) && isNumberBoundaryBefore(query, i) {
|
|
||||||
j := scanNumber(query, i+1)
|
|
||||||
if isNumberBoundaryAfter(query, j) {
|
|
||||||
b.WriteByte('?')
|
|
||||||
i = j - 1
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if isDigit(c) && isNumberBoundaryBefore(query, i) {
|
|
||||||
j := scanNumber(query, i)
|
|
||||||
if isNumberBoundaryAfter(query, j) {
|
|
||||||
b.WriteByte('?')
|
|
||||||
i = j - 1
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
b.WriteByte(c)
|
|
||||||
|
|
||||||
case stateSingleQuote:
|
|
||||||
if c == '\'' {
|
|
||||||
if i+1 < len(query) && query[i+1] == '\'' {
|
|
||||||
i++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateDoubleQuote:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '"' {
|
|
||||||
if i+1 < len(query) && query[i+1] == '"' {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateBacktick:
|
|
||||||
b.WriteByte(c)
|
|
||||||
if c == '`' {
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateLineComment:
|
|
||||||
if keepComments {
|
|
||||||
b.WriteByte(c)
|
|
||||||
}
|
|
||||||
if c == '\n' {
|
|
||||||
if !keepComments {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
|
|
||||||
case stateBlockComment:
|
|
||||||
if keepComments {
|
|
||||||
b.WriteByte(c)
|
|
||||||
}
|
|
||||||
if c == '*' && i+1 < len(query) && query[i+1] == '/' {
|
|
||||||
if keepComments {
|
|
||||||
i++
|
|
||||||
b.WriteByte(query[i])
|
|
||||||
} else {
|
|
||||||
b.WriteByte(' ')
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
state = stateNormal
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return strings.Join(strings.Fields(b.String()), " ")
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDigit(c byte) bool {
|
|
||||||
return c >= '0' && c <= '9'
|
|
||||||
}
|
|
||||||
|
|
||||||
func isIdentifierChar(c byte) bool {
|
|
||||||
return (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '_'
|
|
||||||
}
|
|
||||||
|
|
||||||
func isNumberBoundaryBefore(query string, index int) bool {
|
|
||||||
if index <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
prev := query[index-1]
|
|
||||||
return !isIdentifierChar(prev) && prev != '$' && prev != '.'
|
|
||||||
}
|
|
||||||
|
|
||||||
func isNumberBoundaryAfter(query string, index int) bool {
|
|
||||||
if index >= len(query) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
next := query[index]
|
|
||||||
return !isIdentifierChar(next) && next != '.'
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanNumber(query string, start int) int {
|
|
||||||
i := start
|
|
||||||
for i < len(query) && isDigit(query[i]) {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
if i < len(query) && query[i] == '.' {
|
|
||||||
i++
|
|
||||||
for i < len(query) && isDigit(query[i]) {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if i < len(query) && (query[i] == 'e' || query[i] == 'E') {
|
|
||||||
i++
|
|
||||||
if i < len(query) && (query[i] == '+' || query[i] == '-') {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
for i < len(query) && isDigit(query[i]) {
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
@ -1,27 +0,0 @@
|
|||||||
package sqlruntime
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// CloneHookArgs creates a shallow copy for hook consumers to avoid mutation races.
|
|
||||||
func CloneHookArgs(args []interface{}) []interface{} {
|
|
||||||
if len(args) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
copied := make([]interface{}, len(args))
|
|
||||||
copy(copied, args)
|
|
||||||
return copied
|
|
||||||
}
|
|
||||||
|
|
||||||
// ShouldRunAfterHook decides whether after-hook should run.
|
|
||||||
func ShouldRunAfterHook(hasAfterHook bool, slowThreshold, duration time.Duration, err error) bool {
|
|
||||||
if !hasAfterHook {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
if slowThreshold <= 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return duration >= slowThreshold
|
|
||||||
}
|
|
||||||
@ -1,269 +0,0 @@
|
|||||||
package sqlruntime
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
placeholderQuestion = 0
|
|
||||||
placeholderDollar = 1
|
|
||||||
|
|
||||||
fingerprintModeBasic = 0
|
|
||||||
fingerprintModeMaskLiterals = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
// NormalizePlaceholderStyle converts unknown style values to default question style.
|
|
||||||
func NormalizePlaceholderStyle(style int) int {
|
|
||||||
switch style {
|
|
||||||
case placeholderDollar:
|
|
||||||
return placeholderDollar
|
|
||||||
default:
|
|
||||||
return placeholderQuestion
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NormalizeFingerprintMode converts unknown mode values to default basic mode.
|
|
||||||
func NormalizeFingerprintMode(mode int) int {
|
|
||||||
switch mode {
|
|
||||||
case fingerprintModeMaskLiterals:
|
|
||||||
return fingerprintModeMaskLiterals
|
|
||||||
default:
|
|
||||||
return fingerprintModeBasic
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// State stores runtime SQL behavior toggles in a thread-safe manner.
|
|
||||||
type State struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
beforeHook interface{}
|
|
||||||
afterHook interface{}
|
|
||||||
placeholder int
|
|
||||||
slowThreshold time.Duration
|
|
||||||
fingerprintEnabled bool
|
|
||||||
fingerprintMode int
|
|
||||||
fingerprintKeepComments bool
|
|
||||||
fingerprintCounterEnabled bool
|
|
||||||
fingerprintCounts map[string]uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Options returns snapshot of current runtime options.
|
|
||||||
func (s *State) Options() (before, after interface{}, placeholder int, slowThreshold time.Duration) {
|
|
||||||
if s == nil {
|
|
||||||
return nil, nil, placeholderQuestion, 0
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
before = s.beforeHook
|
|
||||||
after = s.afterHook
|
|
||||||
placeholder = NormalizePlaceholderStyle(s.placeholder)
|
|
||||||
slowThreshold = s.slowThreshold
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return before, after, placeholder, slowThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hooks returns before/after hooks and slow threshold.
|
|
||||||
func (s *State) Hooks() (before, after interface{}, slowThreshold time.Duration) {
|
|
||||||
before, after, _, slowThreshold = s.Options()
|
|
||||||
return before, after, slowThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetHooks sets before/after hooks.
|
|
||||||
func (s *State) SetHooks(before, after interface{}) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.beforeHook = before
|
|
||||||
s.afterHook = after
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetBeforeHook sets before hook.
|
|
||||||
func (s *State) SetBeforeHook(before interface{}) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.beforeHook = before
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetAfterHook sets after hook.
|
|
||||||
func (s *State) SetAfterHook(after interface{}) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.afterHook = after
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPlaceholderStyle sets placeholder style.
|
|
||||||
func (s *State) SetPlaceholderStyle(style int) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.placeholder = NormalizePlaceholderStyle(style)
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PlaceholderStyle returns placeholder style.
|
|
||||||
func (s *State) PlaceholderStyle() int {
|
|
||||||
if s == nil {
|
|
||||||
return placeholderQuestion
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
style := NormalizePlaceholderStyle(s.placeholder)
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return style
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSlowThreshold sets minimum duration for triggering after hook.
|
|
||||||
func (s *State) SetSlowThreshold(threshold time.Duration) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if threshold < 0 {
|
|
||||||
threshold = 0
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.slowThreshold = threshold
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SlowThreshold returns current slow threshold.
|
|
||||||
func (s *State) SlowThreshold() time.Duration {
|
|
||||||
if s == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
threshold := s.slowThreshold
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return threshold
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
|
|
||||||
func (s *State) SetFingerprintEnabled(enabled bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.fingerprintEnabled = enabled
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
|
|
||||||
func (s *State) FingerprintEnabled() bool {
|
|
||||||
if s == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
enabled := s.fingerprintEnabled
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFingerprintMode sets SQL fingerprint mode.
|
|
||||||
func (s *State) SetFingerprintMode(mode int) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.fingerprintMode = NormalizeFingerprintMode(mode)
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FingerprintMode returns SQL fingerprint mode.
|
|
||||||
func (s *State) FingerprintMode() int {
|
|
||||||
if s == nil {
|
|
||||||
return fingerprintModeBasic
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
mode := NormalizeFingerprintMode(s.fingerprintMode)
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return mode
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFingerprintKeepComments toggles comment preservation in generated SQL fingerprints.
|
|
||||||
func (s *State) SetFingerprintKeepComments(keep bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.fingerprintKeepComments = keep
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FingerprintKeepComments reports whether comments are kept in generated SQL fingerprints.
|
|
||||||
func (s *State) FingerprintKeepComments() bool {
|
|
||||||
if s == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
keep := s.fingerprintKeepComments
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return keep
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
|
|
||||||
func (s *State) SetFingerprintCounterEnabled(enabled bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.fingerprintCounterEnabled = enabled
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
|
|
||||||
func (s *State) FingerprintCounterEnabled() bool {
|
|
||||||
if s == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
enabled := s.fingerprintCounterEnabled
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return enabled
|
|
||||||
}
|
|
||||||
|
|
||||||
// IncFingerprintCount increments hit count for a fingerprint.
|
|
||||||
func (s *State) IncFingerprintCount(fingerprint string) {
|
|
||||||
if s == nil || fingerprint == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
if s.fingerprintCounts == nil {
|
|
||||||
s.fingerprintCounts = make(map[string]uint64)
|
|
||||||
}
|
|
||||||
s.fingerprintCounts[fingerprint]++
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// FingerprintCountsSnapshot returns a snapshot copy of fingerprint counters.
|
|
||||||
func (s *State) FingerprintCountsSnapshot() map[string]uint64 {
|
|
||||||
if s == nil {
|
|
||||||
return map[string]uint64{}
|
|
||||||
}
|
|
||||||
s.mu.RLock()
|
|
||||||
if len(s.fingerprintCounts) == 0 {
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return map[string]uint64{}
|
|
||||||
}
|
|
||||||
out := make(map[string]uint64, len(s.fingerprintCounts))
|
|
||||||
for k, v := range s.fingerprintCounts {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetFingerprintCounts clears all fingerprint counters.
|
|
||||||
func (s *State) ResetFingerprintCounts() {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.mu.Lock()
|
|
||||||
s.fingerprintCounts = nil
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
615
orm.go
615
orm.go
@ -1,615 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Orm maps query results to a struct or slice of structs
|
|
||||||
// Usage:
|
|
||||||
//
|
|
||||||
// var user User
|
|
||||||
// rows.Orm(&user) // single row
|
|
||||||
// var users []User
|
|
||||||
// rows.Orm(&users) // multiple rows
|
|
||||||
func (r *StarRows) Orm(target interface{}) error {
|
|
||||||
if !r.parsed {
|
|
||||||
if err := r.parse(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if target == nil {
|
|
||||||
return ErrTargetNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Ptr {
|
|
||||||
return ErrTargetNotPointer
|
|
||||||
}
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return ErrTargetPointerNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
|
|
||||||
// Handle slice
|
|
||||||
if targetValue.Kind() == reflect.Slice {
|
|
||||||
elementType := targetType.Elem()
|
|
||||||
result := reflect.MakeSlice(targetType, 0, r.Length())
|
|
||||||
|
|
||||||
if r.Length() == 0 {
|
|
||||||
targetValue.Set(result)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < r.Length(); i++ {
|
|
||||||
element := reflect.New(elementType)
|
|
||||||
if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
result = reflect.Append(result, element.Elem())
|
|
||||||
}
|
|
||||||
|
|
||||||
targetValue.Set(result)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle array
|
|
||||||
if targetValue.Kind() == reflect.Array {
|
|
||||||
elementType := targetType.Elem()
|
|
||||||
|
|
||||||
if r.Length() == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.Length() > targetValue.Len() {
|
|
||||||
return fmt.Errorf("target array length %d is smaller than rows %d", targetValue.Len(), r.Length())
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < r.Length(); i++ {
|
|
||||||
element := reflect.New(elementType)
|
|
||||||
if err := r.setStructFieldsFromRow(element.Interface(), "db", i); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
targetValue.Index(i).Set(element.Elem())
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle single struct
|
|
||||||
if r.Length() == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.setStructFieldsFromRow(target, "db", 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func bindNamedArgs(args []interface{}, fieldValues map[string]interface{}) ([]interface{}, error) {
|
|
||||||
processedArgs := make([]interface{}, len(args))
|
|
||||||
for i, arg := range args {
|
|
||||||
str, ok := arg.(string)
|
|
||||||
if !ok {
|
|
||||||
processedArgs[i] = arg
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(str, `\:`) {
|
|
||||||
processedArgs[i] = str[1:]
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(str, ":") {
|
|
||||||
fieldName := str[1:]
|
|
||||||
val, exists := fieldValues[fieldName]
|
|
||||||
if !exists {
|
|
||||||
return nil, wrapFieldNotFound(fieldName)
|
|
||||||
}
|
|
||||||
processedArgs[i] = val
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
processedArgs[i] = arg
|
|
||||||
}
|
|
||||||
return processedArgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryX executes a query with named parameter binding
|
|
||||||
// Usage: QueryX(&user, "SELECT * FROM users WHERE id = ?", ":id")
|
|
||||||
func (s *StarDB) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return s.queryX(nil, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryXContext executes a query with context and named parameter binding
|
|
||||||
func (s *StarDB) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return s.queryX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryX is the internal implementation
|
|
||||||
func (s *StarDB) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
fieldValues, err := getStructFieldValues(target, "db")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
processedArgs, err := bindNamedArgs(args, fieldValues)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.query(ctx, query, processedArgs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryXS executes queries for multiple structs
|
|
||||||
func (s *StarDB) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
|
||||||
return s.queryXS(nil, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryXSContext executes queries with context for multiple structs
|
|
||||||
func (s *StarDB) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
|
||||||
return s.queryXS(ctx, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryXS is the internal implementation
|
|
||||||
func (s *StarDB) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
|
||||||
var results []*StarRows
|
|
||||||
if targets == nil {
|
|
||||||
return results, ErrTargetsNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
|
||||||
targetValue := reflect.ValueOf(targets)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return results, ErrTargetsPointerNil
|
|
||||||
}
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
|
||||||
results = make([]*StarRows, 0, targetValue.Len())
|
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
|
||||||
result, err := s.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result, err := s.queryX(ctx, targets, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecX executes a statement with named parameter binding
|
|
||||||
func (s *StarDB) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return s.execX(nil, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecXContext executes a statement with context and named parameter binding
|
|
||||||
func (s *StarDB) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return s.execX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// execX is the internal implementation
|
|
||||||
func (s *StarDB) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
fieldValues, err := getStructFieldValues(target, "db")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
processedArgs, err := bindNamedArgs(args, fieldValues)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.exec(ctx, query, processedArgs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecXS executes statements for multiple structs
|
|
||||||
func (s *StarDB) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
|
||||||
return s.execXS(nil, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecXSContext executes statements with context for multiple structs
|
|
||||||
func (s *StarDB) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
|
||||||
return s.execXS(ctx, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// execXS is the internal implementation
|
|
||||||
func (s *StarDB) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
|
||||||
var results []sql.Result
|
|
||||||
if targets == nil {
|
|
||||||
return results, ErrTargetsNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
|
||||||
targetValue := reflect.ValueOf(targets)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return results, ErrTargetsPointerNil
|
|
||||||
}
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
|
||||||
results = make([]sql.Result, 0, targetValue.Len())
|
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
|
||||||
result, err := s.execX(ctx, targetValue.Index(i).Interface(), query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result, err := s.execX(ctx, targets, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert inserts a struct into the database
|
|
||||||
// Usage: Insert(&user, "users", "id") // id is auto-increment
|
|
||||||
func (s *StarDB) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
return s.insert(nil, target, tableName, autoIncrementFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertContext inserts a struct with context
|
|
||||||
func (s *StarDB) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
return s.insert(ctx, target, tableName, autoIncrementFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert is the internal implementation
|
|
||||||
func (s *StarDB) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]interface{}, len(params))
|
|
||||||
for i, param := range params {
|
|
||||||
args[i] = param
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.execX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update updates a struct in the database
|
|
||||||
// Usage: Update(&user, "users", "id") // id is primary key
|
|
||||||
func (s *StarDB) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
|
|
||||||
return s.update(nil, target, tableName, primaryKeys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateContext updates a struct with context
|
|
||||||
func (s *StarDB) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
|
|
||||||
return s.update(ctx, target, tableName, primaryKeys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// update is the internal implementation
|
|
||||||
func (s *StarDB) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
|
|
||||||
query, params, err := buildUpdateSQL(target, tableName, primaryKeys...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]interface{}, len(params))
|
|
||||||
for i, param := range params {
|
|
||||||
args[i] = param
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.execX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildInsertSQL builds an INSERT SQL statement
|
|
||||||
func buildInsertSQL(target interface{}, tableName string, autoIncrementFields ...string) (string, []string, error) {
|
|
||||||
if strings.TrimSpace(tableName) == "" {
|
|
||||||
return "", []string{}, ErrTableNameEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldNames, err := getStructFieldNames(target, "db")
|
|
||||||
if err != nil {
|
|
||||||
return "", []string{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var columns []string
|
|
||||||
var placeholders []string
|
|
||||||
var params []string
|
|
||||||
autoIncrementSet := make(map[string]struct{}, len(autoIncrementFields))
|
|
||||||
for _, autoField := range autoIncrementFields {
|
|
||||||
autoIncrementSet[autoField] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, fieldName := range fieldNames {
|
|
||||||
// Skip auto-increment fields
|
|
||||||
if _, isAutoIncrement := autoIncrementSet[fieldName]; isAutoIncrement {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
columns = append(columns, fieldName)
|
|
||||||
placeholders = append(placeholders, "?")
|
|
||||||
params = append(params, ":"+fieldName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(columns) == 0 {
|
|
||||||
return "", []string{}, ErrNoInsertColumns
|
|
||||||
}
|
|
||||||
|
|
||||||
query := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)",
|
|
||||||
tableName,
|
|
||||||
strings.Join(columns, ", "),
|
|
||||||
strings.Join(placeholders, ", "))
|
|
||||||
|
|
||||||
return query, params, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// buildUpdateSQL builds an UPDATE SQL statement
|
|
||||||
func buildUpdateSQL(target interface{}, tableName string, primaryKeys ...string) (string, []string, error) {
|
|
||||||
if strings.TrimSpace(tableName) == "" {
|
|
||||||
return "", []string{}, ErrTableNameEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
fieldNames, err := getStructFieldNames(target, "db")
|
|
||||||
if err != nil {
|
|
||||||
return "", []string{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(primaryKeys) == 0 {
|
|
||||||
return "", []string{}, ErrPrimaryKeyRequired
|
|
||||||
}
|
|
||||||
|
|
||||||
primaryKeySet := make(map[string]struct{}, len(primaryKeys))
|
|
||||||
for _, pk := range primaryKeys {
|
|
||||||
if pk == "" {
|
|
||||||
return "", []string{}, ErrPrimaryKeyEmpty
|
|
||||||
}
|
|
||||||
primaryKeySet[pk] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var setClauses []string
|
|
||||||
var params []string
|
|
||||||
|
|
||||||
// Build SET clause
|
|
||||||
for _, fieldName := range fieldNames {
|
|
||||||
if _, isPrimaryKey := primaryKeySet[fieldName]; isPrimaryKey {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
setClauses = append(setClauses, fmt.Sprintf("%s = ?", fieldName))
|
|
||||||
params = append(params, ":"+fieldName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(setClauses) == 0 {
|
|
||||||
return "", []string{}, ErrNoUpdateFields
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build WHERE clause
|
|
||||||
var whereClauses []string
|
|
||||||
for _, pk := range primaryKeys {
|
|
||||||
whereClauses = append(whereClauses, fmt.Sprintf("%s = ?", pk))
|
|
||||||
params = append(params, ":"+pk)
|
|
||||||
}
|
|
||||||
|
|
||||||
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s",
|
|
||||||
tableName,
|
|
||||||
strings.Join(setClauses, ", "),
|
|
||||||
strings.Join(whereClauses, " AND "))
|
|
||||||
|
|
||||||
return query, params, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Transaction ORM methods
|
|
||||||
|
|
||||||
// QueryX executes a query with named parameter binding in a transaction
|
|
||||||
func (t *StarTx) QueryX(target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return t.queryX(nil, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryXContext executes a query with context in a transaction
|
|
||||||
func (t *StarTx) QueryXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return t.queryX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryX is the internal implementation
|
|
||||||
func (t *StarTx) queryX(ctx context.Context, target interface{}, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
fieldValues, err := getStructFieldValues(target, "db")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
processedArgs, err := bindNamedArgs(args, fieldValues)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.query(ctx, query, processedArgs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryXS executes queries for multiple structs in a transaction
|
|
||||||
func (t *StarTx) QueryXS(targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
|
||||||
return t.queryXS(nil, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryXSContext executes queries with context for multiple structs in a transaction
|
|
||||||
func (t *StarTx) QueryXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
|
||||||
return t.queryXS(ctx, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// queryXS is the internal implementation
|
|
||||||
func (t *StarTx) queryXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]*StarRows, error) {
|
|
||||||
var results []*StarRows
|
|
||||||
if targets == nil {
|
|
||||||
return results, ErrTargetsNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
|
||||||
targetValue := reflect.ValueOf(targets)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return results, ErrTargetsPointerNil
|
|
||||||
}
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
|
||||||
results = make([]*StarRows, 0, targetValue.Len())
|
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
|
||||||
result, err := t.queryX(ctx, targetValue.Index(i).Interface(), query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result, err := t.queryX(ctx, targets, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecX executes a statement with named parameter binding in a transaction
|
|
||||||
func (t *StarTx) ExecX(target interface{}, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return t.execX(nil, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecXContext executes a statement with context in a transaction
|
|
||||||
func (t *StarTx) ExecXContext(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return t.execX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// execX is the internal implementation
|
|
||||||
func (t *StarTx) execX(ctx context.Context, target interface{}, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
fieldValues, err := getStructFieldValues(target, "db")
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
processedArgs, err := bindNamedArgs(args, fieldValues)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.exec(ctx, query, processedArgs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecXS executes statements for multiple structs in a transaction
|
|
||||||
func (t *StarTx) ExecXS(targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
|
||||||
return t.execXS(nil, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecXSContext executes statements with context for multiple structs in a transaction
|
|
||||||
func (t *StarTx) ExecXSContext(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
|
||||||
return t.execXS(ctx, targets, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// execXS is the internal implementation
|
|
||||||
func (t *StarTx) execXS(ctx context.Context, targets interface{}, query string, args ...interface{}) ([]sql.Result, error) {
|
|
||||||
var results []sql.Result
|
|
||||||
if targets == nil {
|
|
||||||
return results, ErrTargetsNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(targets)
|
|
||||||
targetValue := reflect.ValueOf(targets)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return results, ErrTargetsPointerNil
|
|
||||||
}
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() == reflect.Slice || targetValue.Kind() == reflect.Array {
|
|
||||||
results = make([]sql.Result, 0, targetValue.Len())
|
|
||||||
for i := 0; i < targetValue.Len(); i++ {
|
|
||||||
result, err := t.execX(ctx, targetValue.Index(i).Interface(), query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result, err := t.execX(ctx, targets, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
results = append(results, result)
|
|
||||||
}
|
|
||||||
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert inserts a struct in a transaction
|
|
||||||
func (t *StarTx) Insert(target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
return t.insert(nil, target, tableName, autoIncrementFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertContext inserts a struct with context in a transaction
|
|
||||||
func (t *StarTx) InsertContext(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
return t.insert(ctx, target, tableName, autoIncrementFields...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// insert is the internal implementation
|
|
||||||
func (t *StarTx) insert(ctx context.Context, target interface{}, tableName string, autoIncrementFields ...string) (sql.Result, error) {
|
|
||||||
query, params, err := buildInsertSQL(target, tableName, autoIncrementFields...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]interface{}, len(params))
|
|
||||||
for i, param := range params {
|
|
||||||
args[i] = param
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.execX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update updates a struct in a transaction
|
|
||||||
func (t *StarTx) Update(target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
|
|
||||||
return t.update(nil, target, tableName, primaryKeys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateContext updates a struct with context in a transaction
|
|
||||||
func (t *StarTx) UpdateContext(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
|
|
||||||
return t.update(ctx, target, tableName, primaryKeys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// update is the internal implementation
|
|
||||||
func (t *StarTx) update(ctx context.Context, target interface{}, tableName string, primaryKeys ...string) (sql.Result, error) {
|
|
||||||
query, params, err := buildUpdateSQL(target, tableName, primaryKeys...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([]interface{}, len(params))
|
|
||||||
for i, param := range params {
|
|
||||||
args[i] = param
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.execX(ctx, target, query, args...)
|
|
||||||
}
|
|
||||||
275
orm_test.go
275
orm_test.go
@ -1,275 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Email string `db:"email"`
|
|
||||||
Age int `db:"age"`
|
|
||||||
Balance float64 `db:"balance"`
|
|
||||||
Active bool `db:"active"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Profile struct {
|
|
||||||
UserID int `db:"user_id"`
|
|
||||||
Bio string `db:"bio"`
|
|
||||||
Avatar string `db:"avatar"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type NestedUser struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Profile `db:"---"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type NestedUserPtr struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Profile *Profile `db:"---"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AutoIDOnly struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HiddenTagged struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
hidden string `db:"hidden"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildInsertSQL(t *testing.T) {
|
|
||||||
user := User{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
Email: "test@example.com",
|
|
||||||
Age: 30,
|
|
||||||
Balance: 100.0,
|
|
||||||
Active: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
query, params, err := buildInsertSQL(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("buildInsertSQL failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedQuery := "INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedParams := []string{":name", ":email", ":age", ":balance", ":active", ":created_at"}
|
|
||||||
if len(params) != len(expectedParams) {
|
|
||||||
t.Errorf("Expected %d params, got %d", len(expectedParams), len(params))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, param := range params {
|
|
||||||
if param != expectedParams[i] {
|
|
||||||
t.Errorf("Expected param %s, got %s", expectedParams[i], param)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildUpdateSQL(t *testing.T) {
|
|
||||||
user := User{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
Email: "test@example.com",
|
|
||||||
Age: 30,
|
|
||||||
Balance: 100.0,
|
|
||||||
Active: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
query, params, err := buildUpdateSQL(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("buildUpdateSQL failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedQuery := "UPDATE users SET name = ?, email = ?, age = ?, balance = ?, active = ?, created_at = ? WHERE id = ?"
|
|
||||||
if query != expectedQuery {
|
|
||||||
t.Errorf("Expected query:\n%s\nGot:\n%s", expectedQuery, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedParamCount := 7 // 6 fields + 1 primary key
|
|
||||||
if len(params) != expectedParamCount {
|
|
||||||
t.Errorf("Expected %d params, got %d", expectedParamCount, len(params))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildInsertSQL_NoColumns(t *testing.T) {
|
|
||||||
model := AutoIDOnly{ID: 1}
|
|
||||||
|
|
||||||
_, _, err := buildInsertSQL(&model, "users", "id")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when no columns remain to insert, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrNoInsertColumns) {
|
|
||||||
t.Fatalf("Expected ErrNoInsertColumns, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildInsertSQL_EmptyTableName(t *testing.T) {
|
|
||||||
user := User{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err := buildInsertSQL(&user, "", "id")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when table name is empty, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrTableNameEmpty) {
|
|
||||||
t.Fatalf("Expected ErrTableNameEmpty, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildUpdateSQL_NoPrimaryKey(t *testing.T) {
|
|
||||||
user := User{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err := buildUpdateSQL(&user, "users")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when no primary key is provided, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrPrimaryKeyRequired) {
|
|
||||||
t.Fatalf("Expected ErrPrimaryKeyRequired, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildUpdateSQL_EmptyTableName(t *testing.T) {
|
|
||||||
user := User{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, err := buildUpdateSQL(&user, "", "id")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when table name is empty, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrTableNameEmpty) {
|
|
||||||
t.Fatalf("Expected ErrTableNameEmpty, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBuildUpdateSQL_OnlyPrimaryKey(t *testing.T) {
|
|
||||||
model := AutoIDOnly{ID: 1}
|
|
||||||
|
|
||||||
_, _, err := buildUpdateSQL(&model, "users", "id")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error when no fields remain for SET clause, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrNoUpdateFields) {
|
|
||||||
t.Fatalf("Expected ErrNoUpdateFields, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStructFieldValues_NilNestedPointer(t *testing.T) {
|
|
||||||
user := NestedUserPtr{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
Profile: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
values, err := getStructFieldValues(user, "db")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getStructFieldValues failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if values["id"] != int64(1) {
|
|
||||||
t.Errorf("Expected id=1, got %v", values["id"])
|
|
||||||
}
|
|
||||||
if values["name"] != "Test" {
|
|
||||||
t.Errorf("Expected name=Test, got %v", values["name"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStructFieldNames_NilNestedPointer(t *testing.T) {
|
|
||||||
user := NestedUserPtr{
|
|
||||||
ID: 1,
|
|
||||||
Name: "Test",
|
|
||||||
Profile: nil,
|
|
||||||
}
|
|
||||||
|
|
||||||
names, err := getStructFieldNames(user, "db")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getStructFieldNames failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := []string{"id", "name"}
|
|
||||||
if !reflect.DeepEqual(names, expected) {
|
|
||||||
t.Errorf("Expected names %v, got %v", expected, names)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStructFieldNames_SkipUnexportedField(t *testing.T) {
|
|
||||||
model := HiddenTagged{
|
|
||||||
ID: 1,
|
|
||||||
hidden: "secret",
|
|
||||||
}
|
|
||||||
|
|
||||||
names, err := getStructFieldNames(model, "db")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getStructFieldNames failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expected := []string{"id"}
|
|
||||||
if !reflect.DeepEqual(names, expected) {
|
|
||||||
t.Errorf("Expected names %v, got %v", expected, names)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStructFieldValues_NilTarget(t *testing.T) {
|
|
||||||
_, err := getStructFieldValues(nil, "db")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error for nil target, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrTargetNil) {
|
|
||||||
t.Fatalf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGetStructFieldNames_NilTarget(t *testing.T) {
|
|
||||||
_, err := getStructFieldNames(nil, "db")
|
|
||||||
if err == nil {
|
|
||||||
t.Fatal("Expected error for nil target, got nil")
|
|
||||||
}
|
|
||||||
if !errors.Is(err, ErrTargetNil) {
|
|
||||||
t.Fatalf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestClearReflectCache(t *testing.T) {
|
|
||||||
type cacheUser struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
}
|
|
||||||
|
|
||||||
typ := reflect.TypeOf(cacheUser{})
|
|
||||||
plan1, err := getStructTagPlan(typ, "db")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getStructTagPlan failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(plan1) != 2 {
|
|
||||||
t.Fatalf("Expected 2 fields in plan, got %d", len(plan1))
|
|
||||||
}
|
|
||||||
|
|
||||||
ClearReflectCache()
|
|
||||||
|
|
||||||
plan2, err := getStructTagPlan(typ, "db")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("getStructTagPlan after clear failed: %v", err)
|
|
||||||
}
|
|
||||||
if len(plan2) != 2 {
|
|
||||||
t.Fatalf("Expected 2 fields in plan after clear, got %d", len(plan2))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
58
pool.go
58
pool.go
@ -1,58 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PoolConfig represents database connection pool configuration
|
|
||||||
type PoolConfig struct {
|
|
||||||
MaxOpenConns int
|
|
||||||
MaxIdleConns int
|
|
||||||
ConnMaxLifetime time.Duration
|
|
||||||
ConnMaxIdleTime time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultPoolConfig returns default pool configuration
|
|
||||||
func DefaultPoolConfig() *PoolConfig {
|
|
||||||
return &PoolConfig{
|
|
||||||
MaxOpenConns: 25,
|
|
||||||
MaxIdleConns: 5,
|
|
||||||
ConnMaxLifetime: time.Hour,
|
|
||||||
ConnMaxIdleTime: 10 * time.Minute,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPoolConfig applies pool configuration to the database
|
|
||||||
func (s *StarDB) SetPoolConfig(config *PoolConfig) {
|
|
||||||
if s == nil || s.db == nil || config == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.MaxOpenConns > 0 {
|
|
||||||
s.db.SetMaxOpenConns(config.MaxOpenConns)
|
|
||||||
}
|
|
||||||
if config.MaxIdleConns > 0 {
|
|
||||||
s.db.SetMaxIdleConns(config.MaxIdleConns)
|
|
||||||
}
|
|
||||||
if config.ConnMaxLifetime > 0 {
|
|
||||||
s.db.SetConnMaxLifetime(config.ConnMaxLifetime)
|
|
||||||
}
|
|
||||||
if config.ConnMaxIdleTime > 0 {
|
|
||||||
s.db.SetConnMaxIdleTime(config.ConnMaxIdleTime)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenWithPool opens a database connection with pool configuration
|
|
||||||
func OpenWithPool(driver, connStr string, config *PoolConfig) (*StarDB, error) {
|
|
||||||
db := NewStarDB()
|
|
||||||
if err := db.Open(driver, connStr); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config == nil {
|
|
||||||
config = DefaultPoolConfig()
|
|
||||||
}
|
|
||||||
db.SetPoolConfig(config)
|
|
||||||
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
375
reflect.go
375
reflect.go
@ -1,375 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type structTagField struct {
|
|
||||||
path []int
|
|
||||||
tag string
|
|
||||||
}
|
|
||||||
|
|
||||||
type structTagPlanKey struct {
|
|
||||||
typ reflect.Type
|
|
||||||
tagKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
var structTagPlanCache sync.Map
|
|
||||||
|
|
||||||
// ClearReflectCache clears internal reflection metadata cache.
|
|
||||||
// Useful after schema/tag refactors in long-running processes.
|
|
||||||
func ClearReflectCache() {
|
|
||||||
structTagPlanCache = sync.Map{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getStructTagPlan(targetType reflect.Type, tagKey string) ([]structTagField, error) {
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
}
|
|
||||||
if targetType.Kind() != reflect.Struct {
|
|
||||||
return nil, ErrTargetNotStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
cacheKey := structTagPlanKey{
|
|
||||||
typ: targetType,
|
|
||||||
tagKey: tagKey,
|
|
||||||
}
|
|
||||||
if cached, ok := structTagPlanCache.Load(cacheKey); ok {
|
|
||||||
return cached.([]structTagField), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := make([]structTagField, 0, targetType.NumField())
|
|
||||||
if err := buildStructTagPlan(targetType, tagKey, nil, &fields); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
structTagPlanCache.Store(cacheKey, fields)
|
|
||||||
return fields, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildStructTagPlan(currentType reflect.Type, tagKey string, prefix []int, out *[]structTagField) error {
|
|
||||||
if currentType.Kind() == reflect.Ptr {
|
|
||||||
currentType = currentType.Elem()
|
|
||||||
}
|
|
||||||
if currentType.Kind() != reflect.Struct {
|
|
||||||
return ErrTargetNotStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < currentType.NumField(); i++ {
|
|
||||||
field := currentType.Field(i)
|
|
||||||
tagValue := field.Tag.Get(tagKey)
|
|
||||||
fieldType := field.Type
|
|
||||||
|
|
||||||
path := make([]int, len(prefix)+1)
|
|
||||||
copy(path, prefix)
|
|
||||||
path[len(prefix)] = i
|
|
||||||
|
|
||||||
if tagValue == "---" {
|
|
||||||
if fieldType.Kind() == reflect.Ptr && fieldType.Elem().Kind() == reflect.Struct {
|
|
||||||
if err := buildStructTagPlan(fieldType.Elem(), tagKey, path, out); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fieldType.Kind() == reflect.Struct {
|
|
||||||
if err := buildStructTagPlan(fieldType, tagKey, path, out); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tagValue != "" {
|
|
||||||
*out = append(*out, structTagField{
|
|
||||||
path: path,
|
|
||||||
tag: tagValue,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func resolveFieldByPath(root reflect.Value, path []int) (reflect.Value, bool) {
|
|
||||||
current := root
|
|
||||||
for _, idx := range path {
|
|
||||||
if current.Kind() == reflect.Ptr {
|
|
||||||
if current.IsNil() {
|
|
||||||
return reflect.Value{}, false
|
|
||||||
}
|
|
||||||
current = current.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if current.Kind() != reflect.Struct {
|
|
||||||
return reflect.Value{}, false
|
|
||||||
}
|
|
||||||
if idx < 0 || idx >= current.NumField() {
|
|
||||||
return reflect.Value{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
current = current.Field(idx)
|
|
||||||
}
|
|
||||||
|
|
||||||
return current, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// setStructFieldsFromRow sets struct fields from a row result using reflection
|
|
||||||
func (r *StarRows) setStructFieldsFromRow(target interface{}, tagKey string, rowIndex int) error {
|
|
||||||
if target == nil {
|
|
||||||
return ErrTargetNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Ptr && !targetValue.CanSet() {
|
|
||||||
return ErrTargetNotWritable
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Struct {
|
|
||||||
return ErrTargetNotStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
row := r.Row(rowIndex)
|
|
||||||
if row.columnIndex == nil {
|
|
||||||
return ErrRowIndexOutOfRange
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < targetType.NumField(); i++ {
|
|
||||||
field := targetType.Field(i)
|
|
||||||
fieldValue := targetValue.Field(i)
|
|
||||||
tagValue := field.Tag.Get(tagKey)
|
|
||||||
|
|
||||||
if !fieldValue.CanInterface() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip unexported or otherwise non-settable fields.
|
|
||||||
if !fieldValue.CanSet() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle nested structs
|
|
||||||
if fieldValue.Kind() == reflect.Ptr && fieldValue.Type().Elem().Kind() == reflect.Struct {
|
|
||||||
if tagValue == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if tagValue == "---" {
|
|
||||||
nestedPtr := reflect.New(fieldValue.Type().Elem()).Interface()
|
|
||||||
if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if fieldValue.Kind() == reflect.Struct {
|
|
||||||
if tagValue == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if tagValue == "---" {
|
|
||||||
nestedPtr := reflect.New(reflect.TypeOf(targetValue.Field(i).Interface())).Interface()
|
|
||||||
if err := r.setStructFieldsFromRow(nestedPtr, tagKey, rowIndex); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
targetValue.Field(i).Set(reflect.ValueOf(nestedPtr).Elem())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if tagValue == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if column exists
|
|
||||||
if _, ok := row.columnIndex[tagValue]; !ok {
|
|
||||||
if r.db != nil && r.db.StrictORM {
|
|
||||||
return wrapColumnNotFound(tagValue)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set field value based on type
|
|
||||||
r.setFieldValue(fieldValue, tagValue, row)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setFieldValue sets a single field value
|
|
||||||
func (r *StarRows) setFieldValue(fieldValue reflect.Value, columnName string, row *StarResult) {
|
|
||||||
switch fieldValue.Kind() {
|
|
||||||
case reflect.String:
|
|
||||||
fieldValue.SetString(row.MustString(columnName))
|
|
||||||
case reflect.Int:
|
|
||||||
fieldValue.SetInt(int64(row.MustInt(columnName)))
|
|
||||||
case reflect.Int8:
|
|
||||||
fieldValue.SetInt(int64(int8(row.MustInt64(columnName))))
|
|
||||||
case reflect.Int16:
|
|
||||||
fieldValue.SetInt(int64(int16(row.MustInt64(columnName))))
|
|
||||||
case reflect.Int32:
|
|
||||||
fieldValue.SetInt(int64(row.MustInt32(columnName)))
|
|
||||||
case reflect.Int64:
|
|
||||||
fieldValue.SetInt(row.MustInt64(columnName))
|
|
||||||
case reflect.Uint:
|
|
||||||
fieldValue.SetUint(uint64(row.MustUint64(columnName)))
|
|
||||||
case reflect.Uint8:
|
|
||||||
fieldValue.SetUint(uint64(uint8(row.MustUint64(columnName))))
|
|
||||||
case reflect.Uint16:
|
|
||||||
fieldValue.SetUint(uint64(uint16(row.MustUint64(columnName))))
|
|
||||||
case reflect.Uint32:
|
|
||||||
fieldValue.SetUint(uint64(uint32(row.MustUint64(columnName))))
|
|
||||||
case reflect.Uint64:
|
|
||||||
fieldValue.SetUint(row.MustUint64(columnName))
|
|
||||||
case reflect.Bool:
|
|
||||||
fieldValue.SetBool(row.MustBool(columnName))
|
|
||||||
case reflect.Float32:
|
|
||||||
fieldValue.SetFloat(float64(row.MustFloat32(columnName)))
|
|
||||||
case reflect.Float64:
|
|
||||||
fieldValue.SetFloat(row.MustFloat64(columnName))
|
|
||||||
case reflect.Struct:
|
|
||||||
// Handle special struct types like time.Time
|
|
||||||
colIndex := row.columnIndex[columnName]
|
|
||||||
val := row.Result()[colIndex]
|
|
||||||
if fieldValue.Type() == reflect.TypeOf(time.Time{}) {
|
|
||||||
if t, ok := val.(time.Time); ok {
|
|
||||||
fieldValue.Set(reflect.ValueOf(t))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.Ptr:
|
|
||||||
// Handle pointer to special types like *time.Time
|
|
||||||
colIndex := row.columnIndex[columnName]
|
|
||||||
val := row.Result()[colIndex]
|
|
||||||
if fieldValue.Type().Elem() == reflect.TypeOf(time.Time{}) {
|
|
||||||
if t, ok := val.(time.Time); ok {
|
|
||||||
tCopy := t
|
|
||||||
fieldValue.Set(reflect.ValueOf(&tCopy))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case reflect.Interface:
|
|
||||||
colIndex := row.columnIndex[columnName]
|
|
||||||
val := row.Result()[colIndex]
|
|
||||||
if val != nil {
|
|
||||||
fieldValue.Set(reflect.ValueOf(val))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// getStructFieldValues extracts all field values from a struct
|
|
||||||
func getStructFieldValues(target interface{}, tagKey string) (map[string]interface{}, error) {
|
|
||||||
if target == nil {
|
|
||||||
return nil, ErrTargetNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return nil, ErrPointerTargetNil
|
|
||||||
}
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Struct {
|
|
||||||
return nil, ErrTargetNotStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
plan, err := getStructTagPlan(targetType, tagKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make(map[string]interface{}, len(plan))
|
|
||||||
for _, field := range plan {
|
|
||||||
fieldValue, ok := resolveFieldByPath(targetValue, field.path)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !fieldValue.CanInterface() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result[field.tag] = fieldValue.Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// getStructFieldNames extracts all field names (tag values) from a struct
|
|
||||||
func getStructFieldNames(target interface{}, tagKey string) ([]string, error) {
|
|
||||||
if target == nil {
|
|
||||||
return []string{}, ErrTargetNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
|
|
||||||
if targetType.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return []string{}, ErrPointerTargetNil
|
|
||||||
}
|
|
||||||
targetType = targetType.Elem()
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if targetValue.Kind() != reflect.Struct {
|
|
||||||
return []string{}, ErrTargetNotStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
plan, err := getStructTagPlan(targetType, tagKey)
|
|
||||||
if err != nil {
|
|
||||||
return []string{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
result := make([]string, 0, len(plan))
|
|
||||||
for _, field := range plan {
|
|
||||||
fieldValue, ok := resolveFieldByPath(targetValue, field.path)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !fieldValue.CanInterface() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
result = append(result, field.tag)
|
|
||||||
}
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// isWritable checks if a value is writable
|
|
||||||
func isWritable(target interface{}) bool {
|
|
||||||
if target == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
return targetType.Kind() == reflect.Ptr || targetValue.CanSet()
|
|
||||||
}
|
|
||||||
|
|
||||||
// isStruct checks if a value is a struct
|
|
||||||
func isStruct(target interface{}) bool {
|
|
||||||
if target == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
if targetValue.Kind() == reflect.Ptr {
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
targetValue = targetValue.Elem()
|
|
||||||
}
|
|
||||||
return targetValue.Kind() == reflect.Struct
|
|
||||||
}
|
|
||||||
286
result.go
286
result.go
@ -1,286 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StarResult represents a single row result
|
|
||||||
type StarResult struct {
|
|
||||||
result []interface{}
|
|
||||||
columns []string
|
|
||||||
columnIndex map[string]int
|
|
||||||
columnsType []reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
// Result returns the raw result slice
|
|
||||||
func (r *StarResult) Result() []interface{} {
|
|
||||||
return r.result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Columns returns column names
|
|
||||||
func (r *StarResult) Columns() []string {
|
|
||||||
return r.columns
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnsType returns column types
|
|
||||||
func (r *StarResult) ColumnsType() []reflect.Type {
|
|
||||||
return r.columnsType
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNil checks if a column value is nil
|
|
||||||
func (r *StarResult) IsNil(name string) bool {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return r.result[index] == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustString returns column value as string
|
|
||||||
func (r *StarResult) MustString(name string) string {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return convertToString(r.result[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInt returns column value as int
|
|
||||||
func (r *StarResult) MustInt(name string) int {
|
|
||||||
return int(r.MustInt64(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInt32 returns column value as int32
|
|
||||||
func (r *StarResult) MustInt32(name string) int32 {
|
|
||||||
return int32(r.MustInt64(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInt64 returns column value as int64
|
|
||||||
func (r *StarResult) MustInt64(name string) int64 {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return convertToInt64(r.result[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUint64 returns column value as uint64
|
|
||||||
func (r *StarResult) MustUint64(name string) uint64 {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return convertToUint64(r.result[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustFloat32 returns column value as float32
|
|
||||||
func (r *StarResult) MustFloat32(name string) float32 {
|
|
||||||
return float32(r.MustFloat64(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustFloat64 returns column value as float64
|
|
||||||
func (r *StarResult) MustFloat64(name string) float64 {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return convertToFloat64(r.result[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustBool returns column value as bool
|
|
||||||
func (r *StarResult) MustBool(name string) bool {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return convertToBool(r.result[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustBytes returns column value as []byte
|
|
||||||
func (r *StarResult) MustBytes(name string) []byte {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return []byte{}
|
|
||||||
}
|
|
||||||
if b, ok := r.result[index].([]byte); ok {
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return []byte(r.MustString(name))
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustDate returns column value as time.Time
|
|
||||||
func (r *StarResult) MustDate(name, layout string) time.Time {
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return time.Time{}
|
|
||||||
}
|
|
||||||
return convertToTime(r.result[index], layout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan scans the result into provided pointers
|
|
||||||
// Usage: row.Scan(&id, &name, &email)
|
|
||||||
func (r *StarResult) Scan(dest ...interface{}) error {
|
|
||||||
if len(dest) != len(r.result) {
|
|
||||||
return fmt.Errorf("expected %d destination arguments, got %d", len(r.result), len(dest))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, d := range dest {
|
|
||||||
if err := convertAssign(d, r.result[i]); err != nil {
|
|
||||||
return fmt.Errorf("error scanning column %d: %v", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertAssign assigns src to dest
|
|
||||||
func convertAssign(dest, src interface{}) error {
|
|
||||||
switch d := dest.(type) {
|
|
||||||
case *string:
|
|
||||||
*d = convertToString(src)
|
|
||||||
case *int:
|
|
||||||
*d = int(convertToInt64(src))
|
|
||||||
case *int32:
|
|
||||||
*d = int32(convertToInt64(src))
|
|
||||||
case *int64:
|
|
||||||
*d = convertToInt64(src)
|
|
||||||
case *uint64:
|
|
||||||
*d = convertToUint64(src)
|
|
||||||
case *float32:
|
|
||||||
*d = float32(convertToFloat64(src))
|
|
||||||
case *float64:
|
|
||||||
*d = convertToFloat64(src)
|
|
||||||
case *bool:
|
|
||||||
*d = convertToBool(src)
|
|
||||||
case *time.Time:
|
|
||||||
if t, ok := src.(time.Time); ok {
|
|
||||||
*d = t
|
|
||||||
}
|
|
||||||
case *[]byte:
|
|
||||||
if b, ok := src.([]byte); ok {
|
|
||||||
*d = b
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported Scan type: %T", dest)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StarResultCol represents a single column result
|
|
||||||
type StarResultCol struct {
|
|
||||||
result []interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Result returns the raw result slice
|
|
||||||
func (c *StarResultCol) Result() []interface{} {
|
|
||||||
return c.result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the number of values
|
|
||||||
func (c *StarResultCol) Len() int {
|
|
||||||
return len(c.result)
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsNil returns slice of nil checks for each row
|
|
||||||
func (c *StarResultCol) IsNil() []bool {
|
|
||||||
result := make([]bool, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = (v == nil)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustString returns all values as strings
|
|
||||||
func (c *StarResultCol) MustString() []string {
|
|
||||||
result := make([]string, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = convertToString(v)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInt returns all values as ints
|
|
||||||
func (c *StarResultCol) MustInt() []int {
|
|
||||||
result := make([]int, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = int(convertToInt64(v))
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInt32 returns all values as int32
|
|
||||||
func (c *StarResultCol) MustInt32() []int32 {
|
|
||||||
result := make([]int32, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = int32(convertToInt64(v))
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustInt64 returns all values as int64
|
|
||||||
func (c *StarResultCol) MustInt64() []int64 {
|
|
||||||
result := make([]int64, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = convertToInt64(v)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustUint64 returns all values as uint64
|
|
||||||
func (c *StarResultCol) MustUint64() []uint64 {
|
|
||||||
result := make([]uint64, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = convertToUint64(v)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustFloat32 returns all values as float32
|
|
||||||
func (c *StarResultCol) MustFloat32() []float32 {
|
|
||||||
result := make([]float32, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = float32(convertToFloat64(v))
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustFloat64 returns all values as float64
|
|
||||||
func (c *StarResultCol) MustFloat64() []float64 {
|
|
||||||
result := make([]float64, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = convertToFloat64(v)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustBool returns all values as bool
|
|
||||||
func (c *StarResultCol) MustBool() []bool {
|
|
||||||
result := make([]bool, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = convertToBool(v)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustBytes returns all values as []byte
|
|
||||||
func (c *StarResultCol) MustBytes() [][]byte {
|
|
||||||
result := make([][]byte, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
if b, ok := v.([]byte); ok {
|
|
||||||
result[i] = b
|
|
||||||
} else {
|
|
||||||
result[i] = []byte(convertToString(v))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// MustDate returns all values as time.Time
|
|
||||||
func (c *StarResultCol) MustDate(layout string) []time.Time {
|
|
||||||
result := make([]time.Time, len(c.result))
|
|
||||||
for i, v := range c.result {
|
|
||||||
result[i] = convertToTime(v, layout)
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
124
result_safe.go
124
result_safe.go
@ -1,124 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
internalconv "b612.me/stardb/internal/convert"
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (r *StarResult) getColumnValue(name string) (interface{}, error) {
|
|
||||||
if r == nil || r.columnIndex == nil {
|
|
||||||
return nil, wrapColumnNotFound(name)
|
|
||||||
}
|
|
||||||
index, ok := r.columnIndex[name]
|
|
||||||
if !ok {
|
|
||||||
return nil, wrapColumnNotFound(name)
|
|
||||||
}
|
|
||||||
return r.Result()[index], nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetString returns column value as string with error
|
|
||||||
func (r *StarResult) GetString(name string) (string, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return ConvertToStringSafe(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetInt64 returns column value as int64 with error
|
|
||||||
func (r *StarResult) GetInt64(name string) (int64, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return ConvertToInt64Safe(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetFloat64 returns column value as float64 with error
|
|
||||||
func (r *StarResult) GetFloat64(name string) (float64, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return internalconv.ToFloat64Safe(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNullString returns a nullable string value.
|
|
||||||
func (r *StarResult) GetNullString(name string) (sql.NullString, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullString{}, err
|
|
||||||
}
|
|
||||||
if val == nil {
|
|
||||||
return sql.NullString{}, nil
|
|
||||||
}
|
|
||||||
str, err := ConvertToStringSafe(val)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullString{}, err
|
|
||||||
}
|
|
||||||
return sql.NullString{String: str, Valid: true}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNullInt64 returns a nullable int64 value.
|
|
||||||
func (r *StarResult) GetNullInt64(name string) (sql.NullInt64, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullInt64{}, err
|
|
||||||
}
|
|
||||||
if val == nil {
|
|
||||||
return sql.NullInt64{}, nil
|
|
||||||
}
|
|
||||||
i, err := ConvertToInt64Safe(val)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullInt64{}, err
|
|
||||||
}
|
|
||||||
return sql.NullInt64{Int64: i, Valid: true}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNullFloat64 returns a nullable float64 value.
|
|
||||||
func (r *StarResult) GetNullFloat64(name string) (sql.NullFloat64, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullFloat64{}, err
|
|
||||||
}
|
|
||||||
if val == nil {
|
|
||||||
return sql.NullFloat64{}, nil
|
|
||||||
}
|
|
||||||
f, err := internalconv.ToFloat64Safe(val)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullFloat64{}, err
|
|
||||||
}
|
|
||||||
return sql.NullFloat64{Float64: f, Valid: true}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNullBool returns a nullable bool value.
|
|
||||||
func (r *StarResult) GetNullBool(name string) (sql.NullBool, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullBool{}, err
|
|
||||||
}
|
|
||||||
if val == nil {
|
|
||||||
return sql.NullBool{}, nil
|
|
||||||
}
|
|
||||||
b, err := internalconv.ToBoolSafe(val)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullBool{}, err
|
|
||||||
}
|
|
||||||
return sql.NullBool{Bool: b, Valid: true}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetNullTime returns a nullable time value.
|
|
||||||
func (r *StarResult) GetNullTime(name string) (sql.NullTime, error) {
|
|
||||||
val, err := r.getColumnValue(name)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullTime{}, err
|
|
||||||
}
|
|
||||||
if val == nil {
|
|
||||||
return sql.NullTime{}, nil
|
|
||||||
}
|
|
||||||
t, err := internalconv.ToTimeSafe(val)
|
|
||||||
if err != nil {
|
|
||||||
return sql.NullTime{}, err
|
|
||||||
}
|
|
||||||
return sql.NullTime{Time: t, Valid: true}, nil
|
|
||||||
}
|
|
||||||
177
rows.go
177
rows.go
@ -1,177 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"b612.me/stardb/internal/scanutil"
|
|
||||||
"database/sql"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StarRows represents a result set from a query
|
|
||||||
type StarRows struct {
|
|
||||||
rows *sql.Rows
|
|
||||||
db *StarDB
|
|
||||||
length int
|
|
||||||
stringResult []map[string]string
|
|
||||||
columns []string
|
|
||||||
columnsType []reflect.Type
|
|
||||||
columnIndex map[string]int
|
|
||||||
data [][]interface{}
|
|
||||||
parsed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Length returns the number of rows
|
|
||||||
func (r *StarRows) Length() int {
|
|
||||||
return r.length
|
|
||||||
}
|
|
||||||
|
|
||||||
// StringResult returns all rows as string maps
|
|
||||||
func (r *StarRows) StringResult() []map[string]string {
|
|
||||||
return r.stringResult
|
|
||||||
}
|
|
||||||
|
|
||||||
// Columns returns column names
|
|
||||||
func (r *StarRows) Columns() []string {
|
|
||||||
return r.columns
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnsType returns column types
|
|
||||||
func (r *StarRows) ColumnsType() []reflect.Type {
|
|
||||||
return r.columnsType
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the result set
|
|
||||||
func (r *StarRows) Close() error {
|
|
||||||
return r.rows.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rescan re-parses the result set
|
|
||||||
func (r *StarRows) Rescan() error {
|
|
||||||
return r.parse()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Row returns a specific row by index
|
|
||||||
func (r *StarRows) Row(index int) *StarResult {
|
|
||||||
result := &StarResult{}
|
|
||||||
if index < 0 || index >= len(r.data) {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
result.result = r.data[index]
|
|
||||||
result.columns = r.columns
|
|
||||||
result.columnsType = r.columnsType
|
|
||||||
result.columnIndex = r.columnIndex
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// Col returns all values for a specific column
|
|
||||||
func (r *StarRows) Col(name string) *StarResultCol {
|
|
||||||
result := &StarResultCol{}
|
|
||||||
if _, ok := r.columnIndex[name]; !ok {
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
colIndex := r.columnIndex[name]
|
|
||||||
for _, row := range r.data {
|
|
||||||
result.result = append(result.result, row[colIndex])
|
|
||||||
}
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
// parse parses the sql.Rows into internal data structures
|
|
||||||
func (r *StarRows) parse() error {
|
|
||||||
if r.parsed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.data = [][]interface{}{}
|
|
||||||
r.columnIndex = make(map[string]int)
|
|
||||||
r.stringResult = []map[string]string{}
|
|
||||||
r.columnsType = []reflect.Type{}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
r.columns, err = r.rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
columnTypes, err := r.rows.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, colType := range columnTypes {
|
|
||||||
r.columnsType = append(r.columnsType, colType.ScanType())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Build column index map
|
|
||||||
for i, colName := range r.columns {
|
|
||||||
r.columnIndex[colName] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare scan arguments
|
|
||||||
scanArgs := make([]interface{}, len(r.columns))
|
|
||||||
values := make([]interface{}, len(r.columns))
|
|
||||||
for i := range values {
|
|
||||||
scanArgs[i] = &values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan all rows
|
|
||||||
for r.rows.Next() {
|
|
||||||
if err := r.rows.Scan(scanArgs...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
record := make(map[string]string)
|
|
||||||
rowCopy := make([]interface{}, len(values))
|
|
||||||
|
|
||||||
for i, val := range values {
|
|
||||||
copiedVal := cloneScannedValue(val)
|
|
||||||
rowCopy[i] = copiedVal
|
|
||||||
record[r.columns[i]] = convertToString(copiedVal)
|
|
||||||
}
|
|
||||||
|
|
||||||
r.data = append(r.data, rowCopy)
|
|
||||||
r.stringResult = append(r.stringResult, record)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.rows.Err(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
r.length = len(r.stringResult)
|
|
||||||
r.parsed = true
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneScannedValue(val interface{}) interface{} {
|
|
||||||
return scanutil.CloneScannedValue(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
// convertToString converts any value to string
|
|
||||||
func convertToString(val interface{}) string {
|
|
||||||
switch v := val.(type) {
|
|
||||||
case nil:
|
|
||||||
return ""
|
|
||||||
case string:
|
|
||||||
return v
|
|
||||||
case int:
|
|
||||||
return strconv.Itoa(v)
|
|
||||||
case int32:
|
|
||||||
return strconv.FormatInt(int64(v), 10)
|
|
||||||
case int64:
|
|
||||||
return strconv.FormatInt(v, 10)
|
|
||||||
case float32:
|
|
||||||
return strconv.FormatFloat(float64(v), 'f', -1, 32)
|
|
||||||
case float64:
|
|
||||||
return strconv.FormatFloat(v, 'f', -1, 64)
|
|
||||||
case bool:
|
|
||||||
return strconv.FormatBool(v)
|
|
||||||
case time.Time:
|
|
||||||
return v.String()
|
|
||||||
case []byte:
|
|
||||||
return string(v)
|
|
||||||
default:
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,29 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestCloneScannedValue_BytesAreCopied(t *testing.T) {
|
|
||||||
original := []byte("hello")
|
|
||||||
clonedAny := cloneScannedValue(original)
|
|
||||||
cloned, ok := clonedAny.([]byte)
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected []byte, got %T", clonedAny)
|
|
||||||
}
|
|
||||||
|
|
||||||
original[0] = 'H'
|
|
||||||
if string(cloned) != "hello" {
|
|
||||||
t.Fatalf("expected cloned value to remain 'hello', got '%s'", string(cloned))
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(cloned) > 0 && &cloned[0] == &original[0] {
|
|
||||||
t.Fatal("expected cloned bytes to have a different backing array")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloneScannedValue_NonBytesKeepReference(t *testing.T) {
|
|
||||||
in := int64(42)
|
|
||||||
out := cloneScannedValue(in)
|
|
||||||
if out != in {
|
|
||||||
t.Fatalf("expected %v, got %v", in, out)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
124
scan_each.go
124
scan_each.go
@ -1,124 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ScanEachFunc is called for each scanned row in streaming mode.
|
|
||||||
type ScanEachFunc func(row *StarResult) error
|
|
||||||
|
|
||||||
// ScanEach executes query in streaming mode and invokes fn for each row.
|
|
||||||
func (s *StarDB) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
return s.scanEach(nil, query, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachContext executes query with context in streaming mode and invokes fn for each row.
|
|
||||||
func (s *StarDB) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
return s.scanEach(ctx, query, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
rows, err := s.queryRaw(ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return scanEachSQLRows(rows, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEach executes query in transaction streaming mode and invokes fn for each row.
|
|
||||||
func (t *StarTx) ScanEach(query string, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
return t.scanEach(nil, query, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachContext executes query with context in transaction streaming mode and invokes fn for each row.
|
|
||||||
func (t *StarTx) ScanEachContext(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
return t.scanEach(ctx, query, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *StarTx) scanEach(ctx context.Context, query string, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
rows, err := t.queryRaw(ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return scanEachSQLRows(rows, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEach executes prepared statement query in streaming mode and invokes fn for each row.
|
|
||||||
func (s *StarStmt) ScanEach(fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
return s.scanEach(nil, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachContext executes prepared statement query with context in streaming mode and invokes fn for each row.
|
|
||||||
func (s *StarStmt) ScanEachContext(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
return s.scanEach(ctx, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarStmt) scanEach(ctx context.Context, fn ScanEachFunc, args ...interface{}) error {
|
|
||||||
rows, err := s.queryRaw(ctx, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return scanEachSQLRows(rows, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
func scanEachSQLRows(rows *sql.Rows, fn ScanEachFunc) error {
|
|
||||||
if fn == nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return ErrScanFuncNil
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
columns, err := rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
columnTypes, err := rows.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
types := make([]reflect.Type, len(columnTypes))
|
|
||||||
for i, colType := range columnTypes {
|
|
||||||
types[i] = colType.ScanType()
|
|
||||||
}
|
|
||||||
|
|
||||||
columnIndex := make(map[string]int, len(columns))
|
|
||||||
for i, colName := range columns {
|
|
||||||
columnIndex[colName] = i
|
|
||||||
}
|
|
||||||
|
|
||||||
scanArgs := make([]interface{}, len(columns))
|
|
||||||
values := make([]interface{}, len(columns))
|
|
||||||
for i := range values {
|
|
||||||
scanArgs[i] = &values[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
for rows.Next() {
|
|
||||||
if err := rows.Scan(scanArgs...); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
rowCopy := make([]interface{}, len(values))
|
|
||||||
for i, val := range values {
|
|
||||||
rowCopy[i] = cloneScannedValue(val)
|
|
||||||
}
|
|
||||||
|
|
||||||
row := &StarResult{
|
|
||||||
result: rowCopy,
|
|
||||||
columns: columns,
|
|
||||||
columnIndex: columnIndex,
|
|
||||||
columnsType: types,
|
|
||||||
}
|
|
||||||
if err := fn(row); err != nil {
|
|
||||||
if errors.Is(err, ErrScanStopped) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return rows.Err()
|
|
||||||
}
|
|
||||||
119
scan_each_orm.go
119
scan_each_orm.go
@ -1,119 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ScanEachORMFunc is called for each mapped struct in streaming ORM mode.
|
|
||||||
type ScanEachORMFunc func(target interface{}) error
|
|
||||||
|
|
||||||
// ScanEachORM streams query rows and maps each row to target before invoking fn.
|
|
||||||
// target must be a pointer to struct; it is reused for each row.
|
|
||||||
func (s *StarDB) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
|
||||||
return s.ScanEachORMContext(nil, query, target, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachORMContext streams query rows with context and maps each row to target before invoking fn.
|
|
||||||
// target must be a pointer to struct; it is reused for each row.
|
|
||||||
func (s *StarDB) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
|
||||||
if fn == nil {
|
|
||||||
return ErrScanORMFuncNil
|
|
||||||
}
|
|
||||||
if err := validateScanORMTarget(target); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.ScanEachContext(ctx, query, func(row *StarResult) error {
|
|
||||||
if err := mapResultToStructTarget(row, target, s); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fn(target)
|
|
||||||
}, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachORM streams transaction rows and maps each row to target before invoking fn.
|
|
||||||
// target must be a pointer to struct; it is reused for each row.
|
|
||||||
func (t *StarTx) ScanEachORM(query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
|
||||||
return t.ScanEachORMContext(nil, query, target, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachORMContext streams transaction rows with context and maps each row to target before invoking fn.
|
|
||||||
// target must be a pointer to struct; it is reused for each row.
|
|
||||||
func (t *StarTx) ScanEachORMContext(ctx context.Context, query string, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
|
||||||
if fn == nil {
|
|
||||||
return ErrScanORMFuncNil
|
|
||||||
}
|
|
||||||
if err := validateScanORMTarget(target); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return t.ScanEachContext(ctx, query, func(row *StarResult) error {
|
|
||||||
if err := mapResultToStructTarget(row, target, t.db); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fn(target)
|
|
||||||
}, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachORM streams prepared statement rows and maps each row to target before invoking fn.
|
|
||||||
// target must be a pointer to struct; it is reused for each row.
|
|
||||||
func (s *StarStmt) ScanEachORM(target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
|
||||||
return s.ScanEachORMContext(nil, target, fn, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ScanEachORMContext streams prepared statement rows with context and maps each row to target before invoking fn.
|
|
||||||
// target must be a pointer to struct; it is reused for each row.
|
|
||||||
func (s *StarStmt) ScanEachORMContext(ctx context.Context, target interface{}, fn ScanEachORMFunc, args ...interface{}) error {
|
|
||||||
if fn == nil {
|
|
||||||
return ErrScanORMFuncNil
|
|
||||||
}
|
|
||||||
if err := validateScanORMTarget(target); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.ScanEachContext(ctx, func(row *StarResult) error {
|
|
||||||
if err := mapResultToStructTarget(row, target, s.db); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return fn(target)
|
|
||||||
}, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func validateScanORMTarget(target interface{}) error {
|
|
||||||
if target == nil {
|
|
||||||
return ErrTargetNil
|
|
||||||
}
|
|
||||||
|
|
||||||
targetType := reflect.TypeOf(target)
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
|
|
||||||
if targetType.Kind() != reflect.Ptr {
|
|
||||||
return ErrTargetNotPointer
|
|
||||||
}
|
|
||||||
if targetValue.IsNil() {
|
|
||||||
return ErrTargetPointerNil
|
|
||||||
}
|
|
||||||
if targetValue.Elem().Kind() != reflect.Struct {
|
|
||||||
return ErrTargetNotStruct
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapResultToStructTarget(row *StarResult, target interface{}, db *StarDB) error {
|
|
||||||
targetValue := reflect.ValueOf(target)
|
|
||||||
targetValue.Elem().Set(reflect.Zero(targetValue.Elem().Type()))
|
|
||||||
|
|
||||||
rowWrapper := &StarRows{
|
|
||||||
db: db,
|
|
||||||
length: 1,
|
|
||||||
columns: row.columns,
|
|
||||||
columnsType: row.columnsType,
|
|
||||||
columnIndex: row.columnIndex,
|
|
||||||
data: [][]interface{}{row.result},
|
|
||||||
parsed: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
return rowWrapper.setStructFieldsFromRow(target, "db", 0)
|
|
||||||
}
|
|
||||||
@ -1,19 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import internalsqlplaceholder "b612.me/stardb/internal/sqlplaceholder"
|
|
||||||
|
|
||||||
// ConvertPlaceholders converts placeholders according to style.
|
|
||||||
func ConvertPlaceholders(query string, style PlaceholderStyle) string {
|
|
||||||
switch normalizePlaceholderStyle(style) {
|
|
||||||
case PlaceholderDollar:
|
|
||||||
return ConvertQuestionToDollarPlaceholders(query)
|
|
||||||
default:
|
|
||||||
return query
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertQuestionToDollarPlaceholders converts '?' to '$1,$2,...' in SQL text.
|
|
||||||
// It skips quoted strings, quoted identifiers and comments.
|
|
||||||
func ConvertQuestionToDollarPlaceholders(query string) string {
|
|
||||||
return internalsqlplaceholder.ConvertQuestionToDollarPlaceholders(query)
|
|
||||||
}
|
|
||||||
@ -1,34 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestConvertQuestionToDollarPlaceholders(t *testing.T) {
|
|
||||||
query := "SELECT * FROM users WHERE id = ? AND name = ?"
|
|
||||||
got := ConvertQuestionToDollarPlaceholders(query)
|
|
||||||
want := "SELECT * FROM users WHERE id = $1 AND name = $2"
|
|
||||||
if got != want {
|
|
||||||
t.Fatalf("expected %q, got %q", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertQuestionToDollarPlaceholders_SkipQuotedAndComments(t *testing.T) {
|
|
||||||
query := "SELECT '?', \"?\", `?`, col FROM t WHERE id = ? -- ?\nAND note = '??' /* ? */ AND x = ?"
|
|
||||||
got := ConvertQuestionToDollarPlaceholders(query)
|
|
||||||
want := "SELECT '?', \"?\", `?`, col FROM t WHERE id = $1 -- ?\nAND note = '??' /* ? */ AND x = $2"
|
|
||||||
if got != want {
|
|
||||||
t.Fatalf("expected %q, got %q", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConvertPlaceholders(t *testing.T) {
|
|
||||||
query := "SELECT * FROM t WHERE a = ? AND b = ?"
|
|
||||||
if got := ConvertPlaceholders(query, PlaceholderQuestion); got != query {
|
|
||||||
t.Fatalf("question style should keep query unchanged, got %q", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
got := ConvertPlaceholders(query, PlaceholderDollar)
|
|
||||||
want := "SELECT * FROM t WHERE a = $1 AND b = $2"
|
|
||||||
if got != want {
|
|
||||||
t.Fatalf("expected %q, got %q", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
276
sql_runtime.go
276
sql_runtime.go
@ -1,276 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
internalsqlruntime "b612.me/stardb/internal/sqlruntime"
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SQLBeforeHook runs before a SQL statement is executed.
|
|
||||||
type SQLBeforeHook func(ctx context.Context, query string, args []interface{})
|
|
||||||
|
|
||||||
// SQLAfterHook runs after a SQL statement is executed.
|
|
||||||
type SQLAfterHook func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error)
|
|
||||||
|
|
||||||
// PlaceholderStyle controls SQL placeholder format conversion.
|
|
||||||
type PlaceholderStyle int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// PlaceholderQuestion keeps '?' placeholders unchanged.
|
|
||||||
PlaceholderQuestion PlaceholderStyle = iota
|
|
||||||
// PlaceholderDollar converts '?' placeholders to '$1,$2,...'.
|
|
||||||
PlaceholderDollar
|
|
||||||
)
|
|
||||||
|
|
||||||
// SQLFingerprintMode controls SQL fingerprint generation strategy.
|
|
||||||
type SQLFingerprintMode int
|
|
||||||
|
|
||||||
const (
|
|
||||||
// SQLFingerprintBasic lowercases SQL and collapses whitespace.
|
|
||||||
SQLFingerprintBasic SQLFingerprintMode = iota
|
|
||||||
// SQLFingerprintMaskLiterals also masks numeric/string literals and $n placeholders.
|
|
||||||
SQLFingerprintMaskLiterals
|
|
||||||
)
|
|
||||||
|
|
||||||
type sqlHookMetaKey struct{}
|
|
||||||
|
|
||||||
// SQLHookMeta contains extra hook metadata attached to context.
|
|
||||||
type SQLHookMeta struct {
|
|
||||||
Fingerprint string
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLHookMetaFromContext extracts SQL hook metadata from context.
|
|
||||||
func SQLHookMetaFromContext(ctx context.Context) (SQLHookMeta, bool) {
|
|
||||||
if ctx == nil {
|
|
||||||
return SQLHookMeta{}, false
|
|
||||||
}
|
|
||||||
meta, ok := ctx.Value(sqlHookMetaKey{}).(SQLHookMeta)
|
|
||||||
return meta, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func withSQLHookMeta(ctx context.Context, meta SQLHookMeta) context.Context {
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
return context.WithValue(ctx, sqlHookMetaKey{}, meta)
|
|
||||||
}
|
|
||||||
|
|
||||||
type sqlRuntime struct {
|
|
||||||
state internalsqlruntime.State
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizePlaceholderStyle(style PlaceholderStyle) PlaceholderStyle {
|
|
||||||
return PlaceholderStyle(internalsqlruntime.NormalizePlaceholderStyle(int(style)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func normalizeSQLFingerprintMode(mode SQLFingerprintMode) SQLFingerprintMode {
|
|
||||||
return SQLFingerprintMode(internalsqlruntime.NormalizeFingerprintMode(int(mode)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func cloneHookArgs(args []interface{}) []interface{} {
|
|
||||||
return internalsqlruntime.CloneHookArgs(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) runtimeOptions() (SQLBeforeHook, SQLAfterHook, PlaceholderStyle, time.Duration) {
|
|
||||||
if s == nil {
|
|
||||||
return nil, nil, PlaceholderQuestion, 0
|
|
||||||
}
|
|
||||||
|
|
||||||
beforeAny, afterAny, rawStyle, slowThreshold := s.runtime.state.Options()
|
|
||||||
var before SQLBeforeHook
|
|
||||||
if b, ok := beforeAny.(SQLBeforeHook); ok {
|
|
||||||
before = b
|
|
||||||
}
|
|
||||||
var after SQLAfterHook
|
|
||||||
if a, ok := afterAny.(SQLAfterHook); ok {
|
|
||||||
after = a
|
|
||||||
}
|
|
||||||
|
|
||||||
return before, after, normalizePlaceholderStyle(PlaceholderStyle(rawStyle)), slowThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) sqlHooks() (SQLBeforeHook, SQLAfterHook, time.Duration) {
|
|
||||||
before, after, _, slowThreshold := s.runtimeOptions()
|
|
||||||
return before, after, slowThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) prepareSQLCall(query string, args []interface{}) (string, SQLBeforeHook, SQLAfterHook, []interface{}, time.Duration) {
|
|
||||||
before, after, style, slowThreshold := s.runtimeOptions()
|
|
||||||
query = ConvertPlaceholders(query, style)
|
|
||||||
if before == nil && after == nil {
|
|
||||||
return query, nil, nil, nil, slowThreshold
|
|
||||||
}
|
|
||||||
return query, before, after, cloneHookArgs(args), slowThreshold
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) hookContext(ctx context.Context, query string, before SQLBeforeHook, after SQLAfterHook) context.Context {
|
|
||||||
if s == nil {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
needCounter := s.SQLFingerprintCounterEnabled()
|
|
||||||
needMeta := (before != nil || after != nil) && s.SQLFingerprintEnabled()
|
|
||||||
if !needCounter && !needMeta {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
mode := s.SQLFingerprintMode()
|
|
||||||
keepComments := s.SQLFingerprintKeepComments()
|
|
||||||
fingerprint := internalsqlruntime.FingerprintSQL(query, int(mode), keepComments)
|
|
||||||
if fingerprint == "" {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
|
|
||||||
if needCounter {
|
|
||||||
s.runtime.state.IncFingerprintCount(fingerprint)
|
|
||||||
}
|
|
||||||
if !needMeta {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
return withSQLHookMeta(ctx, SQLHookMeta{Fingerprint: fingerprint})
|
|
||||||
}
|
|
||||||
|
|
||||||
func shouldRunAfterHook(after SQLAfterHook, slowThreshold, duration time.Duration, err error) bool {
|
|
||||||
return internalsqlruntime.ShouldRunAfterHook(after != nil, slowThreshold, duration, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLHooks sets SQL before/after hooks.
|
|
||||||
func (s *StarDB) SetSQLHooks(before SQLBeforeHook, after SQLAfterHook) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetHooks(before, after)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLBeforeHook sets SQL before hook.
|
|
||||||
func (s *StarDB) SetSQLBeforeHook(before SQLBeforeHook) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetBeforeHook(before)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLAfterHook sets SQL after hook.
|
|
||||||
func (s *StarDB) SetSQLAfterHook(after SQLAfterHook) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetAfterHook(after)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPlaceholderStyle sets placeholder conversion style.
|
|
||||||
func (s *StarDB) SetPlaceholderStyle(style PlaceholderStyle) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetPlaceholderStyle(int(style))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLSlowThreshold sets minimum duration for triggering after hook.
|
|
||||||
// When threshold > 0, after hook runs only for statements slower than threshold or those with error.
|
|
||||||
func (s *StarDB) SetSQLSlowThreshold(threshold time.Duration) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetSlowThreshold(threshold)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLSlowThreshold returns current slow SQL threshold.
|
|
||||||
func (s *StarDB) SQLSlowThreshold() time.Duration {
|
|
||||||
if s == nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return s.runtime.state.SlowThreshold()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PlaceholderStyle returns current placeholder style.
|
|
||||||
func (s *StarDB) PlaceholderStyle() PlaceholderStyle {
|
|
||||||
if s == nil {
|
|
||||||
return PlaceholderQuestion
|
|
||||||
}
|
|
||||||
style := PlaceholderStyle(s.runtime.state.PlaceholderStyle())
|
|
||||||
return normalizePlaceholderStyle(style)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLFingerprintEnabled toggles SQL fingerprint metadata generation for hooks.
|
|
||||||
func (s *StarDB) SetSQLFingerprintEnabled(enabled bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetFingerprintEnabled(enabled)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLFingerprintEnabled reports whether SQL fingerprint metadata generation is enabled.
|
|
||||||
func (s *StarDB) SQLFingerprintEnabled() bool {
|
|
||||||
if s == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return s.runtime.state.FingerprintEnabled()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLFingerprintMode sets SQL fingerprint generation mode.
|
|
||||||
func (s *StarDB) SetSQLFingerprintMode(mode SQLFingerprintMode) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetFingerprintMode(int(mode))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLFingerprintMode returns SQL fingerprint generation mode.
|
|
||||||
func (s *StarDB) SQLFingerprintMode() SQLFingerprintMode {
|
|
||||||
if s == nil {
|
|
||||||
return SQLFingerprintBasic
|
|
||||||
}
|
|
||||||
mode := SQLFingerprintMode(s.runtime.state.FingerprintMode())
|
|
||||||
return normalizeSQLFingerprintMode(mode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLFingerprintKeepComments controls whether comments are preserved in SQL fingerprints.
|
|
||||||
// Default is false.
|
|
||||||
func (s *StarDB) SetSQLFingerprintKeepComments(keep bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetFingerprintKeepComments(keep)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLFingerprintKeepComments reports whether SQL fingerprints preserve comments.
|
|
||||||
func (s *StarDB) SQLFingerprintKeepComments() bool {
|
|
||||||
if s == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return s.runtime.state.FingerprintKeepComments()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSQLFingerprintCounterEnabled toggles in-memory SQL fingerprint hit counter.
|
|
||||||
// Default is false.
|
|
||||||
func (s *StarDB) SetSQLFingerprintCounterEnabled(enabled bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.SetFingerprintCounterEnabled(enabled)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLFingerprintCounterEnabled reports whether in-memory SQL fingerprint hit counter is enabled.
|
|
||||||
func (s *StarDB) SQLFingerprintCounterEnabled() bool {
|
|
||||||
if s == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return s.runtime.state.FingerprintCounterEnabled()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SQLFingerprintCounters returns a snapshot of fingerprint hit counters.
|
|
||||||
func (s *StarDB) SQLFingerprintCounters() map[string]uint64 {
|
|
||||||
if s == nil {
|
|
||||||
return map[string]uint64{}
|
|
||||||
}
|
|
||||||
return s.runtime.state.FingerprintCountsSnapshot()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSQLFingerprintCounters clears all in-memory fingerprint hit counters.
|
|
||||||
func (s *StarDB) ResetSQLFingerprintCounters() {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.runtime.state.ResetFingerprintCounts()
|
|
||||||
}
|
|
||||||
@ -1,102 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStarDB_RuntimeConfigConcurrent(t *testing.T) {
|
|
||||||
db := NewStarDB()
|
|
||||||
|
|
||||||
before := func(ctx context.Context, query string, args []interface{}) {}
|
|
||||||
after := func(ctx context.Context, query string, args []interface{}, duration time.Duration, err error) {}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func(i int) {
|
|
||||||
defer wg.Done()
|
|
||||||
for j := 0; j < 1000; j++ {
|
|
||||||
if (i+j)%2 == 0 {
|
|
||||||
db.SetPlaceholderStyle(PlaceholderDollar)
|
|
||||||
} else {
|
|
||||||
db.SetPlaceholderStyle(PlaceholderQuestion)
|
|
||||||
}
|
|
||||||
db.SetSQLSlowThreshold(time.Duration((i+j)%5) * time.Millisecond)
|
|
||||||
db.SetSQLFingerprintEnabled((i+j)%3 == 0)
|
|
||||||
db.SetSQLFingerprintMode(SQLFingerprintMode((i + j) % 3))
|
|
||||||
db.SetSQLFingerprintKeepComments((i+j)%4 == 0)
|
|
||||||
db.SetSQLFingerprintCounterEnabled((i+j)%5 == 0)
|
|
||||||
if (i+j)%7 == 0 {
|
|
||||||
db.ResetSQLFingerprintCounters()
|
|
||||||
}
|
|
||||||
db.SetSQLHooks(before, after)
|
|
||||||
_ = db.PlaceholderStyle()
|
|
||||||
_ = db.SQLSlowThreshold()
|
|
||||||
_ = db.SQLFingerprintEnabled()
|
|
||||||
_ = db.SQLFingerprintMode()
|
|
||||||
_ = db.SQLFingerprintKeepComments()
|
|
||||||
_ = db.SQLFingerprintCounterEnabled()
|
|
||||||
_ = db.SQLFingerprintCounters()
|
|
||||||
_, _, _, _ = db.runtimeOptions()
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SQLFingerprintMode(t *testing.T) {
|
|
||||||
db := NewStarDB()
|
|
||||||
|
|
||||||
if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic {
|
|
||||||
t.Fatalf("expected default mode SQLFingerprintBasic, got %v", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetSQLFingerprintMode(SQLFingerprintMaskLiterals)
|
|
||||||
if got := db.SQLFingerprintMode(); got != SQLFingerprintMaskLiterals {
|
|
||||||
t.Fatalf("expected SQLFingerprintMaskLiterals, got %v", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetSQLFingerprintMode(SQLFingerprintMode(99))
|
|
||||||
if got := db.SQLFingerprintMode(); got != SQLFingerprintBasic {
|
|
||||||
t.Fatalf("expected invalid mode fallback to SQLFingerprintBasic, got %v", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SQLFingerprintKeepComments(t *testing.T) {
|
|
||||||
db := NewStarDB()
|
|
||||||
|
|
||||||
if db.SQLFingerprintKeepComments() {
|
|
||||||
t.Fatal("expected default keep comments to be false")
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetSQLFingerprintKeepComments(true)
|
|
||||||
if !db.SQLFingerprintKeepComments() {
|
|
||||||
t.Fatal("expected keep comments to be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetSQLFingerprintKeepComments(false)
|
|
||||||
if db.SQLFingerprintKeepComments() {
|
|
||||||
t.Fatal("expected keep comments to be false")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SQLFingerprintCounterSwitch(t *testing.T) {
|
|
||||||
db := NewStarDB()
|
|
||||||
|
|
||||||
if db.SQLFingerprintCounterEnabled() {
|
|
||||||
t.Fatal("expected default counter switch to be false")
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetSQLFingerprintCounterEnabled(true)
|
|
||||||
if !db.SQLFingerprintCounterEnabled() {
|
|
||||||
t.Fatal("expected counter switch to be true")
|
|
||||||
}
|
|
||||||
|
|
||||||
db.ResetSQLFingerprintCounters()
|
|
||||||
if got := len(db.SQLFingerprintCounters()); got != 0 {
|
|
||||||
t.Fatalf("expected empty counters after reset, got %d", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
552
stardb.go
552
stardb.go
@ -1,552 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StarDB is a simple wrapper around sql.DB providing enhanced functionality
|
|
||||||
type StarDB struct {
|
|
||||||
db *sql.DB
|
|
||||||
ManualScan bool // If true, rows won't be automatically parsed
|
|
||||||
StrictORM bool // If true, Orm requires all tagged columns to exist in query results
|
|
||||||
// batchInsertMaxRows controls batch split size for BatchInsert/BatchInsertStructs.
|
|
||||||
// <= 0 means no split (single SQL statement).
|
|
||||||
batchInsertMaxRows int64
|
|
||||||
batchInsertMaxParams int64
|
|
||||||
runtime sqlRuntime
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStarDB creates a new StarDB instance
|
|
||||||
func NewStarDB() *StarDB {
|
|
||||||
return &StarDB{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewStarDBWithDB creates a new StarDB instance with an existing *sql.DB
|
|
||||||
func NewStarDBWithDB(db *sql.DB) *StarDB {
|
|
||||||
return &StarDB{db: db}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DB returns the underlying *sql.DB
|
|
||||||
func (s *StarDB) DB() *sql.DB {
|
|
||||||
return s.db
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDB sets the underlying *sql.DB
|
|
||||||
func (s *StarDB) SetDB(db *sql.DB) {
|
|
||||||
s.db = db
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetStrictORM enables or disables strict column validation for Orm mapping.
|
|
||||||
func (s *StarDB) SetStrictORM(strict bool) {
|
|
||||||
if s == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.StrictORM = strict
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) ensureDB() error {
|
|
||||||
if s == nil || s.db == nil {
|
|
||||||
return ErrDBNotInitialized
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open opens a new database connection
|
|
||||||
func (s *StarDB) Open(driver, connStr string) error {
|
|
||||||
var err error
|
|
||||||
s.db, err = sql.Open(driver, connStr)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the database connection
|
|
||||||
func (s *StarDB) Close() error {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.db.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ping verifies the database connection is alive
|
|
||||||
func (s *StarDB) Ping() error {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.db.Ping()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PingContext verifies the database connection with context
|
|
||||||
func (s *StarDB) PingContext(ctx context.Context) error {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.db.PingContext(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stats returns database statistics
|
|
||||||
func (s *StarDB) Stats() sql.DBStats {
|
|
||||||
if s == nil || s.db == nil {
|
|
||||||
return sql.DBStats{}
|
|
||||||
}
|
|
||||||
return s.db.Stats()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxOpenConns sets the maximum number of open connections
|
|
||||||
func (s *StarDB) SetMaxOpenConns(n int) {
|
|
||||||
if s == nil || s.db == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.db.SetMaxOpenConns(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxIdleConns sets the maximum number of idle connections
|
|
||||||
func (s *StarDB) SetMaxIdleConns(n int) {
|
|
||||||
if s == nil || s.db == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.db.SetMaxIdleConns(n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn returns a single connection from the pool
|
|
||||||
func (s *StarDB) Conn(ctx context.Context) (*sql.Conn, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s.db.Conn(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query executes a query that returns rows
|
|
||||||
// Usage: Query("SELECT * FROM users WHERE id = ?", 1)
|
|
||||||
func (s *StarDB) Query(query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return s.query(nil, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryContext executes a query with context
|
|
||||||
func (s *StarDB) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return s.query(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRaw executes a query and returns *sql.Rows without automatic parsing.
|
|
||||||
func (s *StarDB) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
return s.queryRaw(nil, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRawContext executes a query with context and returns *sql.Rows without automatic parsing.
|
|
||||||
func (s *StarDB) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
return s.queryRaw(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarDB) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args)
|
|
||||||
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, hookArgs)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
var (
|
|
||||||
rows *sql.Rows
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if ctx == nil {
|
|
||||||
rows, err = s.db.Query(query, args...)
|
|
||||||
} else {
|
|
||||||
rows, err = s.db.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, hookArgs, duration, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// query is the internal query implementation
|
|
||||||
func (s *StarDB) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
rows, err := s.queryRaw(ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
starRows := &StarRows{
|
|
||||||
rows: rows,
|
|
||||||
db: s,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.ManualScan {
|
|
||||||
if err := starRows.parse(); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return starRows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes a query that doesn't return rows
|
|
||||||
// Usage: Exec("INSERT INTO users (name) VALUES (?)", "John")
|
|
||||||
func (s *StarDB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return s.exec(nil, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecContext executes a query with context
|
|
||||||
func (s *StarDB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return s.exec(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// exec is the internal exec implementation
|
|
||||||
func (s *StarDB) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, hookArgs, slowThreshold := s.prepareSQLCall(query, args)
|
|
||||||
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, hookArgs)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
var (
|
|
||||||
result sql.Result
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if ctx == nil {
|
|
||||||
result, err = s.db.Exec(query, args...)
|
|
||||||
} else {
|
|
||||||
result, err = s.db.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, hookArgs, duration, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare creates a prepared statement
|
|
||||||
func (s *StarDB) Prepare(query string) (*StarStmt, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil)
|
|
||||||
hookCtx := s.hookContext(nil, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, nil)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
stmt, err := s.db.Prepare(query)
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, nil, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrepareContext creates a prepared statement with context
|
|
||||||
func (s *StarDB) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, _, slowThreshold := s.prepareSQLCall(query, nil)
|
|
||||||
hookCtx := s.hookContext(ctx, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, nil)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
stmt, err := s.db.PrepareContext(ctx, query)
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, nil, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StarStmt{stmt: stmt, db: s, sqlText: query}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryStmt executes a prepared statement query
|
|
||||||
// Usage: QueryStmt("SELECT * FROM users WHERE id = ?", 1)
|
|
||||||
func (s *StarDB) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := s.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.Query(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryStmtContext executes a prepared statement query with context
|
|
||||||
func (s *StarDB) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := s.PrepareContext(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.QueryContext(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecStmt executes a prepared statement
|
|
||||||
// Usage: ExecStmt("INSERT INTO users (name) VALUES (?)", "John")
|
|
||||||
func (s *StarDB) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := s.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.Exec(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecStmtContext executes a prepared statement with context
|
|
||||||
func (s *StarDB) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := s.PrepareContext(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.ExecContext(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Begin starts a transaction
|
|
||||||
func (s *StarDB) Begin() (*StarTx, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tx, err := s.db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StarTx{tx: tx, db: s}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// BeginTx starts a transaction with options
|
|
||||||
func (s *StarDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*StarTx, error) {
|
|
||||||
if err := s.ensureDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tx, err := s.db.BeginTx(ctx, opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StarTx{tx: tx, db: s}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StarStmt represents a prepared statement
|
|
||||||
type StarStmt struct {
|
|
||||||
stmt *sql.Stmt
|
|
||||||
db *StarDB
|
|
||||||
sqlText string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarStmt) ensureStmt() error {
|
|
||||||
if s == nil || s.stmt == nil {
|
|
||||||
return ErrStmtNotInitialized
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarStmt) ensureStmtWithDB() error {
|
|
||||||
if err := s.ensureStmt(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if s.db == nil {
|
|
||||||
return ErrStmtDBNotInitialized
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query executes a prepared statement query
|
|
||||||
func (s *StarStmt) Query(args ...interface{}) (*StarRows, error) {
|
|
||||||
return s.query(nil, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryContext executes a prepared statement query with context
|
|
||||||
func (s *StarStmt) QueryContext(ctx context.Context, args ...interface{}) (*StarRows, error) {
|
|
||||||
return s.query(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRaw executes a prepared statement query and returns *sql.Rows without automatic parsing.
|
|
||||||
func (s *StarStmt) QueryRaw(args ...interface{}) (*sql.Rows, error) {
|
|
||||||
return s.queryRaw(nil, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRawContext executes a prepared statement query with context and returns *sql.Rows without automatic parsing.
|
|
||||||
func (s *StarStmt) QueryRawContext(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
return s.queryRaw(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StarStmt) queryRaw(ctx context.Context, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
if err := s.ensureStmt(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var beforeHook SQLBeforeHook
|
|
||||||
var afterHook SQLAfterHook
|
|
||||||
var slowThreshold time.Duration
|
|
||||||
if s.db != nil {
|
|
||||||
beforeHook, afterHook, slowThreshold = s.db.sqlHooks()
|
|
||||||
}
|
|
||||||
var hookArgs []interface{}
|
|
||||||
if beforeHook != nil || afterHook != nil {
|
|
||||||
hookArgs = cloneHookArgs(args)
|
|
||||||
}
|
|
||||||
hookCtx := ctx
|
|
||||||
if s.db != nil {
|
|
||||||
hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook)
|
|
||||||
}
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, s.sqlText, hookArgs)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
var (
|
|
||||||
rows *sql.Rows
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if ctx == nil {
|
|
||||||
rows, err = s.stmt.Query(args...)
|
|
||||||
} else {
|
|
||||||
rows, err = s.stmt.QueryContext(ctx, args...)
|
|
||||||
}
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, s.sqlText, hookArgs, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// query is the internal query implementation
|
|
||||||
func (s *StarStmt) query(ctx context.Context, args ...interface{}) (*StarRows, error) {
|
|
||||||
if err := s.ensureStmtWithDB(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := s.queryRaw(ctx, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
starRows := &StarRows{
|
|
||||||
rows: rows,
|
|
||||||
db: s.db,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !s.db.ManualScan {
|
|
||||||
if err := starRows.parse(); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return starRows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes a prepared statement
|
|
||||||
func (s *StarStmt) Exec(args ...interface{}) (sql.Result, error) {
|
|
||||||
return s.exec(nil, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecContext executes a prepared statement with context
|
|
||||||
func (s *StarStmt) ExecContext(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
|
||||||
return s.exec(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// exec is the internal exec implementation
|
|
||||||
func (s *StarStmt) exec(ctx context.Context, args ...interface{}) (sql.Result, error) {
|
|
||||||
if err := s.ensureStmt(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var beforeHook SQLBeforeHook
|
|
||||||
var afterHook SQLAfterHook
|
|
||||||
var slowThreshold time.Duration
|
|
||||||
if s.db != nil {
|
|
||||||
beforeHook, afterHook, slowThreshold = s.db.sqlHooks()
|
|
||||||
}
|
|
||||||
var hookArgs []interface{}
|
|
||||||
if beforeHook != nil || afterHook != nil {
|
|
||||||
hookArgs = cloneHookArgs(args)
|
|
||||||
}
|
|
||||||
hookCtx := ctx
|
|
||||||
if s.db != nil {
|
|
||||||
hookCtx = s.db.hookContext(ctx, s.sqlText, beforeHook, afterHook)
|
|
||||||
}
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, s.sqlText, hookArgs)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
var (
|
|
||||||
result sql.Result
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if ctx == nil {
|
|
||||||
result, err = s.stmt.Exec(args...)
|
|
||||||
} else {
|
|
||||||
result, err = s.stmt.ExecContext(ctx, args...)
|
|
||||||
}
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, s.sqlText, hookArgs, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the prepared statement
|
|
||||||
func (s *StarStmt) Close() error {
|
|
||||||
if err := s.ensureStmt(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return s.stmt.Close()
|
|
||||||
}
|
|
||||||
@ -1,127 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStarDB_NotInitialized(t *testing.T) {
|
|
||||||
db := NewStarDB()
|
|
||||||
|
|
||||||
if err := db.Close(); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Close, got %v", err)
|
|
||||||
}
|
|
||||||
if err := db.Ping(); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Ping, got %v", err)
|
|
||||||
}
|
|
||||||
if err := db.PingContext(context.Background()); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from PingContext, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.Conn(context.Background()); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Conn, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.Query("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Query, got %v", err)
|
|
||||||
}
|
|
||||||
if err := db.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from ScanEach, got %v", err)
|
|
||||||
}
|
|
||||||
var model struct {
|
|
||||||
ID int `db:"id"`
|
|
||||||
}
|
|
||||||
if err := db.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from ScanEachORM, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.QueryRaw("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from QueryRaw, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.Exec("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Exec, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.Prepare("SELECT 1"); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Prepare, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.Begin(); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from Begin, got %v", err)
|
|
||||||
}
|
|
||||||
if err := db.WithTx(nil); !errors.Is(err, ErrTxFuncNil) {
|
|
||||||
t.Fatalf("expected ErrTxFuncNil from WithTx, got %v", err)
|
|
||||||
}
|
|
||||||
if err := db.WithTx(func(tx *StarTx) error { return nil }); !errors.Is(err, ErrDBNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrDBNotInitialized from WithTx, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := db.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
|
||||||
t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := db.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
|
||||||
t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = db.Stats()
|
|
||||||
db.SetMaxOpenConns(5)
|
|
||||||
db.SetMaxIdleConns(2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_NotInitialized(t *testing.T) {
|
|
||||||
tx := &StarTx{}
|
|
||||||
var model struct {
|
|
||||||
ID int `db:"id"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tx.Query("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from Query, got %v", err)
|
|
||||||
}
|
|
||||||
if err := tx.ScanEach("SELECT 1", func(row *StarResult) error { return nil }); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from ScanEach, got %v", err)
|
|
||||||
}
|
|
||||||
if err := tx.ScanEachORM("SELECT 1", &model, func(target interface{}) error { return nil }); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from ScanEachORM, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.QueryRaw("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from QueryRaw, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.Exec("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from Exec, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.Prepare("SELECT 1"); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from Prepare, got %v", err)
|
|
||||||
}
|
|
||||||
if err := tx.Commit(); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from Commit, got %v", err)
|
|
||||||
}
|
|
||||||
if err := tx.Rollback(); !errors.Is(err, ErrTxNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrTxNotInitialized from Rollback, got %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := tx.QueryStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
|
||||||
t.Fatalf("expected ErrQueryEmpty from QueryStmt, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := tx.ExecStmt(""); !errors.Is(err, ErrQueryEmpty) {
|
|
||||||
t.Fatalf("expected ErrQueryEmpty from ExecStmt, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarStmt_NotInitialized(t *testing.T) {
|
|
||||||
stmt := &StarStmt{}
|
|
||||||
|
|
||||||
if _, err := stmt.Query(); !errors.Is(err, ErrStmtNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrStmtNotInitialized from Query, got %v", err)
|
|
||||||
}
|
|
||||||
if err := stmt.ScanEach(func(row *StarResult) error { return nil }); !errors.Is(err, ErrStmtNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrStmtNotInitialized from ScanEach, got %v", err)
|
|
||||||
}
|
|
||||||
if err := stmt.ScanEachORM(nil, func(target interface{}) error { return nil }); !errors.Is(err, ErrTargetNil) {
|
|
||||||
t.Fatalf("expected ErrTargetNil from ScanEachORM, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := stmt.QueryRaw(); !errors.Is(err, ErrStmtNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrStmtNotInitialized from QueryRaw, got %v", err)
|
|
||||||
}
|
|
||||||
if _, err := stmt.Exec(); !errors.Is(err, ErrStmtNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrStmtNotInitialized from Exec, got %v", err)
|
|
||||||
}
|
|
||||||
if err := stmt.Close(); !errors.Is(err, ErrStmtNotInitialized) {
|
|
||||||
t.Fatalf("expected ErrStmtNotInitialized from Close, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
864
stardb_v1.go
Normal file
864
stardb_v1.go
Normal file
@ -0,0 +1,864 @@
|
|||||||
|
package stardb
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StarDB 一个简单封装的DB库
|
||||||
|
type StarDB struct {
|
||||||
|
DB *sql.DB
|
||||||
|
Rows *sql.Rows
|
||||||
|
}
|
||||||
|
|
||||||
|
// StarRows 为查询结果集(按行)
|
||||||
|
type StarRows struct {
|
||||||
|
Rows *sql.Rows
|
||||||
|
Length int
|
||||||
|
StringResult []map[string]string
|
||||||
|
Columns []string
|
||||||
|
ColumnsType []reflect.Type
|
||||||
|
columnref map[string]int
|
||||||
|
result [][]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StarResult 为查询结果集(总)
|
||||||
|
type StarResult struct {
|
||||||
|
Result []interface{}
|
||||||
|
Columns []string
|
||||||
|
columnref map[string]int
|
||||||
|
ColumnsType []reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
// StarResultCol 为查询结果集(按列)
|
||||||
|
type StarResultCol struct {
|
||||||
|
Result []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBytes 列查询结果转Bytes
|
||||||
|
func (star *StarResultCol) MustBytes() [][]byte {
|
||||||
|
var res [][]byte
|
||||||
|
for _, v := range star.Result {
|
||||||
|
res = append(res, v.([]byte))
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBool 列查询结果转Bool
|
||||||
|
func (star *StarResultCol) MustBool() []bool {
|
||||||
|
var res []bool
|
||||||
|
var tmp bool
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = false
|
||||||
|
case bool:
|
||||||
|
tmp = vtype
|
||||||
|
case float64:
|
||||||
|
if vtype > 0 {
|
||||||
|
tmp = true
|
||||||
|
} else {
|
||||||
|
tmp = false
|
||||||
|
}
|
||||||
|
case float32:
|
||||||
|
if vtype > 0 {
|
||||||
|
tmp = true
|
||||||
|
} else {
|
||||||
|
tmp = false
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
if vtype > 0 {
|
||||||
|
tmp = true
|
||||||
|
} else {
|
||||||
|
tmp = false
|
||||||
|
}
|
||||||
|
case int32:
|
||||||
|
if vtype > 0 {
|
||||||
|
tmp = true
|
||||||
|
} else {
|
||||||
|
tmp = false
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if vtype > 0 {
|
||||||
|
tmp = true
|
||||||
|
} else {
|
||||||
|
tmp = false
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
tmp, _ = strconv.ParseBool(vtype)
|
||||||
|
default:
|
||||||
|
tmp, _ = strconv.ParseBool(string(vtype.([]byte)))
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFloat32 列查询结果转Float32
|
||||||
|
func (star *StarResultCol) MustFloat32() []float32 {
|
||||||
|
var res []float32
|
||||||
|
var tmp float32
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = 0
|
||||||
|
case float32:
|
||||||
|
tmp = vtype
|
||||||
|
case float64:
|
||||||
|
tmp = float32(vtype)
|
||||||
|
case string:
|
||||||
|
tmps, _ := strconv.ParseFloat(vtype, 32)
|
||||||
|
tmp = float32(tmps)
|
||||||
|
case int:
|
||||||
|
tmp = float32(vtype)
|
||||||
|
case int32:
|
||||||
|
tmp = float32(vtype)
|
||||||
|
case int64:
|
||||||
|
tmp = float32(vtype)
|
||||||
|
case time.Time:
|
||||||
|
tmp = float32(vtype.Unix())
|
||||||
|
default:
|
||||||
|
tmpt := string(vtype.([]byte))
|
||||||
|
tmps, _ := strconv.ParseFloat(tmpt, 32)
|
||||||
|
tmp = float32(tmps)
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFloat64 列查询结果转Float64
|
||||||
|
func (star *StarResultCol) MustFloat64() []float64 {
|
||||||
|
var res []float64
|
||||||
|
var tmp float64
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = 0
|
||||||
|
case float64:
|
||||||
|
tmp = vtype
|
||||||
|
case float32:
|
||||||
|
tmp = float64(vtype)
|
||||||
|
case string:
|
||||||
|
tmp, _ = strconv.ParseFloat(vtype, 64)
|
||||||
|
case int:
|
||||||
|
tmp = float64(vtype)
|
||||||
|
case int32:
|
||||||
|
tmp = float64(vtype)
|
||||||
|
case int64:
|
||||||
|
tmp = float64(vtype)
|
||||||
|
case time.Time:
|
||||||
|
tmp = float64(vtype.Unix())
|
||||||
|
default:
|
||||||
|
tmpt := string(vtype.([]byte))
|
||||||
|
tmps, _ := strconv.ParseFloat(tmpt, 64)
|
||||||
|
tmp = float64(tmps)
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustString 列查询结果转String
|
||||||
|
func (star *StarResultCol) MustString() []string {
|
||||||
|
var res []string
|
||||||
|
var tmp string
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = ""
|
||||||
|
case string:
|
||||||
|
tmp = vtype
|
||||||
|
case int64:
|
||||||
|
tmp = strconv.FormatInt(vtype, 10)
|
||||||
|
case int32:
|
||||||
|
tmp = strconv.Itoa(int(vtype))
|
||||||
|
case bool:
|
||||||
|
tmp = strconv.FormatBool(vtype)
|
||||||
|
case float64:
|
||||||
|
tmp = strconv.FormatFloat(vtype, 'f', 10, 64)
|
||||||
|
case float32:
|
||||||
|
tmp = strconv.FormatFloat(float64(vtype), 'f', 10, 32)
|
||||||
|
case int:
|
||||||
|
tmp = strconv.Itoa(vtype)
|
||||||
|
case time.Time:
|
||||||
|
tmp = vtype.String()
|
||||||
|
default:
|
||||||
|
tmp = string(vtype.([]byte))
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt32 列查询结果转Int32
|
||||||
|
func (star *StarResultCol) MustInt32() []int32 {
|
||||||
|
var res []int32
|
||||||
|
var tmp int32
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = 0
|
||||||
|
case float64:
|
||||||
|
tmp = int32(vtype)
|
||||||
|
case float32:
|
||||||
|
tmp = int32(vtype)
|
||||||
|
case string:
|
||||||
|
tmps, _ := strconv.ParseInt(vtype, 10, 32)
|
||||||
|
tmp = int32(tmps)
|
||||||
|
case int:
|
||||||
|
tmp = int32(vtype)
|
||||||
|
case int64:
|
||||||
|
tmp = int32(vtype)
|
||||||
|
case int32:
|
||||||
|
tmp = vtype
|
||||||
|
case time.Time:
|
||||||
|
tmp = int32(vtype.Unix())
|
||||||
|
default:
|
||||||
|
tmpt := string(vtype.([]byte))
|
||||||
|
tmps, _ := strconv.ParseInt(tmpt, 10, 32)
|
||||||
|
tmp = int32(tmps)
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt64 列查询结果转Int64
|
||||||
|
func (star *StarResultCol) MustInt64() []int64 {
|
||||||
|
var res []int64
|
||||||
|
var tmp int64
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = 0
|
||||||
|
case float64:
|
||||||
|
tmp = int64(vtype)
|
||||||
|
case float32:
|
||||||
|
tmp = int64(vtype)
|
||||||
|
case string:
|
||||||
|
tmps, _ := strconv.ParseInt(vtype, 10, 64)
|
||||||
|
tmp = int64(tmps)
|
||||||
|
case int:
|
||||||
|
tmp = int64(vtype)
|
||||||
|
case int32:
|
||||||
|
tmp = int64(vtype)
|
||||||
|
case int64:
|
||||||
|
tmp = vtype
|
||||||
|
case time.Time:
|
||||||
|
tmp = vtype.Unix()
|
||||||
|
default:
|
||||||
|
tmpt := string(vtype.([]byte))
|
||||||
|
tmp, _ = strconv.ParseInt(tmpt, 10, 64)
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt 列查询结果转Int
|
||||||
|
func (star *StarResultCol) MustInt() []int {
|
||||||
|
var res []int
|
||||||
|
var tmp int
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = 0
|
||||||
|
case float64:
|
||||||
|
tmp = int(vtype)
|
||||||
|
case float32:
|
||||||
|
tmp = int(vtype)
|
||||||
|
case string:
|
||||||
|
tmps, _ := strconv.ParseInt(vtype, 10, 64)
|
||||||
|
tmp = int(tmps)
|
||||||
|
case int:
|
||||||
|
tmp = vtype
|
||||||
|
case int32:
|
||||||
|
tmp = int(vtype)
|
||||||
|
case int64:
|
||||||
|
tmp = int(vtype)
|
||||||
|
case time.Time:
|
||||||
|
tmp = int(vtype.Unix())
|
||||||
|
default:
|
||||||
|
tmpt := string(vtype.([]byte))
|
||||||
|
tmps, _ := strconv.ParseInt(tmpt, 10, 64)
|
||||||
|
tmp = int(tmps)
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustDate 列查询结果转Date(time.Time)
|
||||||
|
func (star *StarResultCol) MustDate(layout string) []time.Time {
|
||||||
|
var res []time.Time
|
||||||
|
var tmp time.Time
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = time.Time{}
|
||||||
|
case float64:
|
||||||
|
tmp = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000)
|
||||||
|
case float32:
|
||||||
|
tmp = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000)
|
||||||
|
case string:
|
||||||
|
tmp, _ = time.Parse(layout, vtype)
|
||||||
|
case int:
|
||||||
|
tmp = time.Unix(int64(vtype), 0)
|
||||||
|
case int32:
|
||||||
|
tmp = time.Unix(int64(vtype), 0)
|
||||||
|
case int64:
|
||||||
|
tmp = time.Unix(vtype, 0)
|
||||||
|
case time.Time:
|
||||||
|
tmp = vtype
|
||||||
|
default:
|
||||||
|
tmpt := string(vtype.([]byte))
|
||||||
|
tmp, _ = time.Parse(layout, tmpt)
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNil 检测是不是nil 列查询结果是不是nil
|
||||||
|
func (star *StarResultCol) IsNil(name string) []bool {
|
||||||
|
var res []bool
|
||||||
|
var tmp bool
|
||||||
|
for _, v := range star.Result {
|
||||||
|
switch v.(type) {
|
||||||
|
case nil:
|
||||||
|
tmp = true
|
||||||
|
default:
|
||||||
|
tmp = false
|
||||||
|
}
|
||||||
|
res = append(res, tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsNil 检测是不是nil
|
||||||
|
func (star *StarResult) IsNil(name string) bool {
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tmp := star.Result[num]
|
||||||
|
switch tmp.(type) {
|
||||||
|
case nil:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustDate 列查询结果转Date
|
||||||
|
func (star *StarResult) MustDate(name, layout string) time.Time {
|
||||||
|
var res time.Time
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
tmp := star.Result[num]
|
||||||
|
switch vtype := tmp.(type) {
|
||||||
|
case nil:
|
||||||
|
res = time.Time{}
|
||||||
|
case float64:
|
||||||
|
res = time.Unix(int64(vtype), int64(vtype-float64(int64(vtype)))*1000000000)
|
||||||
|
case float32:
|
||||||
|
res = time.Unix(int64(vtype), int64(vtype-float32(int64(vtype)))*1000000000)
|
||||||
|
case string:
|
||||||
|
res, _ = time.Parse(layout, vtype)
|
||||||
|
case int:
|
||||||
|
res = time.Unix(int64(vtype), 0)
|
||||||
|
case int32:
|
||||||
|
res = time.Unix(int64(vtype), 0)
|
||||||
|
case int64:
|
||||||
|
res = time.Unix(vtype, 0)
|
||||||
|
case time.Time:
|
||||||
|
res = vtype
|
||||||
|
default:
|
||||||
|
res, _ = time.Parse(layout, string(tmp.([]byte)))
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt64 列查询结果转int64
|
||||||
|
func (star *StarResult) MustInt64(name string) int64 {
|
||||||
|
var res int64
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
tmp := star.Result[num]
|
||||||
|
switch vtype := tmp.(type) {
|
||||||
|
case nil:
|
||||||
|
res = 0
|
||||||
|
case float64:
|
||||||
|
res = int64(vtype)
|
||||||
|
case float32:
|
||||||
|
res = int64(vtype)
|
||||||
|
case string:
|
||||||
|
res, _ = strconv.ParseInt(vtype, 10, 64)
|
||||||
|
case int:
|
||||||
|
res = int64(vtype)
|
||||||
|
case int32:
|
||||||
|
res = int64(vtype)
|
||||||
|
case int64:
|
||||||
|
res = vtype
|
||||||
|
case time.Time:
|
||||||
|
res = int64(vtype.Unix())
|
||||||
|
default:
|
||||||
|
res, _ = strconv.ParseInt(string(tmp.([]byte)), 10, 64)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt32 列查询结果转Int32
|
||||||
|
func (star *StarResult) MustInt32(name string) int32 {
|
||||||
|
var res int32
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
tmp := star.Result[num]
|
||||||
|
switch vtype := tmp.(type) {
|
||||||
|
case nil:
|
||||||
|
res = 0
|
||||||
|
case float64:
|
||||||
|
res = int32(vtype)
|
||||||
|
case float32:
|
||||||
|
res = int32(vtype)
|
||||||
|
case string:
|
||||||
|
ress, _ := strconv.ParseInt(vtype, 10, 32)
|
||||||
|
res = int32(ress)
|
||||||
|
case int:
|
||||||
|
res = int32(vtype)
|
||||||
|
case int32:
|
||||||
|
res = vtype
|
||||||
|
case int64:
|
||||||
|
res = int32(vtype)
|
||||||
|
case time.Time:
|
||||||
|
res = int32(vtype.Unix())
|
||||||
|
default:
|
||||||
|
ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 32)
|
||||||
|
res = int32(ress)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustString 列查询结果转string
|
||||||
|
func (star *StarResult) MustString(name string) string {
|
||||||
|
var res string
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch vtype := star.Result[num].(type) {
|
||||||
|
case nil:
|
||||||
|
res = ""
|
||||||
|
case string:
|
||||||
|
res = vtype
|
||||||
|
case int64:
|
||||||
|
res = strconv.FormatInt(vtype, 10)
|
||||||
|
case int32:
|
||||||
|
res = strconv.Itoa(int(vtype))
|
||||||
|
case bool:
|
||||||
|
res = strconv.FormatBool(vtype)
|
||||||
|
case float64:
|
||||||
|
res = strconv.FormatFloat(vtype, 'f', 10, 64)
|
||||||
|
case float32:
|
||||||
|
res = strconv.FormatFloat(float64(vtype), 'f', 10, 32)
|
||||||
|
case int:
|
||||||
|
res = strconv.Itoa(vtype)
|
||||||
|
case time.Time:
|
||||||
|
res = vtype.String()
|
||||||
|
default:
|
||||||
|
res = string(vtype.([]byte))
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFloat64 列查询结果转float64
|
||||||
|
func (star *StarResult) MustFloat64(name string) float64 {
|
||||||
|
var res float64
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch vtype := star.Result[num].(type) {
|
||||||
|
case nil:
|
||||||
|
res = 0
|
||||||
|
case string:
|
||||||
|
res, _ = strconv.ParseFloat(vtype, 64)
|
||||||
|
case float64:
|
||||||
|
res = vtype
|
||||||
|
case int:
|
||||||
|
res = float64(vtype)
|
||||||
|
case int64:
|
||||||
|
res = float64(vtype)
|
||||||
|
case int32:
|
||||||
|
res = float64(vtype)
|
||||||
|
case float32:
|
||||||
|
res = float64(vtype)
|
||||||
|
case time.Time:
|
||||||
|
res = float64(vtype.Unix())
|
||||||
|
default:
|
||||||
|
res, _ = strconv.ParseFloat(string(vtype.([]byte)), 64)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFloat32 列查询结果转float32
|
||||||
|
func (star *StarResult) MustFloat32(name string) float32 {
|
||||||
|
var res float32
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
switch vtype := star.Result[num].(type) {
|
||||||
|
case nil:
|
||||||
|
res = 0
|
||||||
|
case string:
|
||||||
|
tmp, _ := strconv.ParseFloat(vtype, 32)
|
||||||
|
res = float32(tmp)
|
||||||
|
case float64:
|
||||||
|
res = float32(vtype)
|
||||||
|
case float32:
|
||||||
|
res = vtype
|
||||||
|
case int:
|
||||||
|
res = float32(vtype)
|
||||||
|
case int64:
|
||||||
|
res = float32(vtype)
|
||||||
|
case int32:
|
||||||
|
res = float32(vtype)
|
||||||
|
case time.Time:
|
||||||
|
res = float32(vtype.Unix())
|
||||||
|
default:
|
||||||
|
tmp, _ := strconv.ParseFloat(string(vtype.([]byte)), 32)
|
||||||
|
res = float32(tmp)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt 列查询结果转int
|
||||||
|
func (star *StarResult) MustInt(name string) int {
|
||||||
|
var res int
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
tmp := star.Result[num]
|
||||||
|
switch vtype := tmp.(type) {
|
||||||
|
case nil:
|
||||||
|
res = 0
|
||||||
|
case float64:
|
||||||
|
res = int(vtype)
|
||||||
|
case float32:
|
||||||
|
res = int(vtype)
|
||||||
|
case string:
|
||||||
|
ress, _ := strconv.ParseInt(vtype, 10, 64)
|
||||||
|
res = int(ress)
|
||||||
|
case int:
|
||||||
|
res = vtype
|
||||||
|
case int32:
|
||||||
|
res = int(vtype)
|
||||||
|
case int64:
|
||||||
|
res = int(vtype)
|
||||||
|
case time.Time:
|
||||||
|
res = int(vtype.Unix())
|
||||||
|
default:
|
||||||
|
ress, _ := strconv.ParseInt(string(tmp.([]byte)), 10, 64)
|
||||||
|
res = int(ress)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBool 列查询结果转bool
|
||||||
|
func (star *StarResult) MustBool(name string) bool {
|
||||||
|
var res bool
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
tmp := star.Result[num]
|
||||||
|
switch vtype := tmp.(type) {
|
||||||
|
case nil:
|
||||||
|
res = false
|
||||||
|
case bool:
|
||||||
|
res = vtype
|
||||||
|
case float64:
|
||||||
|
if vtype > 0 {
|
||||||
|
res = true
|
||||||
|
} else {
|
||||||
|
res = false
|
||||||
|
}
|
||||||
|
case float32:
|
||||||
|
if vtype > 0 {
|
||||||
|
res = true
|
||||||
|
} else {
|
||||||
|
res = false
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
if vtype > 0 {
|
||||||
|
res = true
|
||||||
|
} else {
|
||||||
|
res = false
|
||||||
|
}
|
||||||
|
case int32:
|
||||||
|
if vtype > 0 {
|
||||||
|
res = true
|
||||||
|
} else {
|
||||||
|
res = false
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if vtype > 0 {
|
||||||
|
res = true
|
||||||
|
} else {
|
||||||
|
res = false
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
res, _ = strconv.ParseBool(vtype)
|
||||||
|
default:
|
||||||
|
res, _ = strconv.ParseBool(string(vtype.([]byte)))
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBytes 列查询结果转byte
|
||||||
|
func (star *StarResult) MustBytes(name string) []byte {
|
||||||
|
num, ok := star.columnref[name]
|
||||||
|
if !ok {
|
||||||
|
return []byte{}
|
||||||
|
}
|
||||||
|
res := star.Result[num].([]byte)
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rescan 重新分析结果集
|
||||||
|
func (star *StarRows) Rescan() {
|
||||||
|
star.parserows()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Col 选择需要进行操作的数据结果列
|
||||||
|
func (star *StarRows) Col(name string) *StarResultCol {
|
||||||
|
result := new(StarResultCol)
|
||||||
|
if _, ok := star.columnref[name]; !ok {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
var rescol []interface{}
|
||||||
|
for _, v := range star.result {
|
||||||
|
rescol = append(rescol, v[star.columnref[name]])
|
||||||
|
}
|
||||||
|
result.Result = rescol
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Row 选择需要进行操作的数据结果行
|
||||||
|
func (star *StarRows) Row(id int) *StarResult {
|
||||||
|
result := new(StarResult)
|
||||||
|
if id+1 > len(star.result) {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
result.Result = star.result[id]
|
||||||
|
result.Columns = star.Columns
|
||||||
|
result.ColumnsType = star.ColumnsType
|
||||||
|
result.columnref = star.columnref
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭打开的结果集
|
||||||
|
func (star *StarRows) Close() error {
|
||||||
|
return star.Rows.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (star *StarRows) parserows() {
|
||||||
|
star.result = [][]interface{}{}
|
||||||
|
star.columnref = make(map[string]int)
|
||||||
|
star.StringResult = []map[string]string{}
|
||||||
|
star.Columns, _ = star.Rows.Columns()
|
||||||
|
types, _ := star.Rows.ColumnTypes()
|
||||||
|
for _, v := range types {
|
||||||
|
star.ColumnsType = append(star.ColumnsType, v.ScanType())
|
||||||
|
}
|
||||||
|
scanArgs := make([]interface{}, len(star.Columns))
|
||||||
|
values := make([]interface{}, len(star.Columns))
|
||||||
|
for i := range values {
|
||||||
|
star.columnref[star.Columns[i]] = i
|
||||||
|
scanArgs[i] = &values[i]
|
||||||
|
}
|
||||||
|
for star.Rows.Next() {
|
||||||
|
if err := star.Rows.Scan(scanArgs...); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
record := make(map[string]string)
|
||||||
|
var rescopy []interface{}
|
||||||
|
for i, col := range values {
|
||||||
|
rescopy = append(rescopy, col)
|
||||||
|
switch vtype := col.(type) {
|
||||||
|
case float32:
|
||||||
|
record[star.Columns[i]] = strconv.FormatFloat(float64(vtype), 'f', -1, 64)
|
||||||
|
case float64:
|
||||||
|
record[star.Columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64)
|
||||||
|
case int64:
|
||||||
|
record[star.Columns[i]] = strconv.FormatInt(vtype, 10)
|
||||||
|
case int32:
|
||||||
|
record[star.Columns[i]] = strconv.FormatInt(int64(vtype), 10)
|
||||||
|
case int:
|
||||||
|
record[star.Columns[i]] = strconv.Itoa(vtype)
|
||||||
|
case string:
|
||||||
|
record[star.Columns[i]] = vtype
|
||||||
|
case bool:
|
||||||
|
record[star.Columns[i]] = strconv.FormatBool(vtype)
|
||||||
|
case time.Time:
|
||||||
|
record[star.Columns[i]] = vtype.String()
|
||||||
|
case nil:
|
||||||
|
record[star.Columns[i]] = ""
|
||||||
|
default:
|
||||||
|
record[star.Columns[i]] = string(vtype.([]byte))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
star.result = append(star.result, rescopy)
|
||||||
|
star.StringResult = append(star.StringResult, record)
|
||||||
|
}
|
||||||
|
star.Length = len(star.StringResult)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query 进行Query操作
|
||||||
|
func (star *StarDB) Query(args ...interface{}) (*StarRows, error) {
|
||||||
|
var err error
|
||||||
|
effect := new(StarRows)
|
||||||
|
if err = star.DB.Ping(); err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
if len(args) == 0 {
|
||||||
|
return effect, errors.New("no args")
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
sql := args[0]
|
||||||
|
if star.Rows, err = star.DB.Query(sql.(string)); err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
effect.Rows = star.Rows
|
||||||
|
effect.parserows()
|
||||||
|
return effect, nil
|
||||||
|
}
|
||||||
|
sql := args[0]
|
||||||
|
stmt, err := star.DB.Prepare(sql.(string))
|
||||||
|
if err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
var para []interface{}
|
||||||
|
for k, v := range args {
|
||||||
|
if k != 0 {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
default:
|
||||||
|
para = append(para, vtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if star.Rows, err = stmt.Query(para...); err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
effect.Rows = star.Rows
|
||||||
|
effect.parserows()
|
||||||
|
return effect, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open 打开一个新的数据库
|
||||||
|
func (star *StarDB) Open(Method, ConnStr string) error {
|
||||||
|
var err error
|
||||||
|
star.DB, err = sql.Open(Method, ConnStr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = star.DB.Ping()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭打开的数据库
|
||||||
|
func (star *StarDB) Close() error {
|
||||||
|
if err := star.DB.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return star.DB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Exec 执行Exec操作
|
||||||
|
func (star *StarDB) Exec(args ...interface{}) (sql.Result, error) {
|
||||||
|
var err error
|
||||||
|
var effect sql.Result
|
||||||
|
if err = star.DB.Ping(); err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
if len(args) == 0 {
|
||||||
|
return effect, errors.New("no args")
|
||||||
|
}
|
||||||
|
if len(args) == 1 {
|
||||||
|
sql := args[0]
|
||||||
|
if _, err = star.DB.Exec(sql.(string)); err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
return effect, nil
|
||||||
|
}
|
||||||
|
sql := args[0]
|
||||||
|
stmt, err := star.DB.Prepare(sql.(string))
|
||||||
|
if err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
defer stmt.Close()
|
||||||
|
var para []interface{}
|
||||||
|
for k, v := range args {
|
||||||
|
if k != 0 {
|
||||||
|
switch vtype := v.(type) {
|
||||||
|
default:
|
||||||
|
para = append(para, vtype)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if effect, err = stmt.Exec(para...); err != nil {
|
||||||
|
return effect, err
|
||||||
|
}
|
||||||
|
return effect, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// FetchAll 把结果集全部转为key-value型<string>数据
|
||||||
|
func FetchAll(rows *sql.Rows) (error, map[int]map[string]string) {
|
||||||
|
var ii int = 0
|
||||||
|
records := make(map[int]map[string]string)
|
||||||
|
columns, err := rows.Columns()
|
||||||
|
if err != nil {
|
||||||
|
return err, records
|
||||||
|
}
|
||||||
|
scanArgs := make([]interface{}, len(columns))
|
||||||
|
values := make([]interface{}, len(columns))
|
||||||
|
for i := range values {
|
||||||
|
scanArgs[i] = &values[i]
|
||||||
|
}
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(scanArgs...); err != nil {
|
||||||
|
return err, records
|
||||||
|
}
|
||||||
|
record := make(map[string]string)
|
||||||
|
for i, col := range values {
|
||||||
|
switch vtype := col.(type) {
|
||||||
|
case float64:
|
||||||
|
record[columns[i]] = strconv.FormatFloat(vtype, 'f', -1, 64)
|
||||||
|
case int64:
|
||||||
|
record[columns[i]] = strconv.FormatInt(vtype, 10)
|
||||||
|
case string:
|
||||||
|
record[columns[i]] = vtype
|
||||||
|
case nil:
|
||||||
|
record[columns[i]] = ""
|
||||||
|
default:
|
||||||
|
record[columns[i]] = string(vtype.([]byte))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
records[ii] = record
|
||||||
|
ii++
|
||||||
|
}
|
||||||
|
return nil, records
|
||||||
|
}
|
||||||
@ -1,888 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TestUser struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Email string `db:"email"`
|
|
||||||
Age int `db:"age"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupBatchTestDB(t *testing.T) *stardb.StarDB {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.Exec(`
|
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
email TEXT NOT NULL,
|
|
||||||
age INTEGER,
|
|
||||||
created_at DATETIME
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_Basic(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
{"Charlie", "charlie@example.com", 35},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsert("users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 3 {
|
|
||||||
t.Errorf("Expected 3 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify insertion
|
|
||||||
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
count := rows.Row(0).MustInt("count")
|
|
||||||
if count != 3 {
|
|
||||||
t.Errorf("Expected 3 rows in database, got %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_Single(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsert("users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_Empty(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{}
|
|
||||||
|
|
||||||
_, err := db.BatchInsert("users", columns, values)
|
|
||||||
if !errors.Is(err, stardb.ErrNoInsertValues) {
|
|
||||||
t.Errorf("Expected ErrNoInsertValues, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_EmptyColumns(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice"},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsert("users", nil, values)
|
|
||||||
if !errors.Is(err, stardb.ErrNoInsertColumns) {
|
|
||||||
t.Errorf("Expected ErrNoInsertColumns, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_EmptyTableName(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsert("", columns, values)
|
|
||||||
if !errors.Is(err, stardb.ErrTableNameEmpty) {
|
|
||||||
t.Errorf("Expected ErrTableNameEmpty, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_RowLengthMismatch(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com"},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsert("users", columns, values)
|
|
||||||
if !errors.Is(err, stardb.ErrBatchRowValueCountMismatch) {
|
|
||||||
t.Errorf("Expected ErrBatchRowValueCountMismatch, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_Large(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
var values [][]interface{}
|
|
||||||
|
|
||||||
// Insert 100 rows
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
values = append(values, []interface{}{
|
|
||||||
"User" + string(rune(i)),
|
|
||||||
"user" + string(rune(i)) + "@example.com",
|
|
||||||
20 + i%50,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsert("users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 100 {
|
|
||||||
t.Errorf("Expected 100 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify
|
|
||||||
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
count := rows.Row(0).MustInt("count")
|
|
||||||
if count != 100 {
|
|
||||||
t.Errorf("Expected 100 rows in database, got %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertContext(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertContext(ctx, "users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertContext failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 2 {
|
|
||||||
t.Errorf("Expected 2 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertContext_Timeout(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
time.Sleep(10 * time.Millisecond) // Ensure timeout
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsertContext(ctx, "users", columns, values)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected timeout error, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertMaxRows_Config(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
if got := db.BatchInsertMaxRows(); got != 0 {
|
|
||||||
t.Fatalf("Expected default chunk size 0, got %d", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(3)
|
|
||||||
if got := db.BatchInsertMaxRows(); got != 3 {
|
|
||||||
t.Fatalf("Expected chunk size 3, got %d", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(-10)
|
|
||||||
if got := db.BatchInsertMaxRows(); got != 0 {
|
|
||||||
t.Fatalf("Expected chunk size reset to 0, got %d", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertMaxParams_Config(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
if got := db.BatchInsertMaxParams(); got != 0 {
|
|
||||||
t.Fatalf("Expected default max params 0, got %d", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxParams(100)
|
|
||||||
if got := db.BatchInsertMaxParams(); got != 100 {
|
|
||||||
t.Fatalf("Expected max params 100, got %d", got)
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxParams(-1)
|
|
||||||
if got := db.BatchInsertMaxParams(); got != 0 {
|
|
||||||
t.Fatalf("Expected max params reset to 0, got %d", got)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_Chunked(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(2)
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
{"Charlie", "charlie@example.com", 35},
|
|
||||||
{"David", "david@example.com", 40},
|
|
||||||
{"Eva", "eva@example.com", 28},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsert("users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Chunked BatchInsert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
if affected != int64(len(values)) {
|
|
||||||
t.Fatalf("Expected %d affected rows, got %d", len(values), affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if count := rows.Row(0).MustInt("count"); count != len(values) {
|
|
||||||
t.Fatalf("Expected %d rows in db, got %d", len(values), count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_ChunkedRollbackOnError(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(2)
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
{"Charlie", nil, 35}, // email NOT NULL, forces second chunk failure
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := db.BatchInsert("users", columns, values); err == nil {
|
|
||||||
t.Fatal("Expected chunked BatchInsert to fail, got nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
if count := rows.Row(0).MustInt("count"); count != 0 {
|
|
||||||
t.Fatalf("Expected rollback to keep table empty, got %d rows", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_ChunkedByMaxParams(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(0) // disabled
|
|
||||||
db.SetBatchInsertMaxParams(4) // 3 columns -> 1 row per chunk
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
{"Charlie", "charlie@example.com", 35},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsert("users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsert by max params failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
if affected != int64(len(values)) {
|
|
||||||
t.Fatalf("Expected %d affected rows, got %d", len(values), affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_MaxParamsTooLow(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(0)
|
|
||||||
db.SetBatchInsertMaxParams(2) // columns=3 -> invalid
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsert("users", columns, values)
|
|
||||||
if !errors.Is(err, stardb.ErrBatchInsertMaxParamsTooLow) {
|
|
||||||
t.Fatalf("Expected ErrBatchInsertMaxParamsTooLow, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_ChunkedHookMeta(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(2)
|
|
||||||
db.SetBatchInsertMaxParams(0)
|
|
||||||
|
|
||||||
var metas []stardb.BatchExecMeta
|
|
||||||
db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
|
|
||||||
if meta, ok := stardb.BatchExecMetaFromContext(ctx); ok {
|
|
||||||
metas = append(metas, meta)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
{"Charlie", "charlie@example.com", 35},
|
|
||||||
{"David", "david@example.com", 40},
|
|
||||||
{"Eva", "eva@example.com", 28},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsertContext(context.Background(), "users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertContext failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(metas) != 3 {
|
|
||||||
t.Fatalf("Expected 3 chunk metas, got %d", len(metas))
|
|
||||||
}
|
|
||||||
|
|
||||||
wantRows := []int{2, 2, 1}
|
|
||||||
for i, meta := range metas {
|
|
||||||
if meta.ChunkIndex != i+1 {
|
|
||||||
t.Fatalf("Chunk %d: expected index %d, got %d", i, i+1, meta.ChunkIndex)
|
|
||||||
}
|
|
||||||
if meta.ChunkCount != 3 {
|
|
||||||
t.Fatalf("Chunk %d: expected count 3, got %d", i, meta.ChunkCount)
|
|
||||||
}
|
|
||||||
if meta.ChunkRows != wantRows[i] {
|
|
||||||
t.Fatalf("Chunk %d: expected rows %d, got %d", i, wantRows[i], meta.ChunkRows)
|
|
||||||
}
|
|
||||||
if meta.TotalRows != len(values) {
|
|
||||||
t.Fatalf("Chunk %d: expected total rows %d, got %d", i, len(values), meta.TotalRows)
|
|
||||||
}
|
|
||||||
if meta.ColumnCount != len(columns) {
|
|
||||||
t.Fatalf("Chunk %d: expected column count %d, got %d", i, len(columns), meta.ColumnCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_HookMetaAbsentWithoutChunking(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(0)
|
|
||||||
db.SetBatchInsertMaxParams(0)
|
|
||||||
|
|
||||||
metaFound := false
|
|
||||||
db.SetSQLHooks(nil, func(ctx context.Context, query string, args []interface{}, d time.Duration, err error) {
|
|
||||||
if _, ok := stardb.BatchExecMetaFromContext(ctx); ok {
|
|
||||||
metaFound = true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsertContext(context.Background(), "users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertContext failed: %v", err)
|
|
||||||
}
|
|
||||||
if metaFound {
|
|
||||||
t.Fatal("Expected no batch meta for non-chunked execution")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Basic(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
users := []TestUser{
|
|
||||||
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
|
|
||||||
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
|
|
||||||
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertStructs failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 3 {
|
|
||||||
t.Errorf("Expected 3 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify insertion
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 3 {
|
|
||||||
t.Errorf("Expected 3 rows, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify first user
|
|
||||||
name := rows.Row(0).MustString("name")
|
|
||||||
if name != "Alice" {
|
|
||||||
t.Errorf("Expected first user 'Alice', got '%s'", name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Chunked(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetBatchInsertMaxRows(2)
|
|
||||||
|
|
||||||
users := []TestUser{
|
|
||||||
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
|
|
||||||
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
|
|
||||||
{Name: "Charlie", Email: "charlie@example.com", Age: 35, CreatedAt: time.Now()},
|
|
||||||
{Name: "David", Email: "david@example.com", Age: 40, CreatedAt: time.Now()},
|
|
||||||
{Name: "Eva", Email: "eva@example.com", Age: 28, CreatedAt: time.Now()},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Chunked BatchInsertStructs failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
if affected != int64(len(users)) {
|
|
||||||
t.Fatalf("Expected %d affected rows, got %d", len(users), affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Single(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
users := []TestUser{
|
|
||||||
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertStructs failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Empty(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
users := []TestUser{}
|
|
||||||
|
|
||||||
_, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if !errors.Is(err, stardb.ErrNoStructsToInsert) {
|
|
||||||
t.Errorf("Expected ErrNoStructsToInsert, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_NotSlice(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := TestUser{Name: "Alice", Email: "alice@example.com", Age: 25}
|
|
||||||
|
|
||||||
_, err := db.BatchInsertStructs("users", user, "id")
|
|
||||||
if !errors.Is(err, stardb.ErrStructsNotSlice) {
|
|
||||||
t.Errorf("Expected ErrStructsNotSlice, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Nil(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err := db.BatchInsertStructs("users", nil, "id")
|
|
||||||
if !errors.Is(err, stardb.ErrStructsNil) {
|
|
||||||
t.Errorf("Expected ErrStructsNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_NilPointer(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
var users *[]TestUser
|
|
||||||
_, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if !errors.Is(err, stardb.ErrStructsPointerNil) {
|
|
||||||
t.Errorf("Expected ErrStructsPointerNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Pointer(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
users := []TestUser{
|
|
||||||
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
|
|
||||||
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertStructs("users", &users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertStructs with pointer failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 2 {
|
|
||||||
t.Errorf("Expected 2 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructsContext(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
users := []TestUser{
|
|
||||||
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: time.Now()},
|
|
||||||
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: time.Now()},
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertStructsContext(ctx, "users", users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertStructsContext failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 2 {
|
|
||||||
t.Errorf("Expected 2 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_Large(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
var users []TestUser
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
users = append(users, TestUser{
|
|
||||||
Name: "User" + string(rune(i)),
|
|
||||||
Email: "user" + string(rune(i)) + "@example.com",
|
|
||||||
Age: 20 + i%30,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertStructs failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 50 {
|
|
||||||
t.Errorf("Expected 50 rows affected, got %d", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify
|
|
||||||
rows, err := db.Query("SELECT COUNT(*) as count FROM users")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
count := rows.Row(0).MustInt("count")
|
|
||||||
if count != 50 {
|
|
||||||
t.Errorf("Expected 50 rows in database, got %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsert_VerifyData(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
values := [][]interface{}{
|
|
||||||
{"Alice", "alice@example.com", 25},
|
|
||||||
{"Bob", "bob@example.com", 30},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsert("users", columns, values)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify data integrity
|
|
||||||
rows, err := db.Query("SELECT name, email, age FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Check Alice
|
|
||||||
row0 := rows.Row(0)
|
|
||||||
if row0.MustString("name") != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", row0.MustString("name"))
|
|
||||||
}
|
|
||||||
if row0.MustString("email") != "alice@example.com" {
|
|
||||||
t.Errorf("Expected email 'alice@example.com', got '%s'", row0.MustString("email"))
|
|
||||||
}
|
|
||||||
if row0.MustInt("age") != 25 {
|
|
||||||
t.Errorf("Expected age 25, got %d", row0.MustInt("age"))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check Bob
|
|
||||||
row1 := rows.Row(1)
|
|
||||||
if row1.MustString("name") != "Bob" {
|
|
||||||
t.Errorf("Expected name 'Bob', got '%s'", row1.MustString("name"))
|
|
||||||
}
|
|
||||||
if row1.MustString("email") != "bob@example.com" {
|
|
||||||
t.Errorf("Expected email 'bob@example.com', got '%s'", row1.MustString("email"))
|
|
||||||
}
|
|
||||||
if row1.MustInt("age") != 30 {
|
|
||||||
t.Errorf("Expected age 30, got %d", row1.MustInt("age"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_BatchInsertStructs_VerifyData(t *testing.T) {
|
|
||||||
db := setupBatchTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
users := []TestUser{
|
|
||||||
{Name: "Alice", Email: "alice@example.com", Age: 25, CreatedAt: now},
|
|
||||||
{Name: "Bob", Email: "bob@example.com", Age: 30, CreatedAt: now},
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.BatchInsertStructs("users", users, "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("BatchInsertStructs failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query and verify with ORM
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var resultUsers []TestUser
|
|
||||||
err = rows.Orm(&resultUsers)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resultUsers) != 2 {
|
|
||||||
t.Fatalf("Expected 2 users, got %d", len(resultUsers))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify Alice
|
|
||||||
if resultUsers[0].Name != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", resultUsers[0].Name)
|
|
||||||
}
|
|
||||||
if resultUsers[0].Email != "alice@example.com" {
|
|
||||||
t.Errorf("Expected email 'alice@example.com', got '%s'", resultUsers[0].Email)
|
|
||||||
}
|
|
||||||
if resultUsers[0].Age != 25 {
|
|
||||||
t.Errorf("Expected age 25, got %d", resultUsers[0].Age)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify Bob
|
|
||||||
if resultUsers[1].Name != "Bob" {
|
|
||||||
t.Errorf("Expected name 'Bob', got '%s'", resultUsers[1].Name)
|
|
||||||
}
|
|
||||||
if resultUsers[1].Email != "bob@example.com" {
|
|
||||||
t.Errorf("Expected email 'bob@example.com', got '%s'", resultUsers[1].Email)
|
|
||||||
}
|
|
||||||
if resultUsers[1].Age != 30 {
|
|
||||||
t.Errorf("Expected age 30, got %d", resultUsers[1].Age)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark tests
|
|
||||||
func BenchmarkBatchInsert_10(b *testing.B) {
|
|
||||||
db := setupBatchTestDB(&testing.T{})
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
var values [][]interface{}
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
values = append(values, []interface{}{"User", "user@example.com", 25})
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
db.BatchInsert("users", columns, values)
|
|
||||||
db.Exec("DELETE FROM users")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBatchInsert_100(b *testing.B) {
|
|
||||||
db := setupBatchTestDB(&testing.T{})
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age"}
|
|
||||||
var values [][]interface{}
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
values = append(values, []interface{}{"User", "user@example.com", 25})
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
db.BatchInsert("users", columns, values)
|
|
||||||
db.Exec("DELETE FROM users")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBatchInsertStructs_10(b *testing.B) {
|
|
||||||
db := setupBatchTestDB(&testing.T{})
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
var users []TestUser
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
users = append(users, TestUser{
|
|
||||||
Name: "User",
|
|
||||||
Email: "user@example.com",
|
|
||||||
Age: 25,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
db.BatchInsertStructs("users", users, "id")
|
|
||||||
db.Exec("DELETE FROM users")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBatchInsertStructs_100(b *testing.B) {
|
|
||||||
db := setupBatchTestDB(&testing.T{})
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
var users []TestUser
|
|
||||||
for i := 0; i < 100; i++ {
|
|
||||||
users = append(users, TestUser{
|
|
||||||
Name: "User",
|
|
||||||
Email: "user@example.com",
|
|
||||||
Age: 25,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
db.BatchInsertStructs("users", users, "id")
|
|
||||||
db.Exec("DELETE FROM users")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,23 +0,0 @@
|
|||||||
module b612.me/stardb/testing
|
|
||||||
|
|
||||||
go 1.25.6
|
|
||||||
|
|
||||||
require (
|
|
||||||
b612.me/stardb v0.0.0
|
|
||||||
modernc.org/sqlite v1.46.1
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
|
||||||
github.com/google/uuid v1.6.0 // indirect
|
|
||||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
|
||||||
github.com/ncruces/go-strftime v1.0.0 // indirect
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect
|
|
||||||
golang.org/x/sys v0.37.0 // indirect
|
|
||||||
modernc.org/libc v1.67.6 // indirect
|
|
||||||
modernc.org/mathutil v1.7.1 // indirect
|
|
||||||
modernc.org/memory v1.11.0 // indirect
|
|
||||||
)
|
|
||||||
|
|
||||||
replace b612.me/stardb => ../
|
|
||||||
@ -1,53 +0,0 @@
|
|||||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
|
||||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
|
|
||||||
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
|
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM=
|
|
||||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
|
||||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
|
||||||
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
|
|
||||||
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE=
|
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY=
|
|
||||||
golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
|
|
||||||
golang.org/x/mod v0.29.0 h1:HV8lRxZC4l2cr3Zq1LvtOsi/ThTgWnUk/y64QSs8GwA=
|
|
||||||
golang.org/x/mod v0.29.0/go.mod h1:NyhrlYXJ2H4eJiRy/WDBO6HMqZQ6q9nk4JzS3NuCK+w=
|
|
||||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
|
||||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
|
||||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
||||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
|
||||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
|
||||||
golang.org/x/tools v0.38.0 h1:Hx2Xv8hISq8Lm16jvBZ2VQf+RLmbd7wVUsALibYI/IQ=
|
|
||||||
golang.org/x/tools v0.38.0/go.mod h1:yEsQ/d/YK8cjh0L6rZlY8tgtlKiBNTL14pGDJPJpYQs=
|
|
||||||
modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis=
|
|
||||||
modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0=
|
|
||||||
modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc=
|
|
||||||
modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM=
|
|
||||||
modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA=
|
|
||||||
modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc=
|
|
||||||
modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI=
|
|
||||||
modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito=
|
|
||||||
modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE=
|
|
||||||
modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY=
|
|
||||||
modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks=
|
|
||||||
modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI=
|
|
||||||
modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI=
|
|
||||||
modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE=
|
|
||||||
modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU=
|
|
||||||
modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg=
|
|
||||||
modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI=
|
|
||||||
modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw=
|
|
||||||
modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8=
|
|
||||||
modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns=
|
|
||||||
modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w=
|
|
||||||
modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE=
|
|
||||||
modernc.org/sqlite v1.46.1 h1:eFJ2ShBLIEnUWlLy12raN0Z1plqmFX9Qe3rjQTKt6sU=
|
|
||||||
modernc.org/sqlite v1.46.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA=
|
|
||||||
modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0=
|
|
||||||
modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A=
|
|
||||||
modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y=
|
|
||||||
modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM=
|
|
||||||
@ -1,934 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Email string `db:"email"`
|
|
||||||
Age int `db:"age"`
|
|
||||||
Balance float64 `db:"balance"`
|
|
||||||
Active bool `db:"active"`
|
|
||||||
CreatedAt time.Time `db:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Profile struct {
|
|
||||||
UserID int `db:"user_id"`
|
|
||||||
Bio string `db:"bio"`
|
|
||||||
Avatar string `db:"avatar"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type NestedUser struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
Profile `db:"---"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserWithPrivateField struct {
|
|
||||||
ID int64 `db:"id"`
|
|
||||||
Name string `db:"name"`
|
|
||||||
age int `db:"age"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_Single(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Name != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Email != "alice@example.com" {
|
|
||||||
t.Errorf("Expected email 'alice@example.com', got '%s'", user.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Age != 25 {
|
|
||||||
t.Errorf("Expected age 25, got %d", user.Age)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Balance != 100.50 {
|
|
||||||
t.Errorf("Expected balance 100.50, got %f", user.Balance)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !user.Active {
|
|
||||||
t.Errorf("Expected user to be active")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_Multiple(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
err = rows.Orm(&users)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != 3 {
|
|
||||||
t.Fatalf("Expected 3 users, got %d", len(users))
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedNames := []string{"Alice", "Bob", "Charlie"}
|
|
||||||
for i, user := range users {
|
|
||||||
if user.Name != expectedNames[i] {
|
|
||||||
t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_Array(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users [3]User
|
|
||||||
err = rows.Orm(&users)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
expectedNames := []string{"Alice", "Bob", "Charlie"}
|
|
||||||
for i, user := range users {
|
|
||||||
if user.Name != expectedNames[i] {
|
|
||||||
t.Errorf("Expected name '%s', got '%s'", expectedNames[i], user.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_ArrayTooSmall(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users [2]User
|
|
||||||
err = rows.Orm(&users)
|
|
||||||
if err == nil {
|
|
||||||
t.Error("Expected error when target array is smaller than row count, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_Empty(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NonExistent")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
err = rows.Orm(&users)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(users) != 0 {
|
|
||||||
t.Errorf("Expected 0 users, got %d", len(users))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_NotPointer(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(user) // Not a pointer
|
|
||||||
if !errors.Is(err, stardb.ErrTargetNotPointer) {
|
|
||||||
t.Errorf("Expected ErrTargetNotPointer, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_NilTarget(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
err = rows.Orm(nil)
|
|
||||||
if !errors.Is(err, stardb.ErrTargetNil) {
|
|
||||||
t.Errorf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_NilPointerTarget(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user *User
|
|
||||||
err = rows.Orm(user)
|
|
||||||
if !errors.Is(err, stardb.ErrTargetPointerNil) {
|
|
||||||
t.Errorf("Expected ErrTargetPointerNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_MissingColumns_NonStrict(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Expected non-strict ORM to ignore missing columns, got error: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.Name != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_MissingColumns_Strict(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetStrictORM(true)
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT id, name FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("Expected strict ORM to fail on missing columns, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Orm_UnexportedTaggedField(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user UserWithPrivateField
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if user.ID <= 0 {
|
|
||||||
t.Errorf("Expected positive ID, got %d", user.ID)
|
|
||||||
}
|
|
||||||
if user.Name != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", user.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_Insert(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "David",
|
|
||||||
Email: "david@example.com",
|
|
||||||
Age: 40,
|
|
||||||
Balance: 400.00,
|
|
||||||
Active: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.Insert(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Insert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
lastID, err := result.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("LastInsertId failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastID <= 0 {
|
|
||||||
t.Errorf("Expected positive last insert ID, got %d", lastID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify insertion
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var insertedUser User
|
|
||||||
err = rows.Orm(&insertedUser)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if insertedUser.Name != "David" {
|
|
||||||
t.Errorf("Expected name 'David', got '%s'", insertedUser.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
if insertedUser.Email != "david@example.com" {
|
|
||||||
t.Errorf("Expected email 'david@example.com', got '%s'", insertedUser.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
if insertedUser.Age != 40 {
|
|
||||||
t.Errorf("Expected age 40, got %d", insertedUser.Age)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_InsertContext(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
user := User{
|
|
||||||
Name: "Eve",
|
|
||||||
Email: "eve@example.com",
|
|
||||||
Age: 28,
|
|
||||||
Balance: 250.00,
|
|
||||||
Active: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.InsertContext(ctx, &user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("InsertContext failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_Update(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// First, get the user
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the user
|
|
||||||
user.Age = 26
|
|
||||||
user.Balance = 150.75
|
|
||||||
|
|
||||||
result, err := db.Update(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Update failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the update
|
|
||||||
rows2, err := db.Query("SELECT * FROM users WHERE id = ?", user.ID)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows2.Close()
|
|
||||||
|
|
||||||
var updatedUser User
|
|
||||||
err = rows2.Orm(&updatedUser)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if updatedUser.Age != 26 {
|
|
||||||
t.Errorf("Expected age 26, got %d", updatedUser.Age)
|
|
||||||
}
|
|
||||||
|
|
||||||
if updatedUser.Balance != 150.75 {
|
|
||||||
t.Errorf("Expected balance 150.75, got %f", updatedUser.Balance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_UpdateContext(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
// Get user
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update
|
|
||||||
user.Active = false
|
|
||||||
result, err := db.UpdateContext(ctx, &user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("UpdateContext failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_QueryX(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Alice",
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("QueryX failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 1 {
|
|
||||||
t.Errorf("Expected 1 row, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
|
|
||||||
var result User
|
|
||||||
err = rows.Orm(&result)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Name != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", result.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_QueryXContext(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
user := User{
|
|
||||||
Age: 30,
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.QueryXContext(ctx, &user, "SELECT * FROM users WHERE age = ?", ":age")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("QueryXContext failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 1 {
|
|
||||||
t.Errorf("Expected 1 row, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_QueryX_MissingField(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := User{Name: "Alice"}
|
|
||||||
|
|
||||||
_, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":unknown")
|
|
||||||
if !errors.Is(err, stardb.ErrFieldNotFound) {
|
|
||||||
t.Errorf("Expected ErrFieldNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_QueryX_NilTarget(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err := db.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id")
|
|
||||||
if !errors.Is(err, stardb.ErrTargetNil) {
|
|
||||||
t.Errorf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_QueryXS(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
users := []User{
|
|
||||||
{Name: "Alice"},
|
|
||||||
{Name: "Bob"},
|
|
||||||
}
|
|
||||||
|
|
||||||
results, err := db.QueryXS(&users, "SELECT * FROM users WHERE name = ?", ":name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("QueryXS failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(results) != 2 {
|
|
||||||
t.Errorf("Expected 2 results, got %d", len(results))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, rows := range results {
|
|
||||||
if rows.Length() != 1 {
|
|
||||||
t.Errorf("Expected 1 row in result %d, got %d", i, rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_QueryXS_NilTargets(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err := db.QueryXS(nil, "SELECT * FROM users")
|
|
||||||
if !errors.Is(err, stardb.ErrTargetsNil) {
|
|
||||||
t.Errorf("Expected ErrTargetsNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_ExecX(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Alice",
|
|
||||||
Age: 99,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExecX failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify
|
|
||||||
rows, err := db.Query("SELECT age FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
age := rows.Row(0).MustInt("age")
|
|
||||||
if age != 99 {
|
|
||||||
t.Errorf("Expected age 99, got %d", age)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_ExecXContext(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
user := User{
|
|
||||||
Name: "Bob",
|
|
||||||
Active: false,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := db.ExecXContext(ctx, &user, "UPDATE users SET active = ? WHERE name = ?", ":active", ":name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExecXContext failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_ExecX_MissingField(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := User{Name: "Alice", Age: 99}
|
|
||||||
|
|
||||||
_, err := db.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":unknown")
|
|
||||||
if !errors.Is(err, stardb.ErrFieldNotFound) {
|
|
||||||
t.Errorf("Expected ErrFieldNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_ExecX_NilTarget(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err := db.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id")
|
|
||||||
if !errors.Is(err, stardb.ErrTargetNil) {
|
|
||||||
t.Errorf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_ExecXS(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
users := []User{
|
|
||||||
{Name: "Alice", Age: 26},
|
|
||||||
{Name: "Bob", Age: 31},
|
|
||||||
}
|
|
||||||
|
|
||||||
results, err := db.ExecXS(&users, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("ExecXS failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(results) != 2 {
|
|
||||||
t.Errorf("Expected 2 results, got %d", len(results))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, result := range results {
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed for result %d: %v", i, err)
|
|
||||||
}
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected in result %d, got %d", i, affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_ExecXS_NilTargets(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err := db.ExecXS(nil, "UPDATE users SET age = age")
|
|
||||||
if !errors.Is(err, stardb.ErrTargetsNil) {
|
|
||||||
t.Errorf("Expected ErrTargetsNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_Insert(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Frank",
|
|
||||||
Email: "frank@example.com",
|
|
||||||
Age: 45,
|
|
||||||
Balance: 500.00,
|
|
||||||
Active: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := tx.Insert(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatalf("Tx.Insert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Commit failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
lastID, err := result.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("LastInsertId failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastID <= 0 {
|
|
||||||
t.Errorf("Expected positive last insert ID, got %d", lastID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Frank")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 1 {
|
|
||||||
t.Errorf("Expected 1 row, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_Update(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get user
|
|
||||||
rows, err := tx.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var user User
|
|
||||||
err = rows.Orm(&user)
|
|
||||||
rows.Close()
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update
|
|
||||||
user.Age = 27
|
|
||||||
result, err := tx.Update(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatalf("Tx.Update failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Commit failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_QueryX(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Charlie",
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := tx.QueryX(&user, "SELECT * FROM users WHERE name = ?", ":name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Tx.QueryX failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 1 {
|
|
||||||
t.Errorf("Expected 1 row, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_QueryX_NilTarget(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
_, err = tx.QueryX(nil, "SELECT * FROM users WHERE id = ?", ":id")
|
|
||||||
if !errors.Is(err, stardb.ErrTargetNil) {
|
|
||||||
t.Errorf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_ExecX(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Charlie",
|
|
||||||
Age: 36,
|
|
||||||
}
|
|
||||||
|
|
||||||
result, err := tx.ExecX(&user, "UPDATE users SET age = ? WHERE name = ?", ":age", ":name")
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatalf("Tx.ExecX failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = tx.Commit()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Commit failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := result.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("RowsAffected failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if affected != 1 {
|
|
||||||
t.Errorf("Expected 1 row affected, got %d", affected)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_ExecX_NilTarget(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
_, err = tx.ExecX(nil, "UPDATE users SET age = ? WHERE id = ?", ":age", ":id")
|
|
||||||
if !errors.Is(err, stardb.ErrTargetNil) {
|
|
||||||
t.Errorf("Expected ErrTargetNil, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarTx_Rollback(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
tx, err := db.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Begin failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Rollback User",
|
|
||||||
Email: "rollback@example.com",
|
|
||||||
Age: 50,
|
|
||||||
Balance: 600.00,
|
|
||||||
Active: true,
|
|
||||||
CreatedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.Insert(&user, "users", "id")
|
|
||||||
if err != nil {
|
|
||||||
tx.Rollback()
|
|
||||||
t.Fatalf("Tx.Insert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rollback instead of commit
|
|
||||||
err = tx.Rollback()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Rollback failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify the insert was rolled back
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Rollback User")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 0 {
|
|
||||||
t.Errorf("Expected 0 rows after rollback, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNamedParameterEscape(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
user := User{
|
|
||||||
Name: "Alice",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test escaped colon
|
|
||||||
rows, err := db.QueryX(&user, "SELECT * FROM users WHERE name = ?", `\:name`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("QueryX with escaped parameter failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Should use literal ":name" string, not the field value
|
|
||||||
if rows.Length() != 0 {
|
|
||||||
t.Errorf("Expected 0 rows with literal ':name', got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,128 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setupBenchmarkDB(b *testing.B) *stardb.StarDB {
|
|
||||||
b.Helper()
|
|
||||||
|
|
||||||
db := &stardb.StarDB{}
|
|
||||||
if err := db.Open("sqlite", ":memory:"); err != nil {
|
|
||||||
b.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := db.Exec(`
|
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
email TEXT NOT NULL,
|
|
||||||
age INTEGER,
|
|
||||||
balance REAL,
|
|
||||||
active BOOLEAN,
|
|
||||||
created_at DATETIME
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to create table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = db.Exec(`
|
|
||||||
INSERT INTO users (name, email, age, balance, active, created_at) VALUES
|
|
||||||
('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'),
|
|
||||||
('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'),
|
|
||||||
('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00')
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Failed to insert seed data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkQueryX(b *testing.B) {
|
|
||||||
db := setupBenchmarkDB(b)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
target := User{Name: "Alice"}
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rows, err := db.QueryX(&target, "SELECT * FROM users WHERE name = ?", ":name")
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("QueryX failed: %v", err)
|
|
||||||
}
|
|
||||||
_ = rows.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkOrm(b *testing.B) {
|
|
||||||
db := setupBenchmarkDB(b)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var users []User
|
|
||||||
if err := rows.Orm(&users); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
b.Fatalf("Orm failed: %v", err)
|
|
||||||
}
|
|
||||||
_ = rows.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkScanEach(b *testing.B) {
|
|
||||||
db := setupBenchmarkDB(b)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
count := 0
|
|
||||||
err := db.ScanEach("SELECT * FROM users ORDER BY name", func(row *stardb.StarResult) error {
|
|
||||||
_ = row.MustString("name")
|
|
||||||
count++
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
b.Fatalf("ScanEach failed: %v", err)
|
|
||||||
}
|
|
||||||
if count != 3 {
|
|
||||||
b.Fatalf("Unexpected row count: %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkBatchInsert(b *testing.B) {
|
|
||||||
db := setupBenchmarkDB(b)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
columns := []string{"name", "email", "age", "balance", "active", "created_at"}
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
base := i * 2
|
|
||||||
values := [][]interface{}{
|
|
||||||
{fmt.Sprintf("bench_user_%d", base), fmt.Sprintf("bench_%d@example.com", base), 20 + (base % 20), 99.5, true, time.Now()},
|
|
||||||
{fmt.Sprintf("bench_user_%d", base+1), fmt.Sprintf("bench_%d@example.com", base+1), 20 + ((base + 1) % 20), 199.5, false, time.Now()},
|
|
||||||
}
|
|
||||||
if _, err := db.BatchInsert("users", columns, values); err != nil {
|
|
||||||
b.Fatalf("BatchInsert failed: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,298 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDefaultPoolConfig(t *testing.T) {
|
|
||||||
config := stardb.DefaultPoolConfig()
|
|
||||||
|
|
||||||
if config.MaxOpenConns != 25 {
|
|
||||||
t.Errorf("Expected MaxOpenConns 25, got %d", config.MaxOpenConns)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.MaxIdleConns != 5 {
|
|
||||||
t.Errorf("Expected MaxIdleConns 5, got %d", config.MaxIdleConns)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ConnMaxLifetime != time.Hour {
|
|
||||||
t.Errorf("Expected ConnMaxLifetime 1 hour, got %v", config.ConnMaxLifetime)
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ConnMaxIdleTime != 10*time.Minute {
|
|
||||||
t.Errorf("Expected ConnMaxIdleTime 10 minutes, got %v", config.ConnMaxIdleTime)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SetPoolConfig(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 50,
|
|
||||||
MaxIdleConns: 10,
|
|
||||||
ConnMaxLifetime: 30 * time.Minute,
|
|
||||||
ConnMaxIdleTime: 5 * time.Minute,
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetPoolConfig(config)
|
|
||||||
|
|
||||||
stats := db.Stats()
|
|
||||||
if stats.MaxOpenConnections != 50 {
|
|
||||||
t.Errorf("Expected MaxOpenConnections 50, got %d", stats.MaxOpenConnections)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SetPoolConfig_Partial(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Only set MaxOpenConns
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 100,
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetPoolConfig(config)
|
|
||||||
|
|
||||||
stats := db.Stats()
|
|
||||||
if stats.MaxOpenConnections != 100 {
|
|
||||||
t.Errorf("Expected MaxOpenConnections 100, got %d", stats.MaxOpenConnections)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SetPoolConfig_Zero(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Zero values should be ignored
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 0,
|
|
||||||
MaxIdleConns: 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetPoolConfig(config)
|
|
||||||
|
|
||||||
// Should not panic or error
|
|
||||||
err = db.Ping()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Ping failed after SetPoolConfig with zero values: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SetPoolConfig_NilConfig(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.SetPoolConfig(nil)
|
|
||||||
|
|
||||||
err = db.Ping()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Ping failed after SetPoolConfig(nil): %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarDB_SetPoolConfig_BeforeOpen(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
|
|
||||||
db.SetPoolConfig(&stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 10,
|
|
||||||
})
|
|
||||||
|
|
||||||
// should not panic when called before Open
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenWithPool_Default(t *testing.T) {
|
|
||||||
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenWithPool failed: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
err = db.Ping()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Ping failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stats := db.Stats()
|
|
||||||
if stats.MaxOpenConnections != 25 {
|
|
||||||
t.Errorf("Expected default MaxOpenConnections 25, got %d", stats.MaxOpenConnections)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenWithPool_Custom(t *testing.T) {
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 15,
|
|
||||||
MaxIdleConns: 3,
|
|
||||||
ConnMaxLifetime: 20 * time.Minute,
|
|
||||||
ConnMaxIdleTime: 3 * time.Minute,
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := stardb.OpenWithPool("sqlite", ":memory:", config)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenWithPool failed: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
err = db.Ping()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Ping failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stats := db.Stats()
|
|
||||||
if stats.MaxOpenConnections != 15 {
|
|
||||||
t.Errorf("Expected MaxOpenConnections 15, got %d", stats.MaxOpenConnections)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenWithPool_InvalidDriver(t *testing.T) {
|
|
||||||
config := stardb.DefaultPoolConfig()
|
|
||||||
|
|
||||||
db, err := stardb.OpenWithPool("invalid_driver", "invalid_conn", config)
|
|
||||||
if err == nil {
|
|
||||||
db.Close()
|
|
||||||
t.Error("Expected error with invalid driver, got nil")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenWithPool_Query(t *testing.T) {
|
|
||||||
db, err := stardb.OpenWithPool("sqlite", ":memory:", nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenWithPool failed: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Create table
|
|
||||||
_, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert data
|
|
||||||
_, err = db.Exec(`INSERT INTO test (name) VALUES (?)`, "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to insert data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query data
|
|
||||||
rows, err := db.Query("SELECT * FROM test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if rows.Length() != 1 {
|
|
||||||
t.Errorf("Expected 1 row, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolConfig_AllFields(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 30,
|
|
||||||
MaxIdleConns: 8,
|
|
||||||
ConnMaxLifetime: 45 * time.Minute,
|
|
||||||
ConnMaxIdleTime: 7 * time.Minute,
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetPoolConfig(config)
|
|
||||||
|
|
||||||
// Verify by checking stats
|
|
||||||
stats := db.Stats()
|
|
||||||
if stats.MaxOpenConnections != 30 {
|
|
||||||
t.Errorf("Expected MaxOpenConnections 30, got %d", stats.MaxOpenConnections)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test that connections work
|
|
||||||
err = db.Ping()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Ping failed after SetPoolConfig: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPoolConfig_NegativeValues(t *testing.T) {
|
|
||||||
db := stardb.NewStarDB()
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Negative values should be ignored
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: -1,
|
|
||||||
MaxIdleConns: -1,
|
|
||||||
}
|
|
||||||
|
|
||||||
db.SetPoolConfig(config)
|
|
||||||
|
|
||||||
// Should not panic or error
|
|
||||||
err = db.Ping()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Ping failed after SetPoolConfig with negative values: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOpenWithPool_MultipleConnections(t *testing.T) {
|
|
||||||
config := &stardb.PoolConfig{
|
|
||||||
MaxOpenConns: 5,
|
|
||||||
MaxIdleConns: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
db, err := stardb.OpenWithPool("sqlite", ":memory:", config)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("OpenWithPool failed: %v", err)
|
|
||||||
}
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Create table
|
|
||||||
_, err = db.Exec(`CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT)`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Perform multiple operations
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
_, err = db.Exec(`INSERT INTO test (value) VALUES (?)`, "test")
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Insert %d failed: %v", i, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify
|
|
||||||
rows, err := db.Query("SELECT COUNT(*) as count FROM test")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
count := rows.Row(0).MustInt("count")
|
|
||||||
if count != 10 {
|
|
||||||
t.Errorf("Expected 10 rows, got %d", count)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,331 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStarResult_MustString(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
|
|
||||||
name := row.MustString("name")
|
|
||||||
if name != "Alice" {
|
|
||||||
t.Errorf("Expected 'Alice', got '%s'", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
email := row.MustString("email")
|
|
||||||
if email != "alice@example.com" {
|
|
||||||
t.Errorf("Expected 'alice@example.com', got '%s'", email)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test non-existent column
|
|
||||||
nonexistent := row.MustString("nonexistent")
|
|
||||||
if nonexistent != "" {
|
|
||||||
t.Errorf("Expected empty string for non-existent column, got '%s'", nonexistent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_MustInt(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Bob")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
|
|
||||||
age := row.MustInt("age")
|
|
||||||
if age != 30 {
|
|
||||||
t.Errorf("Expected age 30, got %d", age)
|
|
||||||
}
|
|
||||||
|
|
||||||
id := row.MustInt("id")
|
|
||||||
if id <= 0 {
|
|
||||||
t.Errorf("Expected positive id, got %d", id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_MustInt64(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Charlie")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
|
|
||||||
age := row.MustInt64("age")
|
|
||||||
if age != 35 {
|
|
||||||
t.Errorf("Expected age 35, got %d", age)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_MustFloat64(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
|
|
||||||
balance := row.MustFloat64("balance")
|
|
||||||
if balance != 100.50 {
|
|
||||||
t.Errorf("Expected balance 100.50, got %f", balance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_MustBool(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Alice - active
|
|
||||||
row := rows.Row(0)
|
|
||||||
active := row.MustBool("active")
|
|
||||||
if !active {
|
|
||||||
t.Errorf("Expected Alice to be active")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Charlie - inactive
|
|
||||||
row = rows.Row(2)
|
|
||||||
active = row.MustBool("active")
|
|
||||||
if active {
|
|
||||||
t.Errorf("Expected Charlie to be inactive")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_IsNil(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
// Insert a row with NULL value
|
|
||||||
_, err := db.Exec("INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, NULL, ?, ?, ?)",
|
|
||||||
"David", "david@example.com", 150.0, true, time.Now())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Insert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "David")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
|
|
||||||
if !row.IsNil("age") {
|
|
||||||
t.Errorf("Expected age to be NULL")
|
|
||||||
}
|
|
||||||
|
|
||||||
if row.IsNil("name") {
|
|
||||||
t.Errorf("Expected name to not be NULL")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResultCol_MustString(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
col := rows.Col("name")
|
|
||||||
names := col.MustString()
|
|
||||||
|
|
||||||
expected := []string{"Alice", "Bob", "Charlie"}
|
|
||||||
for i, name := range names {
|
|
||||||
if name != expected[i] {
|
|
||||||
t.Errorf("Expected name '%s', got '%s'", expected[i], name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResultCol_MustInt(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
col := rows.Col("age")
|
|
||||||
ages := col.MustInt()
|
|
||||||
|
|
||||||
expected := []int{25, 30, 35}
|
|
||||||
for i, age := range ages {
|
|
||||||
if age != expected[i] {
|
|
||||||
t.Errorf("Expected age %d, got %d", expected[i], age)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResultCol_MustFloat64(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
col := rows.Col("balance")
|
|
||||||
balances := col.MustFloat64()
|
|
||||||
|
|
||||||
expected := []float64{100.50, 200.75, 300.25}
|
|
||||||
for i, balance := range balances {
|
|
||||||
if balance != expected[i] {
|
|
||||||
t.Errorf("Expected balance %f, got %f", expected[i], balance)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResultCol_MustBool(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
col := rows.Col("active")
|
|
||||||
actives := col.MustBool()
|
|
||||||
|
|
||||||
expected := []bool{true, true, false}
|
|
||||||
for i, active := range actives {
|
|
||||||
if active != expected[i] {
|
|
||||||
t.Errorf("Expected active %v, got %v", expected[i], active)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_GetColumnNotFoundError(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
_, err = row.GetString("does_not_exist")
|
|
||||||
if !errors.Is(err, stardb.ErrColumnNotFound) {
|
|
||||||
t.Fatalf("Expected ErrColumnNotFound, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_GetNullValues(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
_, err := db.Exec(
|
|
||||||
"INSERT INTO users (name, email, age, balance, active, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
||||||
"NullUser", "null@example.com", nil, nil, nil, nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Insert failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users WHERE name = ?", "NullUser")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
row := rows.Row(0)
|
|
||||||
|
|
||||||
name, err := row.GetNullString("name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetNullString failed: %v", err)
|
|
||||||
}
|
|
||||||
if !name.Valid || name.String != "NullUser" {
|
|
||||||
t.Fatalf("Expected valid name NullUser, got %+v", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
age, err := row.GetNullInt64("age")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetNullInt64 failed: %v", err)
|
|
||||||
}
|
|
||||||
if age.Valid {
|
|
||||||
t.Fatalf("Expected NULL age, got %+v", age)
|
|
||||||
}
|
|
||||||
|
|
||||||
balance, err := row.GetNullFloat64("balance")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetNullFloat64 failed: %v", err)
|
|
||||||
}
|
|
||||||
if balance.Valid {
|
|
||||||
t.Fatalf("Expected NULL balance, got %+v", balance)
|
|
||||||
}
|
|
||||||
|
|
||||||
active, err := row.GetNullBool("active")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetNullBool failed: %v", err)
|
|
||||||
}
|
|
||||||
if active.Valid {
|
|
||||||
t.Fatalf("Expected NULL active, got %+v", active)
|
|
||||||
}
|
|
||||||
|
|
||||||
createdAt, err := row.GetNullTime("created_at")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetNullTime failed: %v", err)
|
|
||||||
}
|
|
||||||
if createdAt.Valid {
|
|
||||||
t.Fatalf("Expected NULL created_at, got %+v", createdAt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarResult_GetNullTime_Valid(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT created_at FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
value, err := rows.Row(0).GetNullTime("created_at")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("GetNullTime failed: %v", err)
|
|
||||||
}
|
|
||||||
if !value.Valid {
|
|
||||||
t.Fatal("Expected valid created_at")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@ -1,109 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStarRows_Row(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Test first row
|
|
||||||
row := rows.Row(0)
|
|
||||||
name := row.MustString("name")
|
|
||||||
if name != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test out of bounds
|
|
||||||
row = rows.Row(999)
|
|
||||||
if len(row.Result()) != 0 {
|
|
||||||
t.Errorf("Expected empty result for out of bounds index")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test negative index
|
|
||||||
row = rows.Row(-1)
|
|
||||||
if len(row.Result()) != 0 {
|
|
||||||
t.Errorf("Expected empty result for negative index")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Col(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT * FROM users ORDER BY name")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Test column extraction
|
|
||||||
col := rows.Col("name")
|
|
||||||
names := col.MustString()
|
|
||||||
|
|
||||||
if len(names) != 3 {
|
|
||||||
t.Errorf("Expected 3 names, got %d", len(names))
|
|
||||||
}
|
|
||||||
|
|
||||||
if names[0] != "Alice" {
|
|
||||||
t.Errorf("Expected first name 'Alice', got '%s'", names[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test non-existent column
|
|
||||||
col = rows.Col("nonexistent")
|
|
||||||
if len(col.Result()) != 0 {
|
|
||||||
t.Errorf("Expected empty result for non-existent column")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_Rescan(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
db.ManualScan = true
|
|
||||||
rows, err := db.Query("SELECT * FROM users")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
err = rows.Rescan()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Rescan failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Length() != 3 {
|
|
||||||
t.Errorf("Expected 3 rows, got %d", rows.Length())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestStarRows_StringResult(t *testing.T) {
|
|
||||||
db := setupTestDB(t)
|
|
||||||
defer db.Close()
|
|
||||||
|
|
||||||
rows, err := db.Query("SELECT name, age FROM users WHERE name = ?", "Alice")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Query failed: %v", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if len(rows.StringResult()) != 1 {
|
|
||||||
t.Fatalf("Expected 1 string result, got %d", len(rows.StringResult()))
|
|
||||||
}
|
|
||||||
|
|
||||||
record := rows.StringResult()[0]
|
|
||||||
if record["name"] != "Alice" {
|
|
||||||
t.Errorf("Expected name 'Alice', got '%s'", record["name"])
|
|
||||||
}
|
|
||||||
|
|
||||||
if record["age"] != "25" {
|
|
||||||
t.Errorf("Expected age '25', got '%s'", record["age"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
File diff suppressed because it is too large
Load Diff
@ -1,47 +0,0 @@
|
|||||||
package testing
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"b612.me/stardb"
|
|
||||||
_ "modernc.org/sqlite"
|
|
||||||
)
|
|
||||||
|
|
||||||
// setupTestDB creates a test database with sample data
|
|
||||||
// This function is only available when building with -tags=testing
|
|
||||||
func setupTestDB(t *testing.T) *stardb.StarDB {
|
|
||||||
db := &stardb.StarDB{}
|
|
||||||
err := db.Open("sqlite", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to open database: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test table
|
|
||||||
_, err = db.Exec(`
|
|
||||||
CREATE TABLE users (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
email TEXT NOT NULL,
|
|
||||||
age INTEGER,
|
|
||||||
balance REAL,
|
|
||||||
active BOOLEAN,
|
|
||||||
created_at DATETIME
|
|
||||||
)
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to create table: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert test data
|
|
||||||
_, err = db.Exec(`
|
|
||||||
INSERT INTO users (name, email, age, balance, active, created_at) VALUES
|
|
||||||
('Alice', 'alice@example.com', 25, 100.50, 1, '2024-01-01 10:00:00'),
|
|
||||||
('Bob', 'bob@example.com', 30, 200.75, 1, '2024-01-02 11:00:00'),
|
|
||||||
('Charlie', 'charlie@example.com', 35, 300.25, 0, '2024-01-03 12:00:00')
|
|
||||||
`)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to insert test data: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
264
tx.go
264
tx.go
@ -1,264 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// StarTx represents a database transaction
|
|
||||||
type StarTx struct {
|
|
||||||
tx *sql.Tx
|
|
||||||
db *StarDB
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *StarTx) ensureTx() error {
|
|
||||||
if t == nil || t.tx == nil || t.db == nil {
|
|
||||||
return ErrTxNotInitialized
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query executes a query within the transaction
|
|
||||||
func (t *StarTx) Query(query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return t.query(nil, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryContext executes a query with context within the transaction
|
|
||||||
func (t *StarTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
return t.query(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRaw executes a query in transaction and returns *sql.Rows without automatic parsing.
|
|
||||||
func (t *StarTx) QueryRaw(query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
return t.queryRaw(nil, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRawContext executes a query with context in transaction and returns *sql.Rows without automatic parsing.
|
|
||||||
func (t *StarTx) QueryRawContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
return t.queryRaw(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *StarTx) queryRaw(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
if err := t.ensureTx(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
|
|
||||||
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, hookArgs)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
var (
|
|
||||||
rows *sql.Rows
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if ctx == nil {
|
|
||||||
rows, err = t.tx.Query(query, args...)
|
|
||||||
} else {
|
|
||||||
rows, err = t.tx.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, hookArgs, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return rows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// query is the internal query implementation
|
|
||||||
func (t *StarTx) query(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
rows, err := t.queryRaw(ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
starRows := &StarRows{
|
|
||||||
rows: rows,
|
|
||||||
db: t.db,
|
|
||||||
}
|
|
||||||
|
|
||||||
if !t.db.ManualScan {
|
|
||||||
if err := starRows.parse(); err != nil {
|
|
||||||
_ = rows.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return starRows, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec executes a query within the transaction
|
|
||||||
func (t *StarTx) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return t.exec(nil, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecContext executes a query with context within the transaction
|
|
||||||
func (t *StarTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
return t.exec(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// exec is the internal exec implementation
|
|
||||||
func (t *StarTx) exec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if err := t.ensureTx(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, hookArgs, slowThreshold := t.db.prepareSQLCall(query, args)
|
|
||||||
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, hookArgs)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
var (
|
|
||||||
result sql.Result
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if ctx == nil {
|
|
||||||
result, err = t.tx.Exec(query, args...)
|
|
||||||
} else {
|
|
||||||
result, err = t.tx.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, hookArgs, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare creates a prepared statement within the transaction
|
|
||||||
func (t *StarTx) Prepare(query string) (*StarStmt, error) {
|
|
||||||
if err := t.ensureTx(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
|
|
||||||
hookCtx := t.db.hookContext(nil, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, nil)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
stmt, err := t.tx.Prepare(query)
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, nil, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// PrepareContext creates a prepared statement with context
|
|
||||||
func (t *StarTx) PrepareContext(ctx context.Context, query string) (*StarStmt, error) {
|
|
||||||
if err := t.ensureTx(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
|
|
||||||
query, beforeHook, afterHook, _, slowThreshold := t.db.prepareSQLCall(query, nil)
|
|
||||||
hookCtx := t.db.hookContext(ctx, query, beforeHook, afterHook)
|
|
||||||
if beforeHook != nil {
|
|
||||||
beforeHook(hookCtx, query, nil)
|
|
||||||
}
|
|
||||||
start := time.Now()
|
|
||||||
|
|
||||||
stmt, err := t.tx.PrepareContext(ctx, query)
|
|
||||||
duration := time.Since(start)
|
|
||||||
if shouldRunAfterHook(afterHook, slowThreshold, duration, err) {
|
|
||||||
afterHook(hookCtx, query, nil, duration, err)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StarStmt{stmt: stmt, db: t.db, sqlText: query}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryStmt executes a prepared statement query within the transaction
|
|
||||||
func (t *StarTx) QueryStmt(query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := t.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.Query(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryStmtContext executes a prepared statement query with context
|
|
||||||
func (t *StarTx) QueryStmtContext(ctx context.Context, query string, args ...interface{}) (*StarRows, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := t.PrepareContext(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.QueryContext(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecStmt executes a prepared statement within the transaction
|
|
||||||
func (t *StarTx) ExecStmt(query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := t.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.Exec(args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecStmtContext executes a prepared statement with context
|
|
||||||
func (t *StarTx) ExecStmtContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if strings.TrimSpace(query) == "" {
|
|
||||||
return nil, ErrQueryEmpty
|
|
||||||
}
|
|
||||||
stmt, err := t.PrepareContext(ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer stmt.Close()
|
|
||||||
return stmt.ExecContext(ctx, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit commits the transaction
|
|
||||||
func (t *StarTx) Commit() error {
|
|
||||||
if err := t.ensureTx(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return t.tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rollback rolls back the transaction
|
|
||||||
func (t *StarTx) Rollback() error {
|
|
||||||
if err := t.ensureTx(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return t.tx.Rollback()
|
|
||||||
}
|
|
||||||
45
tx_helper.go
45
tx_helper.go
@ -1,45 +0,0 @@
|
|||||||
package stardb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
// WithTx runs fn in a transaction and handles commit/rollback automatically.
|
|
||||||
func (s *StarDB) WithTx(fn func(tx *StarTx) error) error {
|
|
||||||
return s.WithTxContext(context.Background(), nil, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTxContext runs fn in a transaction with context/options and handles commit/rollback automatically.
|
|
||||||
func (s *StarDB) WithTxContext(ctx context.Context, opts *sql.TxOptions, fn func(tx *StarTx) error) (err error) {
|
|
||||||
if fn == nil {
|
|
||||||
return ErrTxFuncNil
|
|
||||||
}
|
|
||||||
if ctx == nil {
|
|
||||||
ctx = context.Background()
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := s.BeginTx(ctx, opts)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if p := recover(); p != nil {
|
|
||||||
_ = tx.Rollback()
|
|
||||||
panic(p)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := fn(tx); err != nil {
|
|
||||||
_ = tx.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
_ = tx.Rollback()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
Loading…
x
Reference in New Issue
Block a user