feat(notify): 重构通信内核并补齐 stream/bulk/record/transfer 能力
- 引入 LogicalConn/TransportConn 分层,ClientConn 保留兼容适配层 - 新增 Stream、Bulk、RecordStream 三条数据面能力及对应控制路径 - 完成 transfer/file 传输内核与状态快照、诊断能力 - 补齐 reconnect、inbound dispatcher、modern psk 等基础模块 - 增加大规模回归、并发与基准测试覆盖 - 更新依赖库
This commit is contained in:
@@ -0,0 +1,8 @@
|
||||
.sentrux/
|
||||
agent_readme.md
|
||||
target.md
|
||||
notify_plan.md
|
||||
.gocache
|
||||
.gocache/
|
||||
.tmp_*/
|
||||
.idea
|
||||
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2026 starnet contributors
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
@@ -0,0 +1,133 @@
|
||||
# notify
|
||||
|
||||
`b612.me/notify` 是一个面向点对点直连场景的 Go 通信基础包,覆盖消息信令、流式传输、批量数据通道和文件传输内核能力。
|
||||
|
||||
## 模块定位
|
||||
|
||||
- 消息面:`Send`、`SendWait`、`Reply`、`SetLink`
|
||||
- 流式数据面:`OpenStream`
|
||||
- 记录流数据面:`OpenRecordStream`
|
||||
- 批量数据面:`OpenBulk`(`shared` / `dedicated`)
|
||||
- 文件传输内核:transfer control / progress / resume
|
||||
- 会话模型:`LogicalConn`(逻辑会话)与 `TransportConn`(物理承载)分离
|
||||
|
||||
## 版本要求
|
||||
|
||||
- Go `1.24+`
|
||||
|
||||
## 安全初始化要求
|
||||
|
||||
`Client` / `Server` 在 `Connect` / `Listen` 前必须完成安全配置。默认使用现代 PSK 方案。
|
||||
|
||||
- 客户端:`UseModernPSKClient`
|
||||
- 服务端:`UseModernPSKServer`
|
||||
|
||||
未配置时会返回 `errModernPSKRequired`。
|
||||
|
||||
## 快速开始
|
||||
|
||||
服务端:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"b612.me/notify"
|
||||
)
|
||||
|
||||
func main() {
|
||||
srv := notify.NewServer()
|
||||
if err := notify.UseModernPSKServer(srv, []byte("shared-secret"), nil); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
srv.SetLink("ping", func(msg *notify.Message) {
|
||||
_ = msg.Reply([]byte("pong"))
|
||||
})
|
||||
if err := srv.Listen("tcp", "127.0.0.1:28080"); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
select {}
|
||||
}
|
||||
```
|
||||
|
||||
客户端:
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"b612.me/notify"
|
||||
)
|
||||
|
||||
func main() {
|
||||
cli := notify.NewClient()
|
||||
if err := notify.UseModernPSKClient(cli, []byte("shared-secret"), nil); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := cli.Connect("tcp", "127.0.0.1:28080"); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer cli.Stop()
|
||||
|
||||
reply, err := cli.SendWait("ping", []byte("hello"), 5*time.Second)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
log.Printf("reply=%s", string(reply.Value))
|
||||
}
|
||||
```
|
||||
|
||||
## 传输与 IPC
|
||||
|
||||
- `tcp`
|
||||
- `udp`
|
||||
- `unix`
|
||||
- `npipe`(Windows)
|
||||
|
||||
示例目录:
|
||||
|
||||
- [examples/signal](/mnt/c/coding/gocode/src/b612.me/notify/examples/signal)
|
||||
|
||||
## 现代 PSK 与兼容入口
|
||||
|
||||
现代方案特性:
|
||||
|
||||
- 共享密钥派生(Argon2id)
|
||||
- 消息层加密(AES-GCM)
|
||||
- `stream` / `bulk` fast path 复用现代编码栈
|
||||
|
||||
兼容入口仍保留,但属于历史路径:
|
||||
|
||||
- `UseLegacySecurityClient`
|
||||
- `UseLegacySecurityServer`
|
||||
- `ExchangeKey`
|
||||
- `SetSecretKey`
|
||||
- `SetMsgEn` / `SetMsgDe`
|
||||
|
||||
## 发布前检查
|
||||
|
||||
```bash
|
||||
export SENTRUX_SKIP_GRAMMAR_DOWNLOAD='1'
|
||||
sentrux check .
|
||||
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test ./...
|
||||
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test -race ./...
|
||||
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go vet ./...
|
||||
```
|
||||
|
||||
手工 soak 测试(可选):
|
||||
|
||||
```bash
|
||||
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache \
|
||||
go test -tags notify_manual_soak -run 'Test_ServerTuAndClientCommon|Test_normal|Test_normal_udp'
|
||||
```
|
||||
|
||||
## 兼容性说明
|
||||
|
||||
- 对外主入口保留:`NewClient`、`NewServer`、`Connect`、`Listen`、`SetLink`、`SetDefaultLink`、`Send`、`SendWait`、`SendObj`、`Reply`、`Stop`
|
||||
- 内部主对象已迁移为 `LogicalConn` / `TransportConn`
|
||||
- `ClientConn` 作为兼容适配层继续保留
|
||||
@@ -0,0 +1,266 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
bulkBatchMaxPayloads = 16
|
||||
)
|
||||
|
||||
const (
|
||||
bulkBatchRequestQueued int32 = iota
|
||||
bulkBatchRequestStarted
|
||||
bulkBatchRequestCanceled
|
||||
)
|
||||
|
||||
type bulkBatchRequestState struct {
|
||||
value atomic.Int32
|
||||
}
|
||||
|
||||
type bulkBatchRequest struct {
|
||||
ctx context.Context
|
||||
payload []byte
|
||||
deadline time.Time
|
||||
done chan error
|
||||
state *bulkBatchRequestState
|
||||
}
|
||||
|
||||
type bulkBatchSender struct {
|
||||
binding *transportBinding
|
||||
reqCh chan bulkBatchRequest
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
stopOnce sync.Once
|
||||
errMu sync.Mutex
|
||||
err error
|
||||
}
|
||||
|
||||
func newBulkBatchSender(binding *transportBinding) *bulkBatchSender {
|
||||
sender := &bulkBatchSender{
|
||||
binding: binding,
|
||||
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
go sender.run()
|
||||
return sender
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
req := bulkBatchRequest{
|
||||
ctx: ctx,
|
||||
payload: payload,
|
||||
done: make(chan error, 1),
|
||||
state: &bulkBatchRequestState{},
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
req.deadline = deadline
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-s.stopCh:
|
||||
return s.stoppedErr()
|
||||
case s.reqCh <- req:
|
||||
}
|
||||
select {
|
||||
case err := <-req.done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
if req.tryCancel() {
|
||||
return normalizeStreamDeadlineError(ctx.Err())
|
||||
}
|
||||
return <-req.done
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) run() {
|
||||
defer close(s.doneCh)
|
||||
for {
|
||||
req, ok := s.nextRequest()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
batch := []bulkBatchRequest{req}
|
||||
drain:
|
||||
for len(batch) < bulkBatchMaxPayloads {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
s.failPending(s.stoppedErr())
|
||||
return
|
||||
case next := <-s.reqCh:
|
||||
batch = append(batch, next)
|
||||
default:
|
||||
break drain
|
||||
}
|
||||
}
|
||||
active, payloads := activeBulkBatchRequests(batch)
|
||||
if len(active) == 0 {
|
||||
continue
|
||||
}
|
||||
deadline := bulkBatchRequestsEarliestDeadline(active)
|
||||
err := s.flush(payloads, deadline)
|
||||
if err != nil {
|
||||
s.setErr(err)
|
||||
for _, item := range active {
|
||||
item.done <- err
|
||||
}
|
||||
s.failPending(err)
|
||||
return
|
||||
}
|
||||
for _, item := range active {
|
||||
item.done <- err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
s.failPending(s.stoppedErr())
|
||||
return bulkBatchRequest{}, false
|
||||
case req := <-s.reqCh:
|
||||
return req, true
|
||||
}
|
||||
}
|
||||
|
||||
func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) {
|
||||
active := make([]bulkBatchRequest, 0, len(batch))
|
||||
payloads := make([][]byte, 0, len(batch))
|
||||
for _, item := range batch {
|
||||
if !item.tryStart() {
|
||||
item.done <- item.canceledErr()
|
||||
continue
|
||||
}
|
||||
if err := item.contextErr(); err != nil {
|
||||
item.done <- err
|
||||
continue
|
||||
}
|
||||
active = append(active, item)
|
||||
payloads = append(payloads, item.payload)
|
||||
}
|
||||
return active, payloads
|
||||
}
|
||||
|
||||
func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time {
|
||||
var deadline time.Time
|
||||
for _, item := range batch {
|
||||
if item.deadline.IsZero() {
|
||||
continue
|
||||
}
|
||||
if deadline.IsZero() || item.deadline.Before(deadline) {
|
||||
deadline = item.deadline
|
||||
}
|
||||
}
|
||||
return deadline
|
||||
}
|
||||
|
||||
func (r bulkBatchRequest) contextErr() error {
|
||||
if r.ctx == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-r.ctx.Done():
|
||||
return normalizeStreamDeadlineError(r.ctx.Err())
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r bulkBatchRequest) tryStart() bool {
|
||||
if r.state == nil {
|
||||
return true
|
||||
}
|
||||
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted)
|
||||
}
|
||||
|
||||
func (r bulkBatchRequest) tryCancel() bool {
|
||||
if r.state == nil {
|
||||
return false
|
||||
}
|
||||
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled)
|
||||
}
|
||||
|
||||
func (r bulkBatchRequest) canceledErr() error {
|
||||
if err := r.contextErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
|
||||
if s == nil || s.binding == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
queue := s.binding.queueSnapshot()
|
||||
if queue == nil {
|
||||
return errTransportFrameQueueUnavailable
|
||||
}
|
||||
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
|
||||
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
s.setErr(errTransportDetached)
|
||||
close(s.stopCh)
|
||||
})
|
||||
<-s.doneCh
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) failPending(err error) {
|
||||
for {
|
||||
select {
|
||||
case item := <-s.reqCh:
|
||||
item.done <- err
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) setErr(err error) {
|
||||
if s == nil || err == nil {
|
||||
return
|
||||
}
|
||||
s.errMu.Lock()
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
s.errMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) errSnapshot() error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
s.errMu.Lock()
|
||||
defer s.errMu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *bulkBatchSender) stoppedErr() error {
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return err
|
||||
}
|
||||
return errTransportDetached
|
||||
}
|
||||
@@ -0,0 +1,392 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkBulkTCPThroughput(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
}{
|
||||
{
|
||||
name: "chunk_256KiB",
|
||||
payloadSize: 256 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_768KiB",
|
||||
payloadSize: 768 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_2MiB",
|
||||
payloadSize: 2 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, false)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
}{
|
||||
{
|
||||
name: "chunk_256KiB",
|
||||
payloadSize: 256 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_768KiB",
|
||||
payloadSize: 768 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "chunk_2MiB",
|
||||
payloadSize: 2 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughput(b, tc.payloadSize, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
concurrency int
|
||||
}{
|
||||
{
|
||||
name: "bulks_2_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 2,
|
||||
},
|
||||
{
|
||||
name: "bulks_4_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 4,
|
||||
},
|
||||
{
|
||||
name: "bulks_2_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
concurrency: 2,
|
||||
},
|
||||
{
|
||||
name: "bulks_4_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
concurrency: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
concurrency int
|
||||
}{
|
||||
{
|
||||
name: "bulks_2_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 2,
|
||||
},
|
||||
{
|
||||
name: "bulks_4_512KiB",
|
||||
payloadSize: 512 * 1024,
|
||||
concurrency: 4,
|
||||
},
|
||||
{
|
||||
name: "bulks_2_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
concurrency: 2,
|
||||
},
|
||||
{
|
||||
name: "bulks_4_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
concurrency: 4,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
||||
b.Helper()
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
acceptCh := make(chan BulkAcceptInfo, 1)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
acceptCh <- info
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
b.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
_ = server.Stop()
|
||||
})
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||
b.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
_ = client.Stop()
|
||||
})
|
||||
|
||||
totalBytes := int64(payloadSize)
|
||||
if b.N > 1 {
|
||||
totalBytes = int64(payloadSize) * int64(b.N)
|
||||
}
|
||||
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: totalBytes,
|
||||
},
|
||||
ChunkSize: payloadSize,
|
||||
Dedicated: dedicated,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("client OpenBulk failed: %v", err)
|
||||
}
|
||||
accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second)
|
||||
|
||||
drainDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := io.Copy(io.Discard, accepted.Bulk)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
drainDone <- err
|
||||
return
|
||||
}
|
||||
drainDone <- nil
|
||||
}()
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
n, err := bulk.Write(payload)
|
||||
if err != nil {
|
||||
b.Fatalf("bulk Write failed at iter %d: %v", i, err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
|
||||
}
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
if err := bulk.CloseWrite(); err != nil {
|
||||
b.Fatalf("bulk CloseWrite failed: %v", err)
|
||||
}
|
||||
select {
|
||||
case err := <-drainDone:
|
||||
if err != nil {
|
||||
b.Fatalf("server drain failed: %v", err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
b.Fatal("timed out waiting for server drain")
|
||||
}
|
||||
|
||||
_ = accepted.Bulk.Close()
|
||||
_ = bulk.Close()
|
||||
}
|
||||
|
||||
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) {
|
||||
b.Helper()
|
||||
if concurrency <= 0 {
|
||||
b.Fatal("concurrency must be > 0")
|
||||
}
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
acceptCh := make(chan BulkAcceptInfo, concurrency*2)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
acceptCh <- info
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
b.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
_ = server.Stop()
|
||||
})
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||
b.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
_ = client.Stop()
|
||||
})
|
||||
|
||||
bulks := make([]Bulk, 0, concurrency)
|
||||
acceptedBulks := make([]Bulk, 0, concurrency)
|
||||
totalBytes := int64(payloadSize)
|
||||
if b.N > 1 {
|
||||
totalBytes = int64(payloadSize) * int64(b.N)
|
||||
}
|
||||
for index := 0; index < concurrency; index++ {
|
||||
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: int64(index) * totalBytes,
|
||||
Length: totalBytes,
|
||||
},
|
||||
ChunkSize: payloadSize,
|
||||
Dedicated: dedicated,
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("client OpenBulk failed for bulk %d: %v", index, err)
|
||||
}
|
||||
bulks = append(bulks, bulk)
|
||||
accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second)
|
||||
acceptedBulks = append(acceptedBulks, accepted.Bulk)
|
||||
}
|
||||
|
||||
drainDone := make(chan error, concurrency)
|
||||
for _, acceptedBulk := range acceptedBulks {
|
||||
bulk := acceptedBulk
|
||||
go func() {
|
||||
_, err := io.Copy(io.Discard, bulk)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
drainDone <- err
|
||||
return
|
||||
}
|
||||
drainDone <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
b.ResetTimer()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, concurrency)
|
||||
for index, bulk := range bulks {
|
||||
count := b.N / concurrency
|
||||
if index < b.N%concurrency {
|
||||
count++
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(bulk Bulk, count int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < count; i++ {
|
||||
n, err := bulk.Write(payload)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if n != len(payload) {
|
||||
errCh <- errors.New("bulk write bytes mismatch")
|
||||
return
|
||||
}
|
||||
}
|
||||
}(bulk, count)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
b.Fatalf("concurrent bulk write failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
|
||||
for index, bulk := range bulks {
|
||||
if err := bulk.CloseWrite(); err != nil {
|
||||
b.Fatalf("bulk %d CloseWrite failed: %v", index, err)
|
||||
}
|
||||
}
|
||||
|
||||
for index := 0; index < concurrency; index++ {
|
||||
select {
|
||||
case err := <-drainDone:
|
||||
if err != nil {
|
||||
b.Fatalf("server drain failed: %v", err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
for _, bulk := range acceptedBulks {
|
||||
_ = bulk.Close()
|
||||
}
|
||||
for _, bulk := range bulks {
|
||||
_ = bulk.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func waitBenchmarkAcceptedBulk(tb testing.TB, ch <-chan BulkAcceptInfo, timeout time.Duration) BulkAcceptInfo {
|
||||
tb.Helper()
|
||||
select {
|
||||
case info := <-ch:
|
||||
return info
|
||||
case <-time.After(timeout):
|
||||
tb.Fatalf("timed out waiting for accepted bulk after %v", timeout)
|
||||
return BulkAcceptInfo{}
|
||||
}
|
||||
}
|
||||
+702
@@ -0,0 +1,702 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BulkOpenRequest struct {
|
||||
BulkID string
|
||||
DataID uint64
|
||||
Range BulkRange
|
||||
Metadata BulkMetadata
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
Dedicated bool
|
||||
AttachToken string
|
||||
ChunkSize int
|
||||
WindowBytes int
|
||||
MaxInFlight int
|
||||
}
|
||||
|
||||
type BulkOpenResponse struct {
|
||||
BulkID string
|
||||
DataID uint64
|
||||
Accepted bool
|
||||
Dedicated bool
|
||||
AttachToken string
|
||||
TransportGeneration uint64
|
||||
Error string
|
||||
}
|
||||
|
||||
type BulkCloseRequest struct {
|
||||
BulkID string
|
||||
Full bool
|
||||
}
|
||||
|
||||
type BulkCloseResponse struct {
|
||||
BulkID string
|
||||
Accepted bool
|
||||
Error string
|
||||
}
|
||||
|
||||
type BulkResetRequest struct {
|
||||
BulkID string
|
||||
DataID uint64
|
||||
Error string
|
||||
}
|
||||
|
||||
type BulkResetResponse struct {
|
||||
BulkID string
|
||||
Accepted bool
|
||||
Error string
|
||||
}
|
||||
|
||||
type BulkReleaseRequest struct {
|
||||
BulkID string
|
||||
DataID uint64
|
||||
Bytes int64
|
||||
Chunks int
|
||||
}
|
||||
|
||||
func bindClientBulkControl(c *ClientCommon) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.SetLink(BulkOpenSignalKey, func(msg *Message) {
|
||||
c.handleInboundBulkOpen(msg)
|
||||
})
|
||||
c.SetLink(BulkCloseSignalKey, func(msg *Message) {
|
||||
c.handleInboundBulkClose(msg)
|
||||
})
|
||||
c.SetLink(BulkResetSignalKey, func(msg *Message) {
|
||||
c.handleInboundBulkReset(msg)
|
||||
})
|
||||
c.SetLink(BulkReleaseSignalKey, func(msg *Message) {
|
||||
c.handleInboundBulkRelease(msg)
|
||||
})
|
||||
}
|
||||
|
||||
func bindServerBulkControl(s *ServerCommon) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.SetLink(BulkOpenSignalKey, func(msg *Message) {
|
||||
s.handleInboundBulkOpen(msg)
|
||||
})
|
||||
s.SetLink(BulkCloseSignalKey, func(msg *Message) {
|
||||
s.handleInboundBulkClose(msg)
|
||||
})
|
||||
s.SetLink(BulkResetSignalKey, func(msg *Message) {
|
||||
s.handleInboundBulkReset(msg)
|
||||
})
|
||||
s.SetLink(BulkReleaseSignalKey, func(msg *Message) {
|
||||
s.handleInboundBulkRelease(msg)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
|
||||
req, err := decodeBulkOpenRequest(msg)
|
||||
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if req.Dedicated {
|
||||
if err := clientDedicatedBulkSupportError(c); err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
resp.Error = errBulkRuntimeNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
scope := clientFileScope()
|
||||
if req.DataID == 0 {
|
||||
req.DataID = runtime.nextDataID()
|
||||
resp.DataID = req.DataID
|
||||
}
|
||||
if req.Dedicated && req.AttachToken == "" {
|
||||
req.AttachToken = newBulkAttachToken()
|
||||
}
|
||||
resp.AttachToken = req.AttachToken
|
||||
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
|
||||
bulk.setClientSnapshotOwner(c)
|
||||
if err := runtime.register(scope, bulk); err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
handler := runtime.handlerSnapshot()
|
||||
if handler == nil {
|
||||
bulk.markReset(errBulkHandlerNotConfigured)
|
||||
resp.Error = errBulkHandlerNotConfigured.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if req.Dedicated {
|
||||
if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil {
|
||||
bulk.markReset(err)
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
info := BulkAcceptInfo{
|
||||
ID: bulk.ID(),
|
||||
Range: bulk.Range(),
|
||||
Metadata: bulk.Metadata(),
|
||||
Dedicated: bulk.Dedicated(),
|
||||
TransportGeneration: bulk.TransportGeneration(),
|
||||
Bulk: bulk,
|
||||
}
|
||||
if err := handler(info); err != nil {
|
||||
bulk.markReset(err)
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
resp.Accepted = true
|
||||
resp.DataID = bulk.dataIDSnapshot()
|
||||
resp.TransportGeneration = bulk.TransportGeneration()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
|
||||
req, err := decodeBulkOpenRequest(msg)
|
||||
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
runtime := s.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
resp.Error = errBulkRuntimeNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
logical := messageLogicalConnSnapshot(msg)
|
||||
if logical == nil {
|
||||
resp.Error = errBulkLogicalConnNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
transport := messageTransportConnSnapshot(msg)
|
||||
if req.Dedicated {
|
||||
if err := logicalDedicatedBulkSupportError(logical); err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if transport != nil {
|
||||
if err := transportDedicatedBulkSupportError(transport); err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
scope := serverFileScope(logical)
|
||||
if req.DataID == 0 {
|
||||
req.DataID = runtime.nextDataID()
|
||||
resp.DataID = req.DataID
|
||||
}
|
||||
if req.Dedicated && req.AttachToken == "" {
|
||||
req.AttachToken = newBulkAttachToken()
|
||||
}
|
||||
resp.AttachToken = req.AttachToken
|
||||
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, bulkTransportGeneration(logical, transport), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
|
||||
if err := runtime.register(scope, bulk); err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
handler := runtime.handlerSnapshot()
|
||||
if handler == nil {
|
||||
bulk.markReset(errBulkHandlerNotConfigured)
|
||||
resp.Error = errBulkHandlerNotConfigured.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
info := BulkAcceptInfo{
|
||||
ID: bulk.ID(),
|
||||
Range: bulk.Range(),
|
||||
Metadata: bulk.Metadata(),
|
||||
Dedicated: bulk.Dedicated(),
|
||||
LogicalConn: logical,
|
||||
TransportConn: transport,
|
||||
TransportGeneration: bulk.TransportGeneration(),
|
||||
Bulk: bulk,
|
||||
}
|
||||
if err := handler(info); err != nil {
|
||||
bulk.markReset(err)
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
resp.Accepted = true
|
||||
resp.DataID = bulk.dataIDSnapshot()
|
||||
resp.TransportGeneration = bulk.TransportGeneration()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleInboundBulkClose(msg *Message) {
|
||||
req, err := decodeBulkCloseRequest(msg)
|
||||
resp := BulkCloseResponse{BulkID: req.BulkID}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
resp.Error = errBulkRuntimeNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
|
||||
if !ok {
|
||||
resp.Error = errBulkNotFound.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if req.Full {
|
||||
bulk.markPeerClosed()
|
||||
} else {
|
||||
bulk.markRemoteClosed()
|
||||
}
|
||||
resp.Accepted = true
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) handleInboundBulkClose(msg *Message) {
|
||||
req, err := decodeBulkCloseRequest(msg)
|
||||
resp := BulkCloseResponse{BulkID: req.BulkID}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
runtime := s.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
resp.Error = errBulkRuntimeNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
logical := messageLogicalConnSnapshot(msg)
|
||||
scope := serverFileScope(logical)
|
||||
bulk, ok := runtime.lookup(scope, req.BulkID)
|
||||
if !ok {
|
||||
resp.Error = errBulkNotFound.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if req.Full {
|
||||
bulk.markPeerClosed()
|
||||
} else {
|
||||
bulk.markRemoteClosed()
|
||||
}
|
||||
resp.Accepted = true
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleInboundBulkReset(msg *Message) {
|
||||
req, err := decodeBulkResetRequest(msg)
|
||||
resp := BulkResetResponse{BulkID: req.BulkID}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
resp.Error = errBulkRuntimeNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
|
||||
if !ok && req.DataID != 0 {
|
||||
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
|
||||
}
|
||||
if !ok {
|
||||
resp.Error = errBulkNotFound.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if resp.BulkID == "" {
|
||||
resp.BulkID = bulk.ID()
|
||||
}
|
||||
bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error)))
|
||||
resp.Accepted = true
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleInboundBulkRelease(msg *Message) {
|
||||
req, err := decodeBulkReleaseRequest(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
|
||||
if !ok && req.DataID != 0 {
|
||||
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
|
||||
req, err := decodeBulkResetRequest(msg)
|
||||
resp := BulkResetResponse{BulkID: req.BulkID}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
runtime := s.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
resp.Error = errBulkRuntimeNil.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
logical := messageLogicalConnSnapshot(msg)
|
||||
scope := serverFileScope(logical)
|
||||
bulk, ok := runtime.lookup(scope, req.BulkID)
|
||||
if !ok && req.DataID != 0 {
|
||||
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
|
||||
}
|
||||
if !ok {
|
||||
resp.Error = errBulkNotFound.Error()
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
return
|
||||
}
|
||||
if resp.BulkID == "" {
|
||||
resp.BulkID = bulk.ID()
|
||||
}
|
||||
bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error)))
|
||||
resp.Accepted = true
|
||||
replyBulkControlIfNeeded(msg, resp)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) handleInboundBulkRelease(msg *Message) {
|
||||
req, err := decodeBulkReleaseRequest(msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
runtime := s.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
logical := messageLogicalConnSnapshot(msg)
|
||||
scope := serverFileScope(logical)
|
||||
bulk, ok := runtime.lookup(scope, req.BulkID)
|
||||
if !ok && req.DataID != 0 {
|
||||
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
|
||||
}
|
||||
|
||||
func replyBulkControlIfNeeded(msg *Message, value interface{}) {
|
||||
if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) {
|
||||
return
|
||||
}
|
||||
_ = msg.ReplyObj(value)
|
||||
}
|
||||
|
||||
func sendBulkOpenClient(ctx context.Context, c Client, req BulkOpenRequest) (BulkOpenResponse, error) {
|
||||
if c == nil {
|
||||
return BulkOpenResponse{}, errBulkClientNil
|
||||
}
|
||||
msg, err := c.SendObjCtx(ctx, BulkOpenSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkOpenResponse{}, err
|
||||
}
|
||||
return decodeBulkOpenResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkOpenRequest) (BulkOpenResponse, error) {
|
||||
if s == nil {
|
||||
return BulkOpenResponse{}, errBulkServerNil
|
||||
}
|
||||
if logical == nil {
|
||||
return BulkOpenResponse{}, errBulkLogicalConnNil
|
||||
}
|
||||
msg, err := s.SendObjCtxLogical(ctx, logical, BulkOpenSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkOpenResponse{}, err
|
||||
}
|
||||
return decodeBulkOpenResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkOpenRequest) (BulkOpenResponse, error) {
|
||||
if s == nil {
|
||||
return BulkOpenResponse{}, errBulkServerNil
|
||||
}
|
||||
if transport == nil {
|
||||
return BulkOpenResponse{}, errBulkTransportNil
|
||||
}
|
||||
msg, err := s.SendObjCtxTransport(ctx, transport, BulkOpenSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkOpenResponse{}, err
|
||||
}
|
||||
return decodeBulkOpenResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkCloseClient(ctx context.Context, c Client, req BulkCloseRequest) (BulkCloseResponse, error) {
|
||||
if c == nil {
|
||||
return BulkCloseResponse{}, errBulkClientNil
|
||||
}
|
||||
msg, err := c.SendObjCtx(ctx, BulkCloseSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkCloseResponse{}, err
|
||||
}
|
||||
return decodeBulkCloseResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkCloseRequest) (BulkCloseResponse, error) {
|
||||
if s == nil {
|
||||
return BulkCloseResponse{}, errBulkServerNil
|
||||
}
|
||||
if logical == nil {
|
||||
return BulkCloseResponse{}, errBulkLogicalConnNil
|
||||
}
|
||||
msg, err := s.SendObjCtxLogical(ctx, logical, BulkCloseSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkCloseResponse{}, err
|
||||
}
|
||||
return decodeBulkCloseResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkCloseRequest) (BulkCloseResponse, error) {
|
||||
if s == nil {
|
||||
return BulkCloseResponse{}, errBulkServerNil
|
||||
}
|
||||
if transport == nil {
|
||||
return BulkCloseResponse{}, errBulkTransportNil
|
||||
}
|
||||
msg, err := s.SendObjCtxTransport(ctx, transport, BulkCloseSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkCloseResponse{}, err
|
||||
}
|
||||
return decodeBulkCloseResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkResetClient(ctx context.Context, c Client, req BulkResetRequest) (BulkResetResponse, error) {
|
||||
if c == nil {
|
||||
return BulkResetResponse{}, errBulkClientNil
|
||||
}
|
||||
msg, err := c.SendObjCtx(ctx, BulkResetSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkResetResponse{}, err
|
||||
}
|
||||
return decodeBulkResetResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkResetRequest) (BulkResetResponse, error) {
|
||||
if s == nil {
|
||||
return BulkResetResponse{}, errBulkServerNil
|
||||
}
|
||||
if logical == nil {
|
||||
return BulkResetResponse{}, errBulkLogicalConnNil
|
||||
}
|
||||
msg, err := s.SendObjCtxLogical(ctx, logical, BulkResetSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkResetResponse{}, err
|
||||
}
|
||||
return decodeBulkResetResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkResetRequest) (BulkResetResponse, error) {
|
||||
if s == nil {
|
||||
return BulkResetResponse{}, errBulkServerNil
|
||||
}
|
||||
if transport == nil {
|
||||
return BulkResetResponse{}, errBulkTransportNil
|
||||
}
|
||||
msg, err := s.SendObjCtxTransport(ctx, transport, BulkResetSignalKey, req)
|
||||
if err != nil {
|
||||
return BulkResetResponse{}, err
|
||||
}
|
||||
return decodeBulkResetResponse(msg)
|
||||
}
|
||||
|
||||
func sendBulkReleaseClient(c Client, req BulkReleaseRequest) error {
|
||||
if c == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
return c.SendObj(BulkReleaseSignalKey, req)
|
||||
}
|
||||
|
||||
func sendBulkReleaseServerLogical(s Server, logical *LogicalConn, req BulkReleaseRequest) error {
|
||||
if s == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
if logical == nil {
|
||||
return errBulkLogicalConnNil
|
||||
}
|
||||
return s.SendObjLogical(logical, BulkReleaseSignalKey, req)
|
||||
}
|
||||
|
||||
func sendBulkReleaseServerTransport(s Server, transport *TransportConn, req BulkReleaseRequest) error {
|
||||
if s == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
if transport == nil {
|
||||
return errBulkTransportNil
|
||||
}
|
||||
return s.SendObjTransport(transport, BulkReleaseSignalKey, req)
|
||||
}
|
||||
|
||||
func decodeBulkOpenRequest(msg *Message) (BulkOpenRequest, error) {
|
||||
var req BulkOpenRequest
|
||||
if msg == nil {
|
||||
return BulkOpenRequest{}, errBulkIDEmpty
|
||||
}
|
||||
if err := msg.Value.Orm(&req); err != nil {
|
||||
return BulkOpenRequest{}, err
|
||||
}
|
||||
req = normalizeBulkOpenRequest(req)
|
||||
if req.BulkID == "" {
|
||||
return BulkOpenRequest{}, errBulkIDEmpty
|
||||
}
|
||||
if !validBulkRange(req.Range) {
|
||||
return BulkOpenRequest{}, errBulkRangeInvalid
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeBulkCloseRequest(msg *Message) (BulkCloseRequest, error) {
|
||||
var req BulkCloseRequest
|
||||
if msg == nil {
|
||||
return BulkCloseRequest{}, errBulkIDEmpty
|
||||
}
|
||||
if err := msg.Value.Orm(&req); err != nil {
|
||||
return BulkCloseRequest{}, err
|
||||
}
|
||||
if req.BulkID == "" {
|
||||
return BulkCloseRequest{}, errBulkIDEmpty
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeBulkResetRequest(msg *Message) (BulkResetRequest, error) {
|
||||
var req BulkResetRequest
|
||||
if msg == nil {
|
||||
return BulkResetRequest{}, errBulkIDEmpty
|
||||
}
|
||||
if err := msg.Value.Orm(&req); err != nil {
|
||||
return BulkResetRequest{}, err
|
||||
}
|
||||
if req.BulkID == "" && req.DataID == 0 {
|
||||
return BulkResetRequest{}, errBulkIDEmpty
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) {
|
||||
var req BulkReleaseRequest
|
||||
if msg == nil {
|
||||
return BulkReleaseRequest{}, errBulkIDEmpty
|
||||
}
|
||||
if err := msg.Value.Orm(&req); err != nil {
|
||||
return BulkReleaseRequest{}, err
|
||||
}
|
||||
if req.BulkID == "" && req.DataID == 0 {
|
||||
return BulkReleaseRequest{}, errBulkIDEmpty
|
||||
}
|
||||
if req.Bytes < 0 || req.Chunks < 0 {
|
||||
return BulkReleaseRequest{}, errBulkRangeInvalid
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) {
|
||||
var resp BulkOpenResponse
|
||||
if err := msg.Value.Orm(&resp); err != nil {
|
||||
return BulkOpenResponse{}, err
|
||||
}
|
||||
return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil)
|
||||
}
|
||||
|
||||
func decodeBulkCloseResponse(msg Message) (BulkCloseResponse, error) {
|
||||
var resp BulkCloseResponse
|
||||
if err := msg.Value.Orm(&resp); err != nil {
|
||||
return BulkCloseResponse{}, err
|
||||
}
|
||||
return resp, bulkControlResultError("close", resp.Accepted, resp.Error, nil)
|
||||
}
|
||||
|
||||
func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) {
|
||||
var resp BulkResetResponse
|
||||
if err := msg.Value.Orm(&resp); err != nil {
|
||||
return BulkResetResponse{}, err
|
||||
}
|
||||
return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil)
|
||||
}
|
||||
|
||||
func bulkControlResultError(op string, accepted bool, message string, callErr error) error {
|
||||
if callErr != nil {
|
||||
return callErr
|
||||
}
|
||||
if message != "" {
|
||||
return bulkControlMessageError(message)
|
||||
}
|
||||
if accepted {
|
||||
return nil
|
||||
}
|
||||
if op == "open" {
|
||||
return errBulkRejected
|
||||
}
|
||||
return errors.New("bulk " + op + " rejected")
|
||||
}
|
||||
|
||||
func bulkControlMessageError(message string) error {
|
||||
switch message {
|
||||
case errBulkNotFound.Error():
|
||||
return errBulkNotFound
|
||||
case errBulkAlreadyExists.Error():
|
||||
return errBulkAlreadyExists
|
||||
case errBulkHandlerNotConfigured.Error():
|
||||
return errBulkHandlerNotConfigured
|
||||
case errBulkLogicalConnNil.Error():
|
||||
return errBulkLogicalConnNil
|
||||
case errBulkTransportNil.Error():
|
||||
return errBulkTransportNil
|
||||
case errBulkRuntimeNil.Error():
|
||||
return errBulkRuntimeNil
|
||||
case errBulkIDEmpty.Error():
|
||||
return errBulkIDEmpty
|
||||
case errBulkRangeInvalid.Error():
|
||||
return errBulkRangeInvalid
|
||||
case errBulkDataIDEmpty.Error():
|
||||
return errBulkDataIDEmpty
|
||||
default:
|
||||
return errors.New(message)
|
||||
}
|
||||
}
|
||||
|
||||
func bulkRemoteResetError(message string) error {
|
||||
if message == "" {
|
||||
return errBulkReset
|
||||
}
|
||||
return errors.New(message)
|
||||
}
|
||||
|
||||
func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 {
|
||||
return streamTransportGeneration(logical, transport)
|
||||
}
|
||||
@@ -0,0 +1,723 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/notify/internal/transport"
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
systemBulkAttachKey = "_notify_bulk_attach"
|
||||
bulkDedicatedRecordMagic = "NBR1"
|
||||
bulkDedicatedRecordHeaderLen = 8
|
||||
bulkDedicatedAttachTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type bulkAttachRequest struct {
|
||||
PeerID string
|
||||
BulkID string
|
||||
AttachToken string
|
||||
}
|
||||
|
||||
type bulkAttachResponse struct {
|
||||
Accepted bool
|
||||
Error string
|
||||
}
|
||||
|
||||
func newBulkAttachToken() string {
|
||||
var buf [16]byte
|
||||
if _, err := cryptorand.Read(buf[:]); err == nil {
|
||||
return hex.EncodeToString(buf[:])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func decodeBulkAttachRequest(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachRequest, error) {
|
||||
var req bulkAttachRequest
|
||||
if decodeFn == nil {
|
||||
decodeFn = Decode
|
||||
}
|
||||
raw := []byte(data)
|
||||
value, err := decodeFn(raw)
|
||||
if err != nil {
|
||||
return req, err
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case bulkAttachRequest:
|
||||
return typed, nil
|
||||
case *bulkAttachRequest:
|
||||
if typed == nil {
|
||||
return req, errors.New("bulk attach request is nil")
|
||||
}
|
||||
return *typed, nil
|
||||
default:
|
||||
return req, errors.New("invalid bulk attach payload")
|
||||
}
|
||||
}
|
||||
|
||||
func decodeBulkAttachResponse(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachResponse, error) {
|
||||
var resp bulkAttachResponse
|
||||
if decodeFn == nil {
|
||||
decodeFn = Decode
|
||||
}
|
||||
raw := []byte(data)
|
||||
value, err := decodeFn(raw)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
switch typed := value.(type) {
|
||||
case bulkAttachResponse:
|
||||
return typed, nil
|
||||
case *bulkAttachResponse:
|
||||
if typed == nil {
|
||||
return resp, errors.New("bulk attach response is nil")
|
||||
}
|
||||
return *typed, nil
|
||||
default:
|
||||
return resp, errors.New("invalid bulk attach response")
|
||||
}
|
||||
}
|
||||
|
||||
func encodeDirectSignalFrame(queue *stario.StarQueue, sequenceEn func(interface{}) ([]byte, error), msgEn func([]byte, []byte) []byte, secretKey []byte, msg TransferMsg) ([]byte, error) {
|
||||
if queue == nil {
|
||||
queue = stario.NewQueue()
|
||||
}
|
||||
env, err := wrapTransferMsgEnvelope(msg, sequenceEn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plain, err := sequenceEn(env)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
payload := msgEn(secretKey, plain)
|
||||
if payload == nil && len(plain) != 0 {
|
||||
return nil, errTransportPayloadEncryptFailed
|
||||
}
|
||||
return queue.BuildMessage(payload), nil
|
||||
}
|
||||
|
||||
func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msgDe func([]byte, []byte) []byte, secretKey []byte, payload []byte) (TransferMsg, error) {
|
||||
plain := msgDe(secretKey, payload)
|
||||
if plain == nil && len(payload) != 0 {
|
||||
return TransferMsg{}, errTransportPayloadDecryptFailed
|
||||
}
|
||||
value, err := sequenceDe(plain)
|
||||
if err != nil {
|
||||
return TransferMsg{}, err
|
||||
}
|
||||
env, ok := value.(Envelope)
|
||||
if !ok {
|
||||
return TransferMsg{}, errors.New("invalid signal envelope")
|
||||
}
|
||||
return unwrapTransferMsgEnvelope(env, sequenceDe)
|
||||
}
|
||||
|
||||
func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error {
|
||||
return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{})
|
||||
}
|
||||
|
||||
func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadline time.Time) error {
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return withRawConnWriteLockDeadline(conn, deadline, func(conn net.Conn) error {
|
||||
var header [bulkDedicatedRecordHeaderLen]byte
|
||||
copy(header[:4], bulkDedicatedRecordMagic)
|
||||
binary.BigEndian.PutUint32(header[4:8], uint32(len(payload)))
|
||||
buffers := net.Buffers{header[:], payload}
|
||||
_, err := buffers.WriteTo(conn)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) {
|
||||
if conn == nil {
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
var header [bulkDedicatedRecordHeaderLen]byte
|
||||
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if string(header[:4]) != bulkDedicatedRecordMagic {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
size := int(binary.BigEndian.Uint32(header[4:8]))
|
||||
if size < 0 {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
payload := make([]byte, size)
|
||||
if _, err := io.ReadFull(conn, payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) {
|
||||
source := c.clientConnectSourceSnapshot()
|
||||
if source != nil && source.canReconnect() {
|
||||
return source.dial(ctx)
|
||||
}
|
||||
conn := c.clientTransportConnSnapshot()
|
||||
if conn == nil || conn.RemoteAddr() == nil {
|
||||
return nil, errClientReconnectSourceUnavailable
|
||||
}
|
||||
return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bulkHandle) error {
|
||||
if c == nil || bulk == nil || !bulk.Dedicated() || bulk.dedicatedAttachedSnapshot() {
|
||||
return nil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout)
|
||||
defer cancel()
|
||||
conn, err := c.dialDedicatedBulkConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
if !resp.Accepted {
|
||||
_ = conn.Close()
|
||||
if resp.Error != "" {
|
||||
return errors.New(resp.Error)
|
||||
}
|
||||
return errors.New("bulk attach rejected")
|
||||
}
|
||||
if err := bulk.attachDedicatedConn(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
go c.readDedicatedBulkLoop(bulk, conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) {
|
||||
if c == nil {
|
||||
return bulkAttachResponse{}, errBulkClientNil
|
||||
}
|
||||
if bulk == nil {
|
||||
return bulkAttachResponse{}, errBulkIDEmpty
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.SetReadDeadline(time.Time{})
|
||||
}()
|
||||
reqPayload, err := c.sequenceEn(bulkAttachRequest{
|
||||
PeerID: c.ensureClientPeerIdentity(),
|
||||
BulkID: bulk.ID(),
|
||||
AttachToken: bulk.dedicatedAttachTokenSnapshot(),
|
||||
})
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
queue := stario.NewQueue()
|
||||
msg := TransferMsg{
|
||||
ID: atomic.AddUint64(&c.msgID, 1),
|
||||
Key: systemBulkAttachKey,
|
||||
Value: reqPayload,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}
|
||||
frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
if err := writeFullToConn(conn, frame); err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
replyCh := make(chan Message, 1)
|
||||
readBuf := streamReadBuffer()
|
||||
for {
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
_ = conn.SetReadDeadline(deadline)
|
||||
}
|
||||
n, err := conn.Read(readBuf)
|
||||
if err != nil {
|
||||
return bulkAttachResponse{}, err
|
||||
}
|
||||
parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error {
|
||||
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
replyCh <- Message{
|
||||
ServerConn: c,
|
||||
TransferMsg: transfer,
|
||||
NetType: NET_CLIENT,
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if parseErr != nil {
|
||||
return bulkAttachResponse{}, parseErr
|
||||
}
|
||||
select {
|
||||
case reply := <-replyCh:
|
||||
return decodeBulkAttachResponse(c.sequenceDe, reply.Value)
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) {
|
||||
for {
|
||||
payload, err := readBulkDedicatedRecord(conn)
|
||||
if err != nil {
|
||||
handleDedicatedBulkReadError(bulk, err)
|
||||
return
|
||||
}
|
||||
plain, err := c.decryptTransportPayload(payload)
|
||||
if err != nil {
|
||||
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
|
||||
bulk.markReset(err)
|
||||
return
|
||||
}
|
||||
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
|
||||
if err != nil {
|
||||
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
|
||||
bulk.markReset(err)
|
||||
return
|
||||
}
|
||||
for _, item := range items {
|
||||
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
|
||||
bulk.markReset(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if bulk.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool {
|
||||
if message.Key != systemBulkAttachKey {
|
||||
return false
|
||||
}
|
||||
current := messageLogicalConnSnapshot(&message)
|
||||
resp := bulkAttachResponse{}
|
||||
var (
|
||||
req bulkAttachRequest
|
||||
logical *LogicalConn
|
||||
bulk *bulkHandle
|
||||
err error
|
||||
)
|
||||
req, err = decodeBulkAttachRequest(s.sequenceDe, message.Value)
|
||||
if err == nil {
|
||||
logical, bulk, err = s.resolveInboundDedicatedBulk(current, req)
|
||||
}
|
||||
if err != nil {
|
||||
resp.Error = err.Error()
|
||||
} else {
|
||||
resp.Accepted = true
|
||||
}
|
||||
if current != nil {
|
||||
_ = s.replyDedicatedBulkAttach(current, message, resp)
|
||||
}
|
||||
if err == nil {
|
||||
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil {
|
||||
bulk.markReset(attachErr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) {
|
||||
if s == nil {
|
||||
return nil, nil, errBulkServerNil
|
||||
}
|
||||
if current == nil {
|
||||
return nil, nil, errBulkLogicalConnNil
|
||||
}
|
||||
if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" {
|
||||
return nil, nil, errBulkIDEmpty
|
||||
}
|
||||
logical := s.GetLogicalConn(req.PeerID)
|
||||
if logical == nil {
|
||||
return nil, nil, errBulkLogicalConnNil
|
||||
}
|
||||
runtime := s.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return nil, nil, errBulkRuntimeNil
|
||||
}
|
||||
bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID)
|
||||
if !ok {
|
||||
return nil, nil, errBulkNotFound
|
||||
}
|
||||
if !bulk.Dedicated() {
|
||||
return nil, nil, errors.New("bulk is not dedicated")
|
||||
}
|
||||
if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken {
|
||||
return nil, nil, errors.New("bulk attach token mismatch")
|
||||
}
|
||||
return logical, bulk, nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error {
|
||||
if current == nil || logical == nil || bulk == nil {
|
||||
return errBulkLogicalConnNil
|
||||
}
|
||||
conn, err := current.detachTransportForTransfer()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bulk.attachDedicatedConn(conn); err != nil {
|
||||
if conn != nil {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
go s.readDedicatedBulkLoop(logical, bulk, conn)
|
||||
current.markSessionStopped("bulk dedicated attach", nil)
|
||||
s.removeLogical(current)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error {
|
||||
if s == nil || client == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
encoded, err := s.sequenceEn(resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
reply := TransferMsg{
|
||||
ID: message.ID,
|
||||
Key: systemBulkAttachKey,
|
||||
Value: encoded,
|
||||
Type: MSG_SYS_REPLY,
|
||||
}
|
||||
if message.inboundConn != nil {
|
||||
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
|
||||
}
|
||||
_, err = s.sendLogical(client, reply)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) {
|
||||
for {
|
||||
payload, err := readBulkDedicatedRecord(conn)
|
||||
if err != nil {
|
||||
handleDedicatedBulkReadError(bulk, err)
|
||||
return
|
||||
}
|
||||
plain, err := s.decryptTransportPayloadLogical(logical, payload)
|
||||
if err != nil {
|
||||
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
|
||||
bulk.markReset(err)
|
||||
return
|
||||
}
|
||||
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
|
||||
if err != nil {
|
||||
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
|
||||
bulk.markReset(err)
|
||||
return
|
||||
}
|
||||
for _, item := range items {
|
||||
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
|
||||
bulk.markReset(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if bulk.Context().Err() != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleDedicatedBulkReadError(bulk *bulkHandle, err error) {
|
||||
if bulk == nil {
|
||||
return
|
||||
}
|
||||
if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() {
|
||||
return
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||
if bulk.Dedicated() || bulk.localClosedSnapshot() {
|
||||
bulk.markRemoteClosed()
|
||||
return
|
||||
}
|
||||
}
|
||||
bulk.markReset(transportDetachedError("dedicated bulk read error", err))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSender, error) {
|
||||
if c == nil || bulk == nil {
|
||||
return nil, errBulkClientNil
|
||||
}
|
||||
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
|
||||
return sender, nil
|
||||
}
|
||||
conn := bulk.dedicatedConnSnapshot()
|
||||
if conn == nil {
|
||||
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||
}
|
||||
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items)
|
||||
}, func(err error) {
|
||||
bulk.markReset(err)
|
||||
})
|
||||
actual := bulk.installDedicatedSender(sender)
|
||||
if actual != sender {
|
||||
sender.stop()
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
|
||||
if c == nil || bulk == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
sender, err := c.dedicatedBulkSender(bulk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
|
||||
if c == nil || bulk == nil {
|
||||
return 0, errBulkClientNil
|
||||
}
|
||||
sender, err := c.dedicatedBulkSender(bulk)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
|
||||
if c == nil || bulk == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
flags := uint8(0)
|
||||
if full {
|
||||
flags = bulkFastPayloadFlagFullClose
|
||||
}
|
||||
sender, err := c.dedicatedBulkSender(bulk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error {
|
||||
if c == nil || bulk == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
sender, err := c.dedicatedBulkSender(bulk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error {
|
||||
if c == nil || bulk == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
|
||||
Type: bulkFastPayloadTypeRelease,
|
||||
Payload: payload,
|
||||
}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn := bulk.dedicatedConnSnapshot()
|
||||
if conn == nil {
|
||||
return transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||
}
|
||||
deadline, _ := sendCtx.Deadline()
|
||||
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) {
|
||||
if s == nil || bulk == nil {
|
||||
return nil, errBulkServerNil
|
||||
}
|
||||
if logical == nil {
|
||||
logical = bulk.LogicalConn()
|
||||
}
|
||||
if logical == nil {
|
||||
return nil, errBulkLogicalConnNil
|
||||
}
|
||||
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
|
||||
return sender, nil
|
||||
}
|
||||
conn := bulk.dedicatedConnSnapshot()
|
||||
if conn == nil {
|
||||
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||
}
|
||||
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), func(plain []byte) ([]byte, error) {
|
||||
return s.encryptTransportPayloadLogical(logical, plain)
|
||||
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items)
|
||||
}, func(err error) {
|
||||
bulk.markReset(err)
|
||||
})
|
||||
actual := bulk.installDedicatedSender(sender)
|
||||
if actual != sender {
|
||||
sender.stop()
|
||||
}
|
||||
return actual, nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error {
|
||||
if s == nil || bulk == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) {
|
||||
if s == nil || bulk == nil {
|
||||
return 0, errBulkServerNil
|
||||
}
|
||||
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
|
||||
if s == nil || bulk == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
flags := uint8(0)
|
||||
if full {
|
||||
flags = bulkFastPayloadFlagFullClose
|
||||
}
|
||||
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error {
|
||||
if s == nil || bulk == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error {
|
||||
if s == nil || bulk == nil {
|
||||
return errBulkServerNil
|
||||
}
|
||||
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cancel()
|
||||
frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
|
||||
Type: bulkFastPayloadTypeRelease,
|
||||
Payload: payload,
|
||||
}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn := bulk.dedicatedConnSnapshot()
|
||||
if conn == nil {
|
||||
return transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||
}
|
||||
deadline, _ := sendCtx.Deadline()
|
||||
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
if c == nil {
|
||||
return nil, errBulkClientNil
|
||||
}
|
||||
if c.fastPlainEncode != nil {
|
||||
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items)
|
||||
}
|
||||
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.encryptTransportPayload(plain)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
if s == nil {
|
||||
return nil, errBulkServerNil
|
||||
}
|
||||
if logical == nil {
|
||||
return nil, errBulkLogicalConnNil
|
||||
}
|
||||
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
|
||||
return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items)
|
||||
}
|
||||
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.encryptTransportPayloadLogical(logical, plain)
|
||||
}
|
||||
@@ -0,0 +1,663 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
bulkDedicatedBatchMagic = "NBD2"
|
||||
bulkDedicatedBatchVersion = 1
|
||||
bulkDedicatedBatchHeaderLen = 20
|
||||
bulkDedicatedBatchItemHeaderLen = 16
|
||||
bulkDedicatedBatchMaxItems = 32
|
||||
bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024
|
||||
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems
|
||||
bulkDedicatedReleasePayloadLen = 12
|
||||
)
|
||||
|
||||
const (
|
||||
bulkDedicatedRequestQueued int32 = iota
|
||||
bulkDedicatedRequestStarted
|
||||
bulkDedicatedRequestCanceled
|
||||
)
|
||||
|
||||
type bulkDedicatedRequestState struct {
|
||||
value atomic.Int32
|
||||
}
|
||||
|
||||
type bulkDedicatedBatchItem struct {
|
||||
Type uint8
|
||||
Flags uint8
|
||||
Seq uint64
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
type bulkDedicatedSendRequest struct {
|
||||
Type uint8
|
||||
Flags uint8
|
||||
Seq uint64
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
type bulkDedicatedBatchRequest struct {
|
||||
Ctx context.Context
|
||||
Items []bulkDedicatedSendRequest
|
||||
Deadline time.Time
|
||||
Ack chan error
|
||||
State *bulkDedicatedRequestState
|
||||
}
|
||||
|
||||
type bulkDedicatedSender struct {
|
||||
conn net.Conn
|
||||
dataID uint64
|
||||
encrypt func([]byte) ([]byte, error)
|
||||
encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error)
|
||||
fail func(error)
|
||||
|
||||
reqCh chan bulkDedicatedBatchRequest
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
flushMu sync.Mutex
|
||||
queued atomic.Int64
|
||||
|
||||
errMu sync.Mutex
|
||||
err error
|
||||
}
|
||||
|
||||
func newBulkDedicatedSender(conn net.Conn, dataID uint64, encrypt func([]byte) ([]byte, error), encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error), fail func(error)) *bulkDedicatedSender {
|
||||
sender := &bulkDedicatedSender{
|
||||
conn: conn,
|
||||
dataID: dataID,
|
||||
encrypt: encrypt,
|
||||
encodeBatch: encodeBatch,
|
||||
fail: fail,
|
||||
reqCh: make(chan bulkDedicatedBatchRequest, bulkDedicatedSendQueueSize),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
go sender.run()
|
||||
return sender
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) submitData(ctx context.Context, seq uint64, payload []byte) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
items := []bulkDedicatedSendRequest{{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: seq,
|
||||
Payload: append([]byte(nil), payload...),
|
||||
}}
|
||||
return s.submitBatch(ctx, items, false)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) submitWrite(ctx context.Context, startSeq uint64, payload []byte, chunkSize int) (int, error) {
|
||||
if s == nil {
|
||||
return 0, errTransportDetached
|
||||
}
|
||||
if len(payload) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = defaultBulkChunkSize
|
||||
}
|
||||
written := 0
|
||||
seq := startSeq
|
||||
for written < len(payload) {
|
||||
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
|
||||
items := itemBuf[:0]
|
||||
batchBytes := bulkDedicatedBatchHeaderLen
|
||||
start := written
|
||||
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
|
||||
end := written + chunkSize
|
||||
if end > len(payload) {
|
||||
end = len(payload)
|
||||
}
|
||||
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
|
||||
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
|
||||
break
|
||||
}
|
||||
items = append(items, bulkDedicatedSendRequest{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: seq,
|
||||
Payload: payload[written:end],
|
||||
})
|
||||
batchBytes += itemLen
|
||||
seq++
|
||||
written = end
|
||||
}
|
||||
if len(items) == 0 {
|
||||
end := written + chunkSize
|
||||
if end > len(payload) {
|
||||
end = len(payload)
|
||||
}
|
||||
items = append(items, bulkDedicatedSendRequest{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: seq,
|
||||
Payload: payload[written:end],
|
||||
})
|
||||
seq++
|
||||
written = end
|
||||
}
|
||||
if err := s.submitWriteBatch(ctx, items); err != nil {
|
||||
return start, err
|
||||
}
|
||||
start = written
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulkDedicatedSendRequest) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return nil
|
||||
}
|
||||
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
|
||||
return err
|
||||
}
|
||||
queuedItems := make([]bulkDedicatedSendRequest, len(items))
|
||||
copy(queuedItems, items)
|
||||
return s.submitBatch(ctx, queuedItems, true)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) submitControl(ctx context.Context, frameType uint8, flags uint8, seq uint64, payload []byte) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
items := []bulkDedicatedSendRequest{{
|
||||
Type: frameType,
|
||||
Flags: flags,
|
||||
Seq: seq,
|
||||
}}
|
||||
if len(payload) > 0 {
|
||||
items[0].Payload = append([]byte(nil), payload...)
|
||||
}
|
||||
return s.submitBatch(ctx, items, true)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) submitBatch(ctx context.Context, items []bulkDedicatedSendRequest, wait bool) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return err
|
||||
}
|
||||
req := bulkDedicatedBatchRequest{
|
||||
Ctx: ctx,
|
||||
Items: items,
|
||||
State: &bulkDedicatedRequestState{},
|
||||
}
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
req.Deadline = deadline
|
||||
}
|
||||
if wait {
|
||||
req.Ack = make(chan error, 1)
|
||||
}
|
||||
s.queued.Add(1)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.queued.Add(-1)
|
||||
return normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-s.stopCh:
|
||||
s.queued.Add(-1)
|
||||
return s.stoppedErr()
|
||||
case s.reqCh <- req:
|
||||
if !wait {
|
||||
return nil
|
||||
}
|
||||
return s.waitAck(req)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) tryDirectSubmitBatch(ctx context.Context, items []bulkDedicatedSendRequest) (bool, error) {
|
||||
if s == nil {
|
||||
return true, errTransportDetached
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if len(items) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return true, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true, normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-s.stopCh:
|
||||
return true, s.stoppedErr()
|
||||
default:
|
||||
}
|
||||
if s.queued.Load() != 0 {
|
||||
return false, nil
|
||||
}
|
||||
if !s.flushMu.TryLock() {
|
||||
return false, nil
|
||||
}
|
||||
defer s.flushMu.Unlock()
|
||||
if s.queued.Load() != 0 {
|
||||
return false, nil
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return true, err
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return true, normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-s.stopCh:
|
||||
return true, s.stoppedErr()
|
||||
default:
|
||||
}
|
||||
deadline, _ := ctx.Deadline()
|
||||
if err := s.flush(items, deadline); err != nil {
|
||||
err = normalizeDedicatedBulkSendError(err)
|
||||
s.setErr(err)
|
||||
s.failPending(err)
|
||||
if s.fail != nil {
|
||||
go s.fail(err)
|
||||
}
|
||||
return true, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) waitAck(req bulkDedicatedBatchRequest) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
ctx := req.Ctx
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case err := <-req.Ack:
|
||||
return normalizeDedicatedBulkSendError(err)
|
||||
case <-ctx.Done():
|
||||
if req.tryCancel() {
|
||||
return normalizeStreamDeadlineError(ctx.Err())
|
||||
}
|
||||
return normalizeDedicatedBulkSendError(<-req.Ack)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
s.setErr(errTransportDetached)
|
||||
close(s.stopCh)
|
||||
})
|
||||
<-s.doneCh
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) run() {
|
||||
defer close(s.doneCh)
|
||||
|
||||
for {
|
||||
req, ok := s.nextRequest()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !req.tryStart() {
|
||||
s.finishRequest(req, req.canceledErr())
|
||||
continue
|
||||
}
|
||||
if err := req.contextErr(); err != nil {
|
||||
s.finishRequest(req, err)
|
||||
continue
|
||||
}
|
||||
s.flushMu.Lock()
|
||||
err := s.errSnapshot()
|
||||
if err == nil {
|
||||
err = s.flush(req.Items, req.Deadline)
|
||||
}
|
||||
s.flushMu.Unlock()
|
||||
if err != nil {
|
||||
err = normalizeDedicatedBulkSendError(err)
|
||||
s.setErr(err)
|
||||
s.finishRequest(req, err)
|
||||
s.failPending(err)
|
||||
if s.fail != nil {
|
||||
go s.fail(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
s.finishRequest(req, nil)
|
||||
}
|
||||
}
|
||||
|
||||
func (r bulkDedicatedBatchRequest) contextErr() error {
|
||||
if r.Ctx == nil {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-r.Ctx.Done():
|
||||
return normalizeStreamDeadlineError(r.Ctx.Err())
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r bulkDedicatedBatchRequest) tryStart() bool {
|
||||
if r.State == nil {
|
||||
return true
|
||||
}
|
||||
return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted)
|
||||
}
|
||||
|
||||
func (r bulkDedicatedBatchRequest) tryCancel() bool {
|
||||
if r.State == nil {
|
||||
return false
|
||||
}
|
||||
return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled)
|
||||
}
|
||||
|
||||
func (r bulkDedicatedBatchRequest) canceledErr() error {
|
||||
if err := r.contextErr(); err != nil {
|
||||
return err
|
||||
}
|
||||
return context.Canceled
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) nextRequest() (bulkDedicatedBatchRequest, bool) {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
s.failPending(s.stoppedErr())
|
||||
return bulkDedicatedBatchRequest{}, false
|
||||
case req := <-s.reqCh:
|
||||
return req, true
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) flush(batch []bulkDedicatedSendRequest, deadline time.Time) error {
|
||||
if s == nil || s.conn == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
var (
|
||||
payload []byte
|
||||
err error
|
||||
)
|
||||
if s.encodeBatch != nil {
|
||||
payload, err = s.encodeBatch(batch)
|
||||
} else {
|
||||
plain, plainErr := encodeBulkDedicatedBatchPlain(s.dataID, batch)
|
||||
if plainErr != nil {
|
||||
return plainErr
|
||||
}
|
||||
payload, err = s.encrypt(plain)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) ack(req bulkDedicatedBatchRequest, err error) {
|
||||
if req.Ack != nil {
|
||||
req.Ack <- err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) finishRequest(req bulkDedicatedBatchRequest, err error) {
|
||||
if s != nil {
|
||||
s.queued.Add(-1)
|
||||
}
|
||||
s.ack(req, err)
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) failPending(err error) {
|
||||
for {
|
||||
select {
|
||||
case item := <-s.reqCh:
|
||||
s.finishRequest(item, err)
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) setErr(err error) {
|
||||
if s == nil || err == nil {
|
||||
return
|
||||
}
|
||||
s.errMu.Lock()
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
s.errMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) errSnapshot() error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
s.errMu.Lock()
|
||||
defer s.errMu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *bulkDedicatedSender) stoppedErr() error {
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return err
|
||||
}
|
||||
return errTransportDetached
|
||||
}
|
||||
|
||||
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
|
||||
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
|
||||
}
|
||||
|
||||
func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int {
|
||||
return bulkDedicatedBatchItemHeaderLen + payloadLen
|
||||
}
|
||||
|
||||
func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) {
|
||||
if bytes <= 0 && chunks <= 0 {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
if chunks < 0 {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
payload := make([]byte, bulkDedicatedReleasePayloadLen)
|
||||
binary.BigEndian.PutUint64(payload[:8], uint64(bytes))
|
||||
binary.BigEndian.PutUint32(payload[8:12], uint32(chunks))
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) {
|
||||
if len(payload) != bulkDedicatedReleasePayloadLen {
|
||||
return 0, 0, errBulkFastPayloadInvalid
|
||||
}
|
||||
bytes := int64(binary.BigEndian.Uint64(payload[:8]))
|
||||
chunks := int(binary.BigEndian.Uint32(payload[8:12]))
|
||||
if bytes <= 0 && chunks <= 0 {
|
||||
return 0, 0, errBulkFastPayloadInvalid
|
||||
}
|
||||
return bytes, chunks, nil
|
||||
}
|
||||
|
||||
func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
if dataID == 0 || len(items) == 0 {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
total := bulkDedicatedBatchPlainLen(items)
|
||||
buf := make([]byte, total)
|
||||
if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
if encode == nil {
|
||||
return nil, errTransportPayloadEncryptFailed
|
||||
}
|
||||
plainLen := bulkDedicatedBatchPlainLen(items)
|
||||
return encode(secretKey, plainLen, func(dst []byte) error {
|
||||
return writeBulkDedicatedBatchPlain(dst, dataID, items)
|
||||
})
|
||||
}
|
||||
|
||||
func bulkDedicatedBatchPlainLen(items []bulkDedicatedSendRequest) int {
|
||||
total := bulkDedicatedBatchHeaderLen
|
||||
for _, item := range items {
|
||||
total += bulkDedicatedSendRequestLen(item)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func writeBulkDedicatedBatchPlain(buf []byte, dataID uint64, items []bulkDedicatedSendRequest) error {
|
||||
if dataID == 0 || len(items) == 0 {
|
||||
return errBulkFastPayloadInvalid
|
||||
}
|
||||
if len(buf) != bulkDedicatedBatchPlainLen(items) {
|
||||
return errBulkFastPayloadInvalid
|
||||
}
|
||||
copy(buf[:4], bulkDedicatedBatchMagic)
|
||||
buf[4] = bulkDedicatedBatchVersion
|
||||
binary.BigEndian.PutUint64(buf[8:16], dataID)
|
||||
binary.BigEndian.PutUint32(buf[16:20], uint32(len(items)))
|
||||
offset := bulkDedicatedBatchHeaderLen
|
||||
for _, item := range items {
|
||||
buf[offset] = item.Type
|
||||
buf[offset+1] = item.Flags
|
||||
binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq)
|
||||
binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload)))
|
||||
offset += bulkDedicatedBatchItemHeaderLen
|
||||
copy(buf[offset:offset+len(item.Payload)], item.Payload)
|
||||
offset += len(item.Payload)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatchItem, bool, error) {
|
||||
if len(payload) < 4 || string(payload[:4]) != bulkDedicatedBatchMagic {
|
||||
return 0, nil, false, nil
|
||||
}
|
||||
if len(payload) < bulkDedicatedBatchHeaderLen {
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
if payload[4] != bulkDedicatedBatchVersion {
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
dataID := binary.BigEndian.Uint64(payload[8:16])
|
||||
count := int(binary.BigEndian.Uint32(payload[16:20]))
|
||||
if dataID == 0 || count <= 0 {
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
items := make([]bulkDedicatedBatchItem, 0, count)
|
||||
offset := bulkDedicatedBatchHeaderLen
|
||||
for i := 0; i < count; i++ {
|
||||
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
itemType := payload[offset]
|
||||
switch itemType {
|
||||
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
|
||||
default:
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
flags := payload[offset+1]
|
||||
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
|
||||
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
|
||||
offset += bulkDedicatedBatchItemHeaderLen
|
||||
if dataLen < 0 || len(payload)-offset < dataLen {
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
items = append(items, bulkDedicatedBatchItem{
|
||||
Type: itemType,
|
||||
Flags: flags,
|
||||
Seq: seq,
|
||||
Payload: payload[offset : offset+dataLen],
|
||||
})
|
||||
offset += dataLen
|
||||
}
|
||||
if offset != len(payload) {
|
||||
return 0, nil, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
return dataID, items, true, nil
|
||||
}
|
||||
|
||||
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) {
|
||||
if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expectedDataID == 0 || dataID != expectedDataID {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
frame, matched, err := decodeBulkFastFrame(plain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !matched || expectedDataID == 0 || frame.DataID != expectedDataID {
|
||||
return nil, errBulkFastPayloadInvalid
|
||||
}
|
||||
return []bulkDedicatedBatchItem{{
|
||||
Type: frame.Type,
|
||||
Flags: frame.Flags,
|
||||
Seq: frame.Seq,
|
||||
Payload: frame.Payload,
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func normalizeDedicatedBulkSendError(err error) error {
|
||||
switch {
|
||||
case err == nil:
|
||||
return nil
|
||||
case errors.Is(err, net.ErrClosed):
|
||||
return errTransportDetached
|
||||
default:
|
||||
return normalizeStreamDeadlineError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error {
|
||||
if bulk == nil {
|
||||
return io.ErrClosedPipe
|
||||
}
|
||||
switch item.Type {
|
||||
case bulkFastPayloadTypeData:
|
||||
return bulk.pushOwnedChunkNoReset(item.Payload)
|
||||
case bulkFastPayloadTypeClose:
|
||||
if item.Flags&bulkFastPayloadFlagFullClose != 0 {
|
||||
bulk.markPeerClosed()
|
||||
return nil
|
||||
}
|
||||
bulk.markRemoteClosed()
|
||||
return nil
|
||||
case bulkFastPayloadTypeReset:
|
||||
resetErr := errBulkReset
|
||||
if len(item.Payload) > 0 {
|
||||
resetErr = bulkRemoteResetError(string(item.Payload))
|
||||
}
|
||||
bulk.markReset(bulkResetError(resetErr))
|
||||
return nil
|
||||
case bulkFastPayloadTypeRelease:
|
||||
bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bulk.releaseOutboundWindow(bytes, chunks)
|
||||
return nil
|
||||
default:
|
||||
return errBulkFastPayloadInvalid
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) {
|
||||
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err)
|
||||
}
|
||||
items := []bulkDedicatedSendRequest{
|
||||
{
|
||||
Type: bulkFastPayloadTypeData,
|
||||
Seq: 7,
|
||||
Payload: []byte("hello"),
|
||||
},
|
||||
{
|
||||
Type: bulkFastPayloadTypeClose,
|
||||
Flags: bulkFastPayloadFlagFullClose,
|
||||
},
|
||||
{
|
||||
Type: bulkFastPayloadTypeReset,
|
||||
Payload: []byte("boom"),
|
||||
},
|
||||
{
|
||||
Type: bulkFastPayloadTypeRelease,
|
||||
Payload: releasePayload,
|
||||
},
|
||||
}
|
||||
|
||||
plain, err := encodeBulkDedicatedBatchPlain(42, items)
|
||||
if err != nil {
|
||||
t.Fatalf("encodeBulkDedicatedBatchPlain failed: %v", err)
|
||||
}
|
||||
|
||||
dataID, decoded, matched, err := decodeBulkDedicatedBatchPlain(plain)
|
||||
if err != nil {
|
||||
t.Fatalf("decodeBulkDedicatedBatchPlain failed: %v", err)
|
||||
}
|
||||
if !matched {
|
||||
t.Fatal("decodeBulkDedicatedBatchPlain should match dedicated batch")
|
||||
}
|
||||
if dataID != 42 {
|
||||
t.Fatalf("decoded data id = %d, want 42", dataID)
|
||||
}
|
||||
if len(decoded) != len(items) {
|
||||
t.Fatalf("decoded item count = %d, want %d", len(decoded), len(items))
|
||||
}
|
||||
|
||||
for i := range items {
|
||||
if decoded[i].Type != items[i].Type {
|
||||
t.Fatalf("item %d type = %d, want %d", i, decoded[i].Type, items[i].Type)
|
||||
}
|
||||
if decoded[i].Flags != items[i].Flags {
|
||||
t.Fatalf("item %d flags = %d, want %d", i, decoded[i].Flags, items[i].Flags)
|
||||
}
|
||||
if decoded[i].Seq != items[i].Seq {
|
||||
t.Fatalf("item %d seq = %d, want %d", i, decoded[i].Seq, items[i].Seq)
|
||||
}
|
||||
if got, want := string(decoded[i].Payload), string(items[i].Payload); got != want {
|
||||
t.Fatalf("item %d payload = %q, want %q", i, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkOpenRoundTripDedicatedMultiWriteTCP(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
acceptCh := make(chan BulkAcceptInfo, 1)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
acceptCh <- info
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
t.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||
t.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: 1024,
|
||||
},
|
||||
Dedicated: true,
|
||||
ChunkSize: 4,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("client OpenBulk dedicated failed: %v", err)
|
||||
}
|
||||
|
||||
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
|
||||
|
||||
clientParts := []string{"aa", "bb", "cc", "dd", "ee", "ff"}
|
||||
for _, part := range clientParts {
|
||||
if _, err := bulk.Write([]byte(part)); err != nil {
|
||||
t.Fatalf("client dedicated bulk Write(%q) failed: %v", part, err)
|
||||
}
|
||||
}
|
||||
readBulkExactly(t, accepted.Bulk, "aabbccddeeff", 2*time.Second)
|
||||
|
||||
serverParts := []string{"11", "22", "33", "44", "55", "66"}
|
||||
for _, part := range serverParts {
|
||||
if _, err := accepted.Bulk.Write([]byte(part)); err != nil {
|
||||
t.Fatalf("server dedicated bulk Write(%q) failed: %v", part, err)
|
||||
}
|
||||
}
|
||||
readBulkExactly(t, bulk, "112233445566", 2*time.Second)
|
||||
|
||||
if err := bulk.CloseWrite(); err != nil {
|
||||
t.Fatalf("client dedicated bulk CloseWrite failed: %v", err)
|
||||
}
|
||||
waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second)
|
||||
|
||||
if err := accepted.Bulk.Close(); err != nil {
|
||||
t.Fatalf("server dedicated bulk Close failed: %v", err)
|
||||
}
|
||||
waitForBulkReadEOF(t, bulk, 2*time.Second)
|
||||
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
|
||||
}
|
||||
|
||||
func TestBulkDedicatedSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) {
|
||||
return plain, nil
|
||||
}, nil, nil)
|
||||
defer sender.stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 0, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err == nil {
|
||||
t.Fatal("sender.submitControl should fail when receiver stalls")
|
||||
}
|
||||
if !isTimeoutLikeError(err) {
|
||||
t.Fatalf("sender.submitControl error = %v, want timeout-like error", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("sender.submitControl should not hang when receiver stalls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkDedicatedSenderSubmitWriteDirectPathRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) {
|
||||
return plain, nil
|
||||
}, nil, nil)
|
||||
defer sender.stop()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
payload := make([]byte, 256*1024)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := sender.submitWrite(ctx, 1, payload, len(payload))
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err == nil {
|
||||
t.Fatal("sender.submitWrite should fail when receiver stalls")
|
||||
}
|
||||
if !isTimeoutLikeError(err) {
|
||||
t.Fatalf("sender.submitWrite error = %v, want timeout-like error", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("sender.submitWrite should not hang when receiver stalls")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkDedicatedSenderSkipsQueuedCanceledRequest(t *testing.T) {
|
||||
conn := newBlockingPacketWriteConn()
|
||||
sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) {
|
||||
return plain, nil
|
||||
}, nil, nil)
|
||||
defer sender.stop()
|
||||
|
||||
firstErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
firstErrCh <- sender.submitControl(context.Background(), bulkFastPayloadTypeClose, 0, 1, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-conn.startCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first dedicated bulk write did not start")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
secondErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
secondErrCh <- sender.submitControl(ctx, bulkFastPayloadTypeReset, 0, 2, nil)
|
||||
}()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-secondErrCh:
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("second dedicated bulk submit error = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("second dedicated bulk submit did not return after cancel")
|
||||
}
|
||||
|
||||
close(conn.unblockCh)
|
||||
|
||||
select {
|
||||
case err := <-firstErrCh:
|
||||
if err != nil {
|
||||
t.Fatalf("first dedicated bulk submit failed: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first dedicated bulk submit did not finish")
|
||||
}
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
if got, want := conn.writeCount.Load(), int32(2); got != want {
|
||||
t.Fatalf("dedicated bulk write count = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBulkDedicatedSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) {
|
||||
conn := newBlockingPacketWriteConn()
|
||||
sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) {
|
||||
return plain, nil
|
||||
}, nil, nil)
|
||||
defer sender.stop()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 1, nil)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-conn.startCh:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("dedicated bulk write did not start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("sender.submitControl returned before flush completed: %v", err)
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(conn.unblockCh)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
t.Fatalf("sender.submitControl failed after started flush: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("sender.submitControl did not return after started flush completed")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const bulkDispatchRejectTimeout = 300 * time.Millisecond
|
||||
|
||||
func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
|
||||
if frame.DataID == 0 {
|
||||
return
|
||||
}
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
bulk, ok := runtime.lookupByDataID(clientFileScope(), frame.DataID)
|
||||
if !ok {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client bulk data for unknown data id", frame.DataID)
|
||||
}
|
||||
c.bestEffortRejectInboundBulkData("", frame.DataID, errBulkNotFound.Error())
|
||||
return
|
||||
}
|
||||
if !bulk.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client bulk data rejected by stale session epoch", frame.DataID)
|
||||
}
|
||||
detachErr := transportDetachedSessionEpochError()
|
||||
bulk.markReset(detachErr)
|
||||
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, detachErr.Error())
|
||||
return
|
||||
}
|
||||
switch frame.Type {
|
||||
case bulkFastPayloadTypeData:
|
||||
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client bulk push chunk error", err)
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error())
|
||||
}
|
||||
}
|
||||
case bulkFastPayloadTypeClose:
|
||||
if frame.Flags&bulkFastPayloadFlagFullClose != 0 {
|
||||
bulk.markPeerClosed()
|
||||
return
|
||||
}
|
||||
bulk.markRemoteClosed()
|
||||
case bulkFastPayloadTypeReset:
|
||||
resetErr := errBulkReset
|
||||
if len(frame.Payload) > 0 {
|
||||
resetErr = bulkRemoteResetError(string(frame.Payload))
|
||||
}
|
||||
bulk.markReset(bulkResetError(resetErr))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) {
|
||||
c.dispatchFastBulkFrame(frame)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) {
|
||||
if logical == nil || frame.DataID == 0 {
|
||||
return
|
||||
}
|
||||
runtime := s.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
bulk, ok := runtime.lookupByDataID(serverFileScope(logical), frame.DataID)
|
||||
if !ok {
|
||||
if s.showError || s.debugMode {
|
||||
fmt.Println("server bulk data for unknown data id", frame.DataID)
|
||||
}
|
||||
s.bestEffortRejectInboundBulkData(logical, transport, conn, "", frame.DataID, errBulkNotFound.Error())
|
||||
return
|
||||
}
|
||||
if !bulk.acceptsTransportGeneration(transport) {
|
||||
if s.showError || s.debugMode {
|
||||
fmt.Println("server bulk data rejected by transport generation mismatch", frame.DataID)
|
||||
}
|
||||
detachErr := transportDetachedGenerationMismatchError(bulk.TransportGeneration(), transport)
|
||||
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, detachErr.Error())
|
||||
return
|
||||
}
|
||||
switch frame.Type {
|
||||
case bulkFastPayloadTypeData:
|
||||
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
|
||||
if s.showError || s.debugMode {
|
||||
fmt.Println("server bulk push chunk error", err)
|
||||
}
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error())
|
||||
}
|
||||
}
|
||||
case bulkFastPayloadTypeClose:
|
||||
if frame.Flags&bulkFastPayloadFlagFullClose != 0 {
|
||||
bulk.markPeerClosed()
|
||||
return
|
||||
}
|
||||
bulk.markRemoteClosed()
|
||||
case bulkFastPayloadTypeReset:
|
||||
resetErr := errBulkReset
|
||||
if len(frame.Payload) > 0 {
|
||||
resetErr = bulkRemoteResetError(string(frame.Payload))
|
||||
}
|
||||
bulk.markReset(bulkResetError(resetErr))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) {
|
||||
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) {
|
||||
if c == nil || (bulkID == "" && dataID == 0) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), bulkDispatchRejectTimeout)
|
||||
defer cancel()
|
||||
_, _ = sendBulkResetClient(ctx, c, BulkResetRequest{
|
||||
BulkID: bulkID,
|
||||
DataID: dataID,
|
||||
Error: message,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerCommon) bestEffortRejectInboundBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, bulkID string, dataID uint64, message string) {
|
||||
if s == nil || logical == nil || (bulkID == "" && dataID == 0) {
|
||||
return
|
||||
}
|
||||
payload, err := encode(BulkResetRequest{
|
||||
BulkID: bulkID,
|
||||
DataID: dataID,
|
||||
Error: message,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
||||
Key: BulkResetSignalKey,
|
||||
Value: payload,
|
||||
Type: MSG_ASYNC,
|
||||
}, s.sequenceEn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = s.sendEnvelopeInboundTransport(logical, transport, conn, env)
|
||||
}
|
||||
@@ -0,0 +1,350 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkBulkEndToEndThroughput(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
network string
|
||||
payloadSize int
|
||||
dedicated bool
|
||||
}{
|
||||
{
|
||||
name: "tcp_shared_1MiB",
|
||||
network: "tcp",
|
||||
payloadSize: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "tcp_dedicated_1MiB",
|
||||
network: "tcp",
|
||||
payloadSize: 1024 * 1024,
|
||||
dedicated: true,
|
||||
},
|
||||
{
|
||||
name: "unix_shared_1MiB",
|
||||
network: "unix",
|
||||
payloadSize: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "unix_dedicated_1MiB",
|
||||
network: "unix",
|
||||
payloadSize: 1024 * 1024,
|
||||
dedicated: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkEndToEndThroughputNetwork(b, tc.network, tc.payloadSize, tc.dedicated)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkBulkEndToEndThroughputConcurrent(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
network string
|
||||
payloadSize int
|
||||
concurrency int
|
||||
dedicated bool
|
||||
}{
|
||||
{
|
||||
name: "tcp_dedicated_4x1MiB",
|
||||
network: "tcp",
|
||||
payloadSize: 1024 * 1024,
|
||||
concurrency: 4,
|
||||
dedicated: true,
|
||||
},
|
||||
{
|
||||
name: "unix_dedicated_4x1MiB",
|
||||
network: "unix",
|
||||
payloadSize: 1024 * 1024,
|
||||
concurrency: 4,
|
||||
dedicated: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkBulkEndToEndThroughputConcurrentNetwork(b, tc.network, tc.payloadSize, tc.concurrency, tc.dedicated)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkBulkEndToEndThroughputNetwork(b *testing.B, network string, payloadSize int, dedicated bool) {
|
||||
b.Helper()
|
||||
if network == "unix" && runtime.GOOS == "windows" {
|
||||
b.Skip("unix socket is not available on windows")
|
||||
}
|
||||
|
||||
server := newBulkBenchmarkServer(b, network)
|
||||
client := newBulkBenchmarkClient(b, network, server)
|
||||
|
||||
totalBytes := int64(payloadSize)
|
||||
if b.N > 1 {
|
||||
totalBytes = int64(payloadSize) * int64(b.N)
|
||||
}
|
||||
bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: totalBytes,
|
||||
},
|
||||
ChunkSize: payloadSize,
|
||||
Dedicated: dedicated,
|
||||
})
|
||||
|
||||
drainDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := io.Copy(io.Discard, accepted.Bulk)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
drainDone <- err
|
||||
return
|
||||
}
|
||||
drainDone <- nil
|
||||
}()
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
n, err := bulk.Write(payload)
|
||||
if err != nil {
|
||||
b.Fatalf("bulk Write failed at iter %d: %v", i, err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
|
||||
}
|
||||
}
|
||||
if err := bulk.CloseWrite(); err != nil {
|
||||
b.Fatalf("bulk CloseWrite failed: %v", err)
|
||||
}
|
||||
select {
|
||||
case err := <-drainDone:
|
||||
if err != nil {
|
||||
b.Fatalf("server drain failed: %v", err)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
b.Fatal("timed out waiting for server drain")
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
_ = accepted.Bulk.Close()
|
||||
_ = bulk.Close()
|
||||
}
|
||||
|
||||
func benchmarkBulkEndToEndThroughputConcurrentNetwork(b *testing.B, network string, payloadSize int, concurrency int, dedicated bool) {
|
||||
b.Helper()
|
||||
if concurrency <= 0 {
|
||||
b.Fatal("concurrency must be > 0")
|
||||
}
|
||||
if network == "unix" && runtime.GOOS == "windows" {
|
||||
b.Skip("unix socket is not available on windows")
|
||||
}
|
||||
|
||||
server := newBulkBenchmarkServer(b, network)
|
||||
client := newBulkBenchmarkClient(b, network, server)
|
||||
|
||||
totalBytes := int64(payloadSize)
|
||||
if b.N > 1 {
|
||||
totalBytes = int64(payloadSize) * int64(b.N)
|
||||
}
|
||||
|
||||
bulks := make([]Bulk, 0, concurrency)
|
||||
acceptedBulks := make([]Bulk, 0, concurrency)
|
||||
for index := 0; index < concurrency; index++ {
|
||||
bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: int64(index) * totalBytes,
|
||||
Length: totalBytes,
|
||||
},
|
||||
ChunkSize: payloadSize,
|
||||
Dedicated: dedicated,
|
||||
})
|
||||
bulks = append(bulks, bulk)
|
||||
acceptedBulks = append(acceptedBulks, accepted.Bulk)
|
||||
}
|
||||
|
||||
drainDone := make(chan error, concurrency)
|
||||
for _, acceptedBulk := range acceptedBulks {
|
||||
bulk := acceptedBulk
|
||||
go func() {
|
||||
_, err := io.Copy(io.Discard, bulk)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
drainDone <- err
|
||||
return
|
||||
}
|
||||
drainDone <- nil
|
||||
}()
|
||||
}
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
b.ResetTimer()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, concurrency)
|
||||
for index, bulk := range bulks {
|
||||
count := b.N / concurrency
|
||||
if index < b.N%concurrency {
|
||||
count++
|
||||
}
|
||||
wg.Add(1)
|
||||
go func(bulk Bulk, count int) {
|
||||
defer wg.Done()
|
||||
for i := 0; i < count; i++ {
|
||||
n, err := bulk.Write(payload)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
if n != len(payload) {
|
||||
errCh <- errors.New("bulk write bytes mismatch")
|
||||
return
|
||||
}
|
||||
}
|
||||
}(bulk, count)
|
||||
}
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
for err := range errCh {
|
||||
if err != nil {
|
||||
b.Fatalf("concurrent bulk write failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
for index, bulk := range bulks {
|
||||
if err := bulk.CloseWrite(); err != nil {
|
||||
b.Fatalf("bulk %d CloseWrite failed: %v", index, err)
|
||||
}
|
||||
}
|
||||
for index := 0; index < concurrency; index++ {
|
||||
select {
|
||||
case err := <-drainDone:
|
||||
if err != nil {
|
||||
b.Fatalf("server drain failed: %v", err)
|
||||
}
|
||||
case <-time.After(15 * time.Second):
|
||||
b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency)
|
||||
}
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
for _, bulk := range acceptedBulks {
|
||||
_ = bulk.Close()
|
||||
}
|
||||
for _, bulk := range bulks {
|
||||
_ = bulk.Close()
|
||||
}
|
||||
}
|
||||
|
||||
type bulkBenchmarkServer struct {
|
||||
server *ServerCommon
|
||||
acceptCh chan BulkAcceptInfo
|
||||
addr string
|
||||
}
|
||||
|
||||
func newBulkBenchmarkServer(tb testing.TB, network string) bulkBenchmarkServer {
|
||||
tb.Helper()
|
||||
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
tb.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
if network == "udp" {
|
||||
if err := UseSignalReliabilityServer(server, bulkBenchmarkSignalReliabilityOptions()); err != nil {
|
||||
tb.Fatalf("UseSignalReliabilityServer failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
acceptCh := make(chan BulkAcceptInfo, 32)
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
acceptCh <- info
|
||||
return nil
|
||||
})
|
||||
|
||||
addr := bulkBenchmarkListenAddr(tb, network)
|
||||
if err := server.Listen(network, addr); err != nil {
|
||||
tb.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
_ = server.Stop()
|
||||
})
|
||||
|
||||
return bulkBenchmarkServer{
|
||||
server: server,
|
||||
acceptCh: acceptCh,
|
||||
addr: signalRoundTripServerAddr(server, addr),
|
||||
}
|
||||
}
|
||||
|
||||
func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkServer) *ClientCommon {
|
||||
tb.Helper()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
tb.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if network == "udp" {
|
||||
if err := UseSignalReliabilityClient(client, bulkBenchmarkSignalReliabilityOptions()); err != nil {
|
||||
tb.Fatalf("UseSignalReliabilityClient failed: %v", err)
|
||||
}
|
||||
}
|
||||
if err := client.Connect(network, server.addr); err != nil {
|
||||
tb.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
tb.Cleanup(func() {
|
||||
_ = client.Stop()
|
||||
})
|
||||
return client
|
||||
}
|
||||
|
||||
func openBenchmarkBulkPair(tb testing.TB, client *ClientCommon, acceptCh <-chan BulkAcceptInfo, opt BulkOpenOptions) (Bulk, BulkAcceptInfo) {
|
||||
tb.Helper()
|
||||
|
||||
bulk, err := client.OpenBulk(context.Background(), opt)
|
||||
if err != nil {
|
||||
tb.Fatalf("client OpenBulk failed: %v", err)
|
||||
}
|
||||
return bulk, waitBenchmarkAcceptedBulk(tb, acceptCh, 5*time.Second)
|
||||
}
|
||||
|
||||
func bulkBenchmarkListenAddr(tb testing.TB, network string) string {
|
||||
tb.Helper()
|
||||
switch network {
|
||||
case "unix":
|
||||
return filepath.Join(tb.TempDir(), "notify-bulk.sock")
|
||||
case "udp", "tcp":
|
||||
return "127.0.0.1:0"
|
||||
default:
|
||||
tb.Fatalf("unsupported benchmark network %q", network)
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func bulkBenchmarkSignalReliabilityOptions() *SignalReliabilityOptions {
|
||||
return &SignalReliabilityOptions{
|
||||
Enabled: true,
|
||||
AckTimeout: 3 * time.Second,
|
||||
SendRetry: 8,
|
||||
ReceiveCacheLimit: 512,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,280 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload")
|
||||
)
|
||||
|
||||
var bulkFastFrameScratchPool sync.Pool
|
||||
|
||||
const (
|
||||
bulkFastPayloadMagic = "NBF1"
|
||||
bulkFastPayloadVersion = 1
|
||||
bulkFastPayloadTypeData = 1
|
||||
bulkFastPayloadTypeClose = 2
|
||||
bulkFastPayloadTypeReset = 3
|
||||
bulkFastPayloadTypeRelease = 4
|
||||
bulkFastPayloadHeaderLen = 28
|
||||
bulkFastPayloadFlagFullClose = 1 << 0
|
||||
)
|
||||
|
||||
type bulkFastFrame struct {
|
||||
Type uint8
|
||||
Flags uint8
|
||||
DataID uint64
|
||||
Seq uint64
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
type bulkFastDataFrame = bulkFastFrame
|
||||
|
||||
func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error {
|
||||
if dataID == 0 {
|
||||
return errBulkDataIDEmpty
|
||||
}
|
||||
if len(dst) < bulkFastPayloadHeaderLen {
|
||||
return errBulkFastPayloadInvalid
|
||||
}
|
||||
copy(dst[:4], bulkFastPayloadMagic)
|
||||
dst[4] = bulkFastPayloadVersion
|
||||
dst[5] = frameType
|
||||
dst[6] = flags
|
||||
dst[7] = 0
|
||||
binary.BigEndian.PutUint64(dst[8:16], dataID)
|
||||
binary.BigEndian.PutUint64(dst[16:24], seq)
|
||||
binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen))
|
||||
return nil
|
||||
}
|
||||
|
||||
func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
|
||||
return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen)
|
||||
}
|
||||
|
||||
func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
|
||||
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(frame[bulkFastPayloadHeaderLen:], payload)
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
|
||||
if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(frame[bulkFastPayloadHeaderLen:], payload)
|
||||
return frame, nil
|
||||
}
|
||||
|
||||
func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) {
|
||||
if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic {
|
||||
return bulkFastFrame{}, false, nil
|
||||
}
|
||||
if len(payload) < bulkFastPayloadHeaderLen {
|
||||
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
if payload[4] != bulkFastPayloadVersion {
|
||||
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
switch payload[5] {
|
||||
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
|
||||
default:
|
||||
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
dataLen := int(binary.BigEndian.Uint32(payload[24:28]))
|
||||
if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen {
|
||||
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
dataID := binary.BigEndian.Uint64(payload[8:16])
|
||||
if dataID == 0 {
|
||||
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||
}
|
||||
return bulkFastFrame{
|
||||
Type: payload[5],
|
||||
Flags: payload[6],
|
||||
DataID: dataID,
|
||||
Seq: binary.BigEndian.Uint64(payload[16:24]),
|
||||
Payload: payload[bulkFastPayloadHeaderLen:],
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) {
|
||||
frame, matched, err := decodeBulkFastFrame(payload)
|
||||
if !matched || err != nil {
|
||||
return frame, matched, err
|
||||
}
|
||||
if frame.Type != bulkFastPayloadTypeData {
|
||||
return bulkFastDataFrame{}, false, nil
|
||||
}
|
||||
return frame, true, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
|
||||
if c != nil && c.fastBulkEncode != nil {
|
||||
return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk)
|
||||
}
|
||||
scratch := getBulkFastFrameScratch(len(chunk))
|
||||
defer putBulkFastFrameScratch(scratch)
|
||||
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
|
||||
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(frame[bulkFastPayloadHeaderLen:], chunk)
|
||||
return c.encryptTransportPayload(frame)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error {
|
||||
payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
binding := c.clientTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
|
||||
return sender.submit(ctx, payload)
|
||||
}
|
||||
return c.writePayloadToTransport(payload)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.encryptTransportPayload(plain)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
|
||||
if logical != nil {
|
||||
if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil {
|
||||
return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk)
|
||||
}
|
||||
}
|
||||
scratch := getBulkFastFrameScratch(len(chunk))
|
||||
defer putBulkFastFrameScratch(scratch)
|
||||
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
|
||||
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(frame[bulkFastPayloadHeaderLen:], chunk)
|
||||
return s.encryptTransportPayloadLogical(logical, frame)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error {
|
||||
if err := s.ensureServerTransportSendReady(transport); err != nil {
|
||||
return err
|
||||
}
|
||||
if logical == nil && transport != nil {
|
||||
logical = transport.logicalConnSnapshot()
|
||||
}
|
||||
if logical == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if binding := logical.transportBindingSnapshot(); binding != nil {
|
||||
if binding.queueSnapshot() != nil {
|
||||
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
|
||||
return sender.submit(ctx, payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
return s.writeEnvelopePayload(logical, transport, nil, payload)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.encryptTransportPayloadLogical(logical, plain)
|
||||
}
|
||||
|
||||
func getBulkFastFrameScratch(payloadLen int) []byte {
|
||||
need := bulkFastPayloadHeaderLen + payloadLen
|
||||
if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need {
|
||||
return buf[:need]
|
||||
}
|
||||
return make([]byte, need)
|
||||
}
|
||||
|
||||
func putBulkFastFrameScratch(buf []byte) {
|
||||
if cap(buf) == 0 || cap(buf) > 4*1024*1024 {
|
||||
return
|
||||
}
|
||||
bulkFastFrameScratchPool.Put(buf[:0])
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
|
||||
plain, err := c.decryptTransportPayload(payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frame, matched, err := decodeBulkFastFrame(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.dispatchFastBulkFrame(frame)
|
||||
return nil
|
||||
}
|
||||
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.dispatchFastStreamData(frame)
|
||||
return nil
|
||||
}
|
||||
env, err := c.decodeEnvelopePlain(plain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.dispatchEnvelope(env, now)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error {
|
||||
if logical == nil && transport != nil {
|
||||
logical = transport.logicalConnSnapshot()
|
||||
}
|
||||
if logical == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
plain, err := s.decryptTransportPayloadLogical(logical, payload)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if frame, matched, err := decodeBulkFastFrame(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||
return nil
|
||||
}
|
||||
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.dispatchFastStreamData(logical, transport, conn, frame)
|
||||
return nil
|
||||
}
|
||||
env, err := s.decodeEnvelopePlain(plain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.dispatchEnvelope(logical, transport, conn, env, now)
|
||||
return nil
|
||||
}
|
||||
+196
@@ -0,0 +1,196 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type bulkRuntime struct {
|
||||
rolePrefix string
|
||||
seq atomic.Uint64
|
||||
dataSeq atomic.Uint64
|
||||
|
||||
mu sync.RWMutex
|
||||
handler func(BulkAcceptInfo) error
|
||||
bulks map[string]*bulkHandle
|
||||
data map[string]*bulkHandle
|
||||
}
|
||||
|
||||
func newBulkRuntime(rolePrefix string) *bulkRuntime {
|
||||
return &bulkRuntime{
|
||||
rolePrefix: rolePrefix,
|
||||
bulks: make(map[string]*bulkHandle),
|
||||
data: make(map[string]*bulkHandle),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) nextID() string {
|
||||
if r == nil {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s-%d", r.rolePrefix, r.seq.Add(1))
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) nextDataID() uint64 {
|
||||
if r == nil {
|
||||
return 0
|
||||
}
|
||||
return r.dataSeq.Add(1)
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) setHandler(fn func(BulkAcceptInfo) error) {
|
||||
if r == nil {
|
||||
return
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
r.handler = fn
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) handlerSnapshot() func(BulkAcceptInfo) error {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.handler
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
|
||||
if r == nil {
|
||||
return errBulkRuntimeNil
|
||||
}
|
||||
if bulk == nil || bulk.id == "" {
|
||||
return errBulkIDEmpty
|
||||
}
|
||||
key := bulkRuntimeKey(scope, bulk.id)
|
||||
dataKey := bulkRuntimeDataKey(scope, bulk.dataID)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if _, ok := r.bulks[key]; ok {
|
||||
return errBulkAlreadyExists
|
||||
}
|
||||
if bulk.dataID == 0 {
|
||||
return errBulkDataIDEmpty
|
||||
}
|
||||
if _, ok := r.data[dataKey]; ok {
|
||||
return errBulkAlreadyExists
|
||||
}
|
||||
r.bulks[key] = bulk
|
||||
r.data[dataKey] = bulk
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) lookup(scope string, bulkID string) (*bulkHandle, bool) {
|
||||
if r == nil || bulkID == "" {
|
||||
return nil, false
|
||||
}
|
||||
key := bulkRuntimeKey(scope, bulkID)
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
bulk, ok := r.bulks[key]
|
||||
return bulk, ok
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle, bool) {
|
||||
if r == nil || dataID == 0 {
|
||||
return nil, false
|
||||
}
|
||||
key := bulkRuntimeDataKey(scope, dataID)
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
bulk, ok := r.data[key]
|
||||
return bulk, ok
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) remove(scope string, bulkID string) {
|
||||
if r == nil || bulkID == "" {
|
||||
return
|
||||
}
|
||||
key := bulkRuntimeKey(scope, bulkID)
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 {
|
||||
delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID))
|
||||
}
|
||||
delete(r.bulks, key)
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) closeAll(err error) {
|
||||
r.closeMatching(func(string) bool { return true }, err)
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) closeScope(scope string, err error) {
|
||||
scope = normalizeFileScope(scope)
|
||||
r.closeMatching(func(key string) bool {
|
||||
return strings.HasPrefix(key, scope+"\x00")
|
||||
}, err)
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) closeMatching(match func(string) bool, err error) {
|
||||
if r == nil || match == nil {
|
||||
return
|
||||
}
|
||||
resetErr := bulkRuntimeCloseError(err)
|
||||
r.mu.RLock()
|
||||
bulks := make([]*bulkHandle, 0, len(r.bulks))
|
||||
for key, bulk := range r.bulks {
|
||||
if bulk == nil || !match(key) {
|
||||
continue
|
||||
}
|
||||
bulks = append(bulks, bulk)
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
for _, bulk := range bulks {
|
||||
bulk.markReset(resetErr)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *bulkRuntime) snapshots() []BulkSnapshot {
|
||||
if r == nil {
|
||||
return nil
|
||||
}
|
||||
r.mu.RLock()
|
||||
snapshots := make([]BulkSnapshot, 0, len(r.bulks))
|
||||
for _, bulk := range r.bulks {
|
||||
if bulk == nil {
|
||||
continue
|
||||
}
|
||||
snapshots = append(snapshots, bulk.snapshot())
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
sortBulkSnapshots(snapshots)
|
||||
return snapshots
|
||||
}
|
||||
|
||||
func bulkRuntimeKey(scope string, bulkID string) string {
|
||||
return normalizeFileScope(scope) + "\x00" + bulkID
|
||||
}
|
||||
|
||||
func bulkRuntimeDataKey(scope string, dataID uint64) string {
|
||||
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
|
||||
}
|
||||
|
||||
func bulkRuntimeCloseError(err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return errServiceShutdown
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getBulkRuntime() *bulkRuntime {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.bulkRuntime
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getBulkRuntime() *bulkRuntime {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.bulkRuntime
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BulkSnapshot struct {
|
||||
ID string
|
||||
DataID uint64
|
||||
Scope string
|
||||
Range BulkRange
|
||||
Metadata BulkMetadata
|
||||
BindingOwner string
|
||||
BindingAlive bool
|
||||
BindingCurrent bool
|
||||
BindingReason string
|
||||
BindingError string
|
||||
Dedicated bool
|
||||
DedicatedAttached bool
|
||||
SessionEpoch uint64
|
||||
LogicalClientID string
|
||||
TransportGeneration uint64
|
||||
TransportAttached bool
|
||||
TransportHasRuntimeConn bool
|
||||
TransportCurrent bool
|
||||
TransportDetachReason string
|
||||
TransportDetachKind string
|
||||
TransportDetachGeneration uint64
|
||||
TransportDetachError string
|
||||
TransportDetachedAt time.Time
|
||||
ReattachEligible bool
|
||||
LocalClosed bool
|
||||
LocalReadClosed bool
|
||||
RemoteClosed bool
|
||||
PeerReadClosed bool
|
||||
BufferedChunks int
|
||||
BufferedBytes int
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
ChunkSize int
|
||||
WindowBytes int
|
||||
MaxInFlight int
|
||||
BytesRead int64
|
||||
BytesWritten int64
|
||||
ReadCalls int64
|
||||
WriteCalls int64
|
||||
OpenedAt time.Time
|
||||
LastReadAt time.Time
|
||||
LastWriteAt time.Time
|
||||
ResetError string
|
||||
}
|
||||
|
||||
type clientBulkSnapshotReader interface {
|
||||
clientBulkSnapshots() []BulkSnapshot
|
||||
}
|
||||
|
||||
type serverBulkSnapshotReader interface {
|
||||
serverBulkSnapshots() []BulkSnapshot
|
||||
}
|
||||
|
||||
var (
|
||||
errClientBulkSnapshotNil = errors.New("client bulk snapshot target is nil")
|
||||
errServerBulkSnapshotNil = errors.New("server bulk snapshot target is nil")
|
||||
errClientBulkSnapshotUnsupported = errors.New("client bulk snapshot target type is unsupported")
|
||||
errServerBulkSnapshotUnsupported = errors.New("server bulk snapshot target type is unsupported")
|
||||
)
|
||||
|
||||
func GetClientBulkSnapshots(c Client) ([]BulkSnapshot, error) {
|
||||
if c == nil {
|
||||
return nil, errClientBulkSnapshotNil
|
||||
}
|
||||
reader, ok := any(c).(clientBulkSnapshotReader)
|
||||
if !ok {
|
||||
return nil, errClientBulkSnapshotUnsupported
|
||||
}
|
||||
return reader.clientBulkSnapshots(), nil
|
||||
}
|
||||
|
||||
func GetServerBulkSnapshots(s Server) ([]BulkSnapshot, error) {
|
||||
if s == nil {
|
||||
return nil, errServerBulkSnapshotNil
|
||||
}
|
||||
reader, ok := any(s).(serverBulkSnapshotReader)
|
||||
if !ok {
|
||||
return nil, errServerBulkSnapshotUnsupported
|
||||
}
|
||||
return reader.serverBulkSnapshots(), nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientBulkSnapshots() []BulkSnapshot {
|
||||
return bulkSnapshotsFromRuntime(c.getBulkRuntime())
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverBulkSnapshots() []BulkSnapshot {
|
||||
return bulkSnapshotsFromRuntime(s.getBulkRuntime())
|
||||
}
|
||||
|
||||
func bulkSnapshotsFromRuntime(runtime *bulkRuntime) []BulkSnapshot {
|
||||
if runtime == nil {
|
||||
return nil
|
||||
}
|
||||
return runtime.snapshots()
|
||||
}
|
||||
|
||||
func sortBulkSnapshots(src []BulkSnapshot) {
|
||||
sort.Slice(src, func(i, j int) bool {
|
||||
if src[i].Scope != src[j].Scope {
|
||||
return src[i].Scope < src[j].Scope
|
||||
}
|
||||
if src[i].ID != src[j].ID {
|
||||
return src[i].ID < src[j].ID
|
||||
}
|
||||
if src[i].DataID != src[j].DataID {
|
||||
return src[i].DataID < src[j].DataID
|
||||
}
|
||||
return src[i].TransportGeneration < src[j].TransportGeneration
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func BenchmarkModernPSKSealPlainThroughput(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
}{
|
||||
{
|
||||
name: "seal_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "seal_4MiB",
|
||||
payloadSize: 4 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions())
|
||||
if err != nil {
|
||||
b.Fatalf("deriveModernPSKKey failed: %v", err)
|
||||
}
|
||||
transport := buildModernPSKTransportBundle(aad)
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
payload := make([]byte, tc.payloadSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
var sink []byte
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(tc.payloadSize))
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
wire, err := transport.fastPlainEncode(key, len(payload), func(dst []byte) error {
|
||||
copy(dst, payload)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("fastPlainEncode failed: %v", err)
|
||||
}
|
||||
sink = wire
|
||||
}
|
||||
b.StopTimer()
|
||||
_ = sink
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) {
|
||||
cases := []struct {
|
||||
name string
|
||||
payloadSize int
|
||||
}{
|
||||
{
|
||||
name: "wire_1MiB",
|
||||
payloadSize: 1024 * 1024,
|
||||
},
|
||||
{
|
||||
name: "wire_4MiB",
|
||||
payloadSize: 4 * 1024 * 1024,
|
||||
},
|
||||
}
|
||||
|
||||
key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions())
|
||||
if err != nil {
|
||||
b.Fatalf("deriveModernPSKKey failed: %v", err)
|
||||
}
|
||||
transport := buildModernPSKTransportBundle(aad)
|
||||
|
||||
for _, tc := range cases {
|
||||
b.Run(tc.name, func(b *testing.B) {
|
||||
benchmarkDedicatedWireLocalhostThroughput(b, key, transport.fastPlainEncode, tc.payloadSize)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) {
|
||||
b.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
b.Fatalf("net.Listen failed: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
_ = listener.Close()
|
||||
})
|
||||
|
||||
acceptCh := make(chan net.Conn, 1)
|
||||
acceptErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
acceptErrCh <- err
|
||||
return
|
||||
}
|
||||
acceptCh <- conn
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp", listener.Addr().String())
|
||||
if err != nil {
|
||||
b.Fatalf("net.Dial failed: %v", err)
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
_ = clientConn.Close()
|
||||
})
|
||||
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
|
||||
_ = tcpConn.SetNoDelay(true)
|
||||
}
|
||||
|
||||
var serverConn net.Conn
|
||||
select {
|
||||
case conn := <-acceptCh:
|
||||
serverConn = conn
|
||||
case err := <-acceptErrCh:
|
||||
b.Fatalf("Accept failed: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
b.Fatal("timed out waiting for accept")
|
||||
}
|
||||
b.Cleanup(func() {
|
||||
if serverConn != nil {
|
||||
_ = serverConn.Close()
|
||||
}
|
||||
})
|
||||
|
||||
drainDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := io.Copy(io.Discard, serverConn)
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
drainDone <- err
|
||||
return
|
||||
}
|
||||
drainDone <- nil
|
||||
}()
|
||||
|
||||
sender := newBulkDedicatedSender(clientConn, 1, func(plain []byte) ([]byte, error) {
|
||||
return encode(key, len(plain), func(dst []byte) error {
|
||||
copy(dst, plain)
|
||||
return nil
|
||||
})
|
||||
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||
return encodeBulkDedicatedBatchPayloadFast(encode, key, 1, items)
|
||||
}, nil)
|
||||
defer sender.stop()
|
||||
|
||||
payload := make([]byte, payloadSize)
|
||||
for i := range payload {
|
||||
payload[i] = byte(i)
|
||||
}
|
||||
|
||||
b.ReportAllocs()
|
||||
b.SetBytes(int64(payloadSize))
|
||||
b.ResetTimer()
|
||||
seq := uint64(1)
|
||||
for i := 0; i < b.N; i++ {
|
||||
n, err := sender.submitWrite(context.Background(), seq, payload, payloadSize)
|
||||
if err != nil {
|
||||
b.Fatalf("submitWrite failed at iter %d: %v", i, err)
|
||||
}
|
||||
if n != len(payload) {
|
||||
b.Fatalf("submitWrite bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
|
||||
}
|
||||
seq++
|
||||
}
|
||||
b.StopTimer()
|
||||
|
||||
_ = clientConn.Close()
|
||||
select {
|
||||
case err := <-drainDone:
|
||||
if err != nil {
|
||||
b.Fatalf("server drain failed: %v", err)
|
||||
}
|
||||
case <-time.After(10 * time.Second):
|
||||
b.Fatal("timed out waiting for server drain")
|
||||
}
|
||||
}
|
||||
+1494
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBulkOpenDedicatedUDPRejected(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
return nil
|
||||
})
|
||||
if err := server.Listen("udp", "127.0.0.1:0"); err != nil {
|
||||
t.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil {
|
||||
t.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
_, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: 128,
|
||||
},
|
||||
Dedicated: true,
|
||||
})
|
||||
if !errors.Is(err, errBulkDedicatedStreamOnly) {
|
||||
t.Fatalf("client OpenBulk dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerOpenBulkLogicalDedicatedUDPRejected(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
return nil
|
||||
})
|
||||
if err := server.Listen("udp", "127.0.0.1:0"); err != nil {
|
||||
t.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil {
|
||||
t.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||||
_, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{
|
||||
Range: BulkRange{
|
||||
Offset: 0,
|
||||
Length: 128,
|
||||
},
|
||||
Dedicated: true,
|
||||
})
|
||||
if !errors.Is(err, errBulkDedicatedStreamOnly) {
|
||||
t.Fatalf("server OpenBulkLogical dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly)
|
||||
}
|
||||
}
|
||||
@@ -1,15 +1,9 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/starcrypto"
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
@@ -22,6 +16,11 @@ type ClientCommon struct {
|
||||
conn net.Conn
|
||||
mu sync.Mutex
|
||||
msgID uint64
|
||||
peerIdentity string
|
||||
sessionEpoch uint64
|
||||
sessionOwnerState atomic.Int32
|
||||
sessionRuntime atomic.Pointer[clientSessionRuntime]
|
||||
connectSource atomic.Pointer[clientConnectSource]
|
||||
queue *stario.StarQueue
|
||||
stopFn context.CancelFunc
|
||||
stopCtx context.Context
|
||||
@@ -33,7 +32,9 @@ type ClientCommon struct {
|
||||
defaultFns func(message *Message)
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
noFinSyncMsgPool sync.Map
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
handshakeRsaPubKey []byte
|
||||
SecretKey []byte
|
||||
noFinSyncMsgMaxKeepSeconds int
|
||||
@@ -46,126 +47,39 @@ type ClientCommon struct {
|
||||
useHeartBeat bool
|
||||
sequenceDe func([]byte) (interface{}, error)
|
||||
sequenceEn func(interface{}) ([]byte, error)
|
||||
logicalSession *logicalSessionState
|
||||
onFileEvent func(FileEvent)
|
||||
fileEventObserver func(FileEvent)
|
||||
fileTransferCfg fileTransferConfig
|
||||
signalReliableCfg signalReliabilityConfig
|
||||
streamRuntime *streamRuntime
|
||||
recordRuntime *recordRuntime
|
||||
bulkRuntime *bulkRuntime
|
||||
connectionRetryState *connectionRetryState
|
||||
securityReadyCheck bool
|
||||
debugMode bool
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Connect(network string, addr string) error {
|
||||
if c.alive.Load().(bool) {
|
||||
return errors.New("client already run")
|
||||
}
|
||||
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
|
||||
c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32)
|
||||
conn, err := net.Dial(network, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.alive.Store(true)
|
||||
c.status.Alive = true
|
||||
c.conn = conn
|
||||
if c.useHeartBeat {
|
||||
go c.Heartbeat()
|
||||
}
|
||||
return c.clientPostInit()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) DebugMode(dmg bool) {
|
||||
c.mu.Lock()
|
||||
c.debugMode = dmg
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) IsDebugMode() bool {
|
||||
return c.debugMode
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
|
||||
if c.alive.Load().(bool) {
|
||||
return errors.New("client already run")
|
||||
}
|
||||
c.stopCtx, c.stopFn = context.WithCancel(context.Background())
|
||||
c.queue = stario.NewQueueCtx(c.stopCtx, 4, math.MaxUint32)
|
||||
conn, err := net.DialTimeout(network, addr, timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
c.alive.Store(true)
|
||||
c.status.Alive = true
|
||||
c.conn = conn
|
||||
if c.useHeartBeat {
|
||||
go c.Heartbeat()
|
||||
}
|
||||
return c.clientPostInit()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) monitorPool() {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
|
||||
data := v.(WaitMsg)
|
||||
close(data.Reply)
|
||||
c.noFinSyncMsgPool.Delete(k)
|
||||
return true
|
||||
})
|
||||
return
|
||||
case <-time.After(time.Second * 30):
|
||||
}
|
||||
now := time.Now()
|
||||
if c.noFinSyncMsgMaxKeepSeconds > 0 {
|
||||
c.noFinSyncMsgPool.Range(func(k, v interface{}) bool {
|
||||
data := v.(WaitMsg)
|
||||
if data.Time.Add(time.Duration(c.noFinSyncMsgMaxKeepSeconds) * time.Second).Before(now) {
|
||||
close(data.Reply)
|
||||
c.noFinSyncMsgPool.Delete(k)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SkipExchangeKey() bool {
|
||||
return c.skipKeyExchange
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
|
||||
c.skipKeyExchange = val
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientPostInit() error {
|
||||
go c.readMessage()
|
||||
go c.loadMessage()
|
||||
if !c.skipKeyExchange {
|
||||
err := c.keyExchangeFn(c)
|
||||
if err != nil {
|
||||
c.alive.Store(false)
|
||||
c.mu.Lock()
|
||||
c.status = Status{
|
||||
Alive: false,
|
||||
Reason: "key exchange failed",
|
||||
Err: err,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.stopFn()
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func NewClient() Client {
|
||||
transport := defaultModernPSKTransportBundle()
|
||||
var client = ClientCommon{
|
||||
maxReadTimeout: 0,
|
||||
maxWriteTimeout: 0,
|
||||
peerIdentity: newClientPeerIdentity(),
|
||||
sequenceEn: encode,
|
||||
sequenceDe: Decode,
|
||||
keyExchangeFn: aesRsaHello,
|
||||
SecretKey: defaultAesKey,
|
||||
SecretKey: nil,
|
||||
handshakeRsaPubKey: defaultRsaPubKey,
|
||||
msgEn: defaultMsgEn,
|
||||
msgDe: defaultMsgDe,
|
||||
msgEn: transport.msgEn,
|
||||
msgDe: transport.msgDe,
|
||||
fastStreamEncode: transport.fastStreamEncode,
|
||||
fastBulkEncode: transport.fastBulkEncode,
|
||||
fastPlainEncode: transport.fastPlainEncode,
|
||||
skipKeyExchange: true,
|
||||
securityReadyCheck: true,
|
||||
}
|
||||
client.alive.Store(false)
|
||||
//heartbeat should not controlable for user
|
||||
client.useHeartBeat = true
|
||||
client.heartbeatPeriod = time.Second * 20
|
||||
client.linkFns = make(map[string]func(*Message))
|
||||
@@ -173,442 +87,19 @@ func NewClient() Client {
|
||||
return
|
||||
}
|
||||
client.wg = stario.NewWaitGroup(0)
|
||||
client.fileTransferCfg = defaultFileTransferConfig()
|
||||
client.signalReliableCfg = defaultSignalReliabilityConfig()
|
||||
client.logicalSession = newLogicalSessionState(client.fileTransferCfg, client.signalReliableCfg)
|
||||
client.streamRuntime = newStreamRuntime("cstrm")
|
||||
client.recordRuntime = newRecordRuntime()
|
||||
client.bulkRuntime = newBulkRuntime("cblk")
|
||||
client.connectionRetryState = newConnectionRetryState()
|
||||
client.onFileEvent = normalizeFileEventCallback(nil)
|
||||
client.fileEventObserver = normalizeFileEventCallback(nil)
|
||||
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
|
||||
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
|
||||
bindClientStreamControl(&client)
|
||||
bindClientBulkControl(&client)
|
||||
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
|
||||
return &client
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Heartbeat() {
|
||||
failedCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
return
|
||||
case <-time.After(c.heartbeatPeriod):
|
||||
}
|
||||
_, err := c.sendWait(TransferMsg{
|
||||
ID: 10000,
|
||||
Key: "heartbeat",
|
||||
Value: nil,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}, time.Second*5)
|
||||
if err == nil {
|
||||
c.lastHeartbeat = time.Now().Unix()
|
||||
failedCount = 0
|
||||
}
|
||||
if c.debugMode {
|
||||
fmt.Println("failed to recv heartbeat,timeout!")
|
||||
}
|
||||
failedCount++
|
||||
if failedCount >= 3 {
|
||||
if c.debugMode {
|
||||
fmt.Println("heatbeat failed more than 3 times,stop client")
|
||||
}
|
||||
c.alive.Store(false)
|
||||
c.mu.Lock()
|
||||
c.status = Status{
|
||||
Alive: false,
|
||||
Reason: "heartbeat failed more than 3 times",
|
||||
Err: errors.New("heartbeat failed more than 3 times"),
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.stopFn()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ShowError(std bool) {
|
||||
c.mu.Lock()
|
||||
c.showError = std
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readMessage() {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
c.conn.Close()
|
||||
return
|
||||
default:
|
||||
}
|
||||
data := make([]byte, 8192)
|
||||
if c.maxReadTimeout.Seconds() != 0 {
|
||||
if err := c.conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout)); err != nil {
|
||||
//TODO:ALERT
|
||||
}
|
||||
}
|
||||
readNum, err := c.conn.Read(data)
|
||||
if err == os.ErrDeadlineExceeded {
|
||||
if readNum != 0 {
|
||||
c.queue.ParseMessage(data[:readNum], "b612")
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client read error", err)
|
||||
}
|
||||
c.alive.Store(false)
|
||||
c.mu.Lock()
|
||||
c.status = Status{
|
||||
Alive: false,
|
||||
Reason: "client read error",
|
||||
Err: err,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.stopFn()
|
||||
continue
|
||||
}
|
||||
c.queue.ParseMessage(data[:readNum], "b612")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sayGoodBye() error {
|
||||
_, err := c.sendWait(TransferMsg{
|
||||
ID: 10010,
|
||||
Key: "bye",
|
||||
Value: nil,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}, time.Second*3)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) loadMessage() {
|
||||
for {
|
||||
select {
|
||||
case <-c.stopCtx.Done():
|
||||
//say goodbye
|
||||
if !c.byeFromServer {
|
||||
c.sayGoodBye()
|
||||
}
|
||||
c.conn.Close()
|
||||
return
|
||||
case data, ok := <-c.queue.RestoreChan():
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
c.wg.Add(1)
|
||||
go func(data stario.MsgQueue) {
|
||||
defer c.wg.Done()
|
||||
//fmt.Println("c received:", float64(time.Now().UnixNano()-nowd)/1000000)
|
||||
now := time.Now()
|
||||
//transfer to Msg
|
||||
msg, err := c.sequenceDe(c.msgDe(c.SecretKey, data.Msg))
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client decode data error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
message := Message{
|
||||
ServerConn: c,
|
||||
TransferMsg: msg.(TransferMsg),
|
||||
NetType: NET_CLIENT,
|
||||
}
|
||||
message.Time = now
|
||||
c.dispatchMsg(message)
|
||||
}(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchMsg(message Message) {
|
||||
switch message.TransferMsg.Type {
|
||||
case MSG_SYS_WAIT:
|
||||
fallthrough
|
||||
case MSG_SYS:
|
||||
c.sysMsg(message)
|
||||
return
|
||||
case MSG_KEY_CHANGE:
|
||||
fallthrough
|
||||
case MSG_SYS_REPLY:
|
||||
fallthrough
|
||||
case MSG_SYNC_REPLY:
|
||||
data, ok := c.noFinSyncMsgPool.Load(message.ID)
|
||||
if ok {
|
||||
wait := data.(WaitMsg)
|
||||
wait.Reply <- message
|
||||
c.noFinSyncMsgPool.Delete(message.ID)
|
||||
return
|
||||
}
|
||||
//return
|
||||
fallthrough
|
||||
default:
|
||||
}
|
||||
callFn := func(fn func(*Message)) {
|
||||
fn(&message)
|
||||
}
|
||||
fn, ok := c.linkFns[message.Key]
|
||||
if ok {
|
||||
callFn(fn)
|
||||
}
|
||||
if c.defaultFns != nil {
|
||||
callFn(c.defaultFns)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sysMsg(message Message) {
|
||||
switch message.Key {
|
||||
case "bye":
|
||||
if message.TransferMsg.Type == MSG_SYS_WAIT {
|
||||
//fmt.Println("recv stop signal from server")
|
||||
c.byeFromServer = true
|
||||
message.Reply(nil)
|
||||
}
|
||||
c.alive.Store(false)
|
||||
c.mu.Lock()
|
||||
c.status = Status{
|
||||
Alive: false,
|
||||
Reason: "recv stop signal from server",
|
||||
Err: nil,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.stopFn()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
|
||||
c.defaultFns = fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.linkFns[key] = fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
|
||||
var wait WaitMsg
|
||||
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
|
||||
msg.ID = atomic.AddUint64(&c.msgID, 1)
|
||||
}
|
||||
data, err := c.sequenceEn(msg)
|
||||
if err != nil {
|
||||
return WaitMsg{}, err
|
||||
}
|
||||
data = c.msgEn(c.SecretKey, data)
|
||||
data = c.queue.BuildMessage(data)
|
||||
if c.maxWriteTimeout.Seconds() != 0 {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
|
||||
}
|
||||
_, err = c.conn.Write(data)
|
||||
if err == nil && (msg.Type == MSG_SYNC_ASK || msg.Type == MSG_KEY_CHANGE || msg.Type == MSG_SYS_WAIT) {
|
||||
wait.Time = time.Now()
|
||||
wait.TransferMsg = msg
|
||||
wait.Reply = make(chan Message, 1)
|
||||
c.noFinSyncMsgPool.Store(msg.ID, wait)
|
||||
}
|
||||
return wait, err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Send(key string, value MsgVal) error {
|
||||
_, err := c.send(TransferMsg{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MSG_ASYNC,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
|
||||
data, err := c.send(msg)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
if timeout.Seconds() == 0 {
|
||||
msg, ok := <-data.Reply
|
||||
if !ok {
|
||||
return msg, os.ErrInvalid
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
close(data.Reply)
|
||||
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
|
||||
return Message{}, os.ErrDeadlineExceeded
|
||||
case <-c.stopCtx.Done():
|
||||
return Message{}, errors.New("service shutdown")
|
||||
case msg, ok := <-data.Reply:
|
||||
if !ok {
|
||||
return msg, os.ErrInvalid
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
|
||||
data, err := c.send(msg)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(data.Reply)
|
||||
c.noFinSyncMsgPool.Delete(data.TransferMsg.ID)
|
||||
return Message{}, os.ErrDeadlineExceeded
|
||||
case <-c.stopCtx.Done():
|
||||
return Message{}, errors.New("service shutdown")
|
||||
case msg, ok := <-data.Reply:
|
||||
if !ok {
|
||||
return msg, os.ErrInvalid
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
|
||||
data, err := c.sequenceEn(val)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return c.sendCtx(TransferMsg{
|
||||
Key: key,
|
||||
Value: data,
|
||||
Type: MSG_SYNC_ASK,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendObj(key string, val interface{}) error {
|
||||
data, err := encode(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.send(TransferMsg{
|
||||
Key: key,
|
||||
Value: data,
|
||||
Type: MSG_ASYNC,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
|
||||
return c.sendCtx(TransferMsg{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MSG_SYNC_ASK,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
|
||||
return c.sendWait(TransferMsg{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MSG_SYNC_ASK,
|
||||
}, timeout)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
|
||||
data, err := c.sequenceEn(value)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return c.SendWait(key, data, timeout)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
|
||||
return m.Reply(value)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
|
||||
pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := c.sendWait(TransferMsg{
|
||||
ID: 19961127,
|
||||
Key: "sirius",
|
||||
Value: newSendKey,
|
||||
Type: MSG_KEY_CHANGE,
|
||||
}, time.Second*10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(data.Value) != "success" {
|
||||
return errors.New("cannot exchange new aes-key")
|
||||
}
|
||||
c.SecretKey = newKey
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return nil
|
||||
}
|
||||
|
||||
func aesRsaHello(c Client) error {
|
||||
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
|
||||
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
|
||||
return c.ExchangeKey(newAesKey)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
||||
return c.msgEn
|
||||
}
|
||||
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||
c.msgEn = fn
|
||||
}
|
||||
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
||||
return c.msgDe
|
||||
}
|
||||
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||
c.msgDe = fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
|
||||
return c.heartbeatPeriod
|
||||
}
|
||||
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
|
||||
c.heartbeatPeriod = duration
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetSecretKey() []byte {
|
||||
return c.SecretKey
|
||||
}
|
||||
func (c *ClientCommon) SetSecretKey(key []byte) {
|
||||
c.SecretKey = key
|
||||
}
|
||||
func (c *ClientCommon) RsaPubKey() []byte {
|
||||
return c.handshakeRsaPubKey
|
||||
}
|
||||
func (c *ClientCommon) SetRsaPubKey(key []byte) {
|
||||
c.handshakeRsaPubKey = key
|
||||
}
|
||||
func (c *ClientCommon) Stop() error {
|
||||
if !c.alive.Load().(bool) {
|
||||
return nil
|
||||
}
|
||||
c.alive.Store(false)
|
||||
c.mu.Lock()
|
||||
c.status = Status{
|
||||
Alive: false,
|
||||
Reason: "recv stop signal from user",
|
||||
Err: nil,
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.stopFn()
|
||||
return nil
|
||||
}
|
||||
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
|
||||
return c.stopCtx.Done()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Status() Status {
|
||||
return c.status
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
|
||||
return c.sequenceEn
|
||||
}
|
||||
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
|
||||
c.sequenceEn = fn
|
||||
}
|
||||
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
|
||||
return c.sequenceDe
|
||||
}
|
||||
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
|
||||
c.sequenceDe = fn
|
||||
}
|
||||
|
||||
+198
@@ -0,0 +1,198 @@
|
||||
package notify
|
||||
|
||||
import "context"
|
||||
|
||||
func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
runtime.setHandler(fn)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
|
||||
if c == nil {
|
||||
return nil, errBulkClientNil
|
||||
}
|
||||
runtime := c.getBulkRuntime()
|
||||
if runtime == nil {
|
||||
return nil, errBulkRuntimeNil
|
||||
}
|
||||
req := clientBulkRequest(runtime, opt)
|
||||
if req.BulkID == "" {
|
||||
return nil, errBulkIDEmpty
|
||||
}
|
||||
if req.Dedicated {
|
||||
if err := clientDedicatedBulkSupportError(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if !validBulkRange(req.Range) {
|
||||
return nil, errBulkRangeInvalid
|
||||
}
|
||||
if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists {
|
||||
return nil, errBulkAlreadyExists
|
||||
}
|
||||
resp, err := sendBulkOpenClient(ctx, c, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.DataID != 0 {
|
||||
req.DataID = resp.DataID
|
||||
}
|
||||
req.Dedicated = resp.Dedicated
|
||||
if resp.AttachToken != "" {
|
||||
req.AttachToken = resp.AttachToken
|
||||
}
|
||||
if req.DataID == 0 {
|
||||
return nil, errBulkDataIDEmpty
|
||||
}
|
||||
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
|
||||
bulk.setClientSnapshotOwner(c)
|
||||
if err := runtime.register(clientFileScope(), bulk); err != nil {
|
||||
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
|
||||
BulkID: req.BulkID,
|
||||
DataID: req.DataID,
|
||||
Error: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
if bulk.Dedicated() {
|
||||
if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil {
|
||||
runtime.remove(clientFileScope(), bulk.ID())
|
||||
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
|
||||
BulkID: bulk.ID(),
|
||||
DataID: bulk.dataIDSnapshot(),
|
||||
Error: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return bulk, nil
|
||||
}
|
||||
|
||||
func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenRequest {
|
||||
opt = normalizeBulkOpenOptions(opt)
|
||||
id := opt.ID
|
||||
if id == "" && runtime != nil {
|
||||
id = runtime.nextID()
|
||||
}
|
||||
return normalizeBulkOpenRequest(BulkOpenRequest{
|
||||
BulkID: id,
|
||||
Range: opt.Range,
|
||||
Metadata: cloneBulkMetadata(opt.Metadata),
|
||||
ReadTimeout: opt.ReadTimeout,
|
||||
WriteTimeout: opt.WriteTimeout,
|
||||
Dedicated: opt.Dedicated,
|
||||
ChunkSize: opt.ChunkSize,
|
||||
WindowBytes: opt.WindowBytes,
|
||||
MaxInFlight: opt.MaxInFlight,
|
||||
})
|
||||
}
|
||||
|
||||
func clientBulkCloseSender(c *ClientCommon) bulkCloseSender {
|
||||
return func(ctx context.Context, bulk *bulkHandle, full bool) error {
|
||||
if bulk != nil && bulk.Dedicated() {
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.sendDedicatedBulkClose(ctx, bulk, full)
|
||||
}
|
||||
_, err := sendBulkCloseClient(ctx, c, BulkCloseRequest{
|
||||
BulkID: bulk.ID(),
|
||||
Full: full,
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func clientBulkResetSender(c *ClientCommon) bulkResetSender {
|
||||
return func(ctx context.Context, bulk *bulkHandle, message string) error {
|
||||
if bulk != nil && bulk.Dedicated() {
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.sendDedicatedBulkReset(ctx, bulk, message)
|
||||
}
|
||||
_, err := sendBulkResetClient(ctx, c, BulkResetRequest{
|
||||
BulkID: bulk.ID(),
|
||||
DataID: bulk.dataIDSnapshot(),
|
||||
Error: message,
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender {
|
||||
return func(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
|
||||
if c == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
if bulk != nil && bulk.Dedicated() {
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.sendDedicatedBulkData(ctx, bulk, chunk)
|
||||
}
|
||||
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||
return errTransportDetached
|
||||
}
|
||||
dataID := bulk.dataIDSnapshot()
|
||||
if dataID == 0 {
|
||||
return errBulkDataPathNotReady
|
||||
}
|
||||
return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk)
|
||||
}
|
||||
}
|
||||
|
||||
func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
|
||||
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
|
||||
if c == nil {
|
||||
return 0, errBulkClientNil
|
||||
}
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return 0, ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
if bulk != nil && bulk.Dedicated() {
|
||||
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.sendDedicatedBulkWrite(ctx, bulk, payload)
|
||||
}
|
||||
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||
return 0, errTransportDetached
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender {
|
||||
return func(bulk *bulkHandle, bytes int64, chunks int) error {
|
||||
if c == nil || bulk == nil {
|
||||
return errBulkClientNil
|
||||
}
|
||||
if bytes <= 0 && chunks <= 0 {
|
||||
return nil
|
||||
}
|
||||
if bulk.Dedicated() {
|
||||
return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks)
|
||||
}
|
||||
return sendBulkReleaseClient(c, BulkReleaseRequest{
|
||||
BulkID: bulk.ID(),
|
||||
DataID: bulk.dataIDSnapshot(),
|
||||
Bytes: bytes,
|
||||
Chunks: chunks,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *ClientCommon) DebugMode(dmg bool) {
|
||||
c.mu.Lock()
|
||||
c.debugMode = dmg
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) IsDebugMode() bool {
|
||||
return c.debugMode
|
||||
}
|
||||
|
||||
// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||
func (c *ClientCommon) SkipExchangeKey() bool {
|
||||
return c.skipKeyExchange
|
||||
}
|
||||
|
||||
// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
|
||||
c.skipKeyExchange = val
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ShowError(std bool) {
|
||||
c.mu.Lock()
|
||||
c.showError = std
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
|
||||
c.defaultFns = fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.linkFns[key] = fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetFileHandler(fn func(FileEvent)) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.onFileEvent = normalizeFileEventCallback(fn)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetFileReceiveDir(dir string) error {
|
||||
return c.getFileReceivePool().setDir(dir)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetTransferResumeStore(store TransferResumeStore) {
|
||||
if runtime := c.getTransferRuntime(); runtime != nil {
|
||||
runtime.setResumeStore(store)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
|
||||
if runtime := c.getTransferRuntime(); runtime != nil {
|
||||
return runtime.recover(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
||||
return c.msgEn
|
||||
}
|
||||
|
||||
// Deprecated: SetMsgEn overrides the transport codec directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||
c.msgEn = fn
|
||||
c.fastStreamEncode = nil
|
||||
c.fastBulkEncode = nil
|
||||
c.fastPlainEncode = nil
|
||||
c.securityReadyCheck = false
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
||||
return c.msgDe
|
||||
}
|
||||
|
||||
// Deprecated: SetMsgDe overrides the transport codec directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||
c.msgDe = fn
|
||||
c.fastStreamEncode = nil
|
||||
c.fastBulkEncode = nil
|
||||
c.fastPlainEncode = nil
|
||||
c.securityReadyCheck = false
|
||||
}
|
||||
|
||||
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
|
||||
return c.heartbeatPeriod
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
|
||||
c.heartbeatPeriod = duration
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetSecretKey() []byte {
|
||||
return c.SecretKey
|
||||
}
|
||||
|
||||
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
func (c *ClientCommon) SetSecretKey(key []byte) {
|
||||
c.SecretKey = key
|
||||
c.securityReadyCheck = len(key) == 0
|
||||
c.skipKeyExchange = true
|
||||
}
|
||||
|
||||
// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||
func (c *ClientCommon) RsaPubKey() []byte {
|
||||
return c.handshakeRsaPubKey
|
||||
}
|
||||
|
||||
// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||
func (c *ClientCommon) SetRsaPubKey(key []byte) {
|
||||
c.handshakeRsaPubKey = key
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Stop() error {
|
||||
if !sessionIsAlive(&c.alive) {
|
||||
return nil
|
||||
}
|
||||
c.stopClientSession("recv stop signal from user", nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
|
||||
return sessionStopChan(c.clientStopContextSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Status() Status {
|
||||
return sessionStatusValue(&c.mu, &c.status)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
|
||||
return c.sequenceEn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
|
||||
c.sequenceEn = fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
|
||||
return c.sequenceDe
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
|
||||
c.sequenceDe = fn
|
||||
}
|
||||
+437
@@ -0,0 +1,437 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/starcrypto"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type clientConnTransportDetachState struct {
|
||||
Generation uint64
|
||||
Reason string
|
||||
Err string
|
||||
At time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
clientConnTransportDetachKindReadError = "read_error"
|
||||
clientConnTransportDetachKindHeartbeatTimeout = "heartbeat_timeout"
|
||||
clientConnTransportDetachKindOther = "other"
|
||||
)
|
||||
|
||||
type ClientConn struct {
|
||||
alive atomic.Value
|
||||
status Status
|
||||
logicalView atomic.Pointer[LogicalConn]
|
||||
logicalState atomic.Pointer[logicalConnState]
|
||||
runtimeState atomic.Pointer[logicalConnRuntimeState]
|
||||
transportState atomic.Pointer[clientConnTransportState]
|
||||
sessionRuntime atomic.Pointer[clientConnSessionRuntime]
|
||||
attachment atomic.Pointer[clientConnAttachmentState]
|
||||
identityBound atomic.Bool
|
||||
ClientID string
|
||||
ClientAddr net.Addr
|
||||
server Server
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
Alive bool
|
||||
Reason string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (c *ClientConn) readTUMessage() {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
logical.readTUMessage()
|
||||
return
|
||||
}
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
c.readTUMessageLoop(rt)
|
||||
}
|
||||
|
||||
func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
logical.readTUMessageLoop(rt)
|
||||
return
|
||||
}
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
stopCtx := rt.transportStopCtx
|
||||
if stopCtx == nil {
|
||||
stopCtx = rt.stopCtx
|
||||
}
|
||||
if stopCtx == nil {
|
||||
return
|
||||
}
|
||||
conn := rt.tuConn
|
||||
generation := rt.transportGeneration
|
||||
defer closeClientConnSessionRuntimeTransportDone(rt)
|
||||
buf := streamReadBuffer()
|
||||
for {
|
||||
select {
|
||||
case <-sessionStopChan(stopCtx):
|
||||
if c.shouldCloseClientConnTransportOnStop(conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
default:
|
||||
}
|
||||
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
|
||||
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deprecated: rsaDecode exists only for the legacy MSG_KEY_CHANGE flow.
|
||||
func (c *ClientConn) rsaDecode(message Message) {
|
||||
privKey, err := starcrypto.DecodeRsaPrivateKey(c.clientConnHandshakeRsaKeySnapshot(), "")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
message.Reply([]byte("failed"))
|
||||
return
|
||||
}
|
||||
data, err := starcrypto.RSADecrypt(privKey, message.Value)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
message.Reply([]byte("failed"))
|
||||
return
|
||||
}
|
||||
message.Reply([]byte("success"))
|
||||
c.setClientConnSecretKey(data)
|
||||
}
|
||||
|
||||
func (c *ClientConn) sayGoodByeForTU() error {
|
||||
if c == nil || c.server == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
_, err := c.server.SendWaitLogical(c.LogicalConn(), "bye", nil, time.Second*3)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
_, err = c.server.sendWait(c, TransferMsg{
|
||||
ID: 10010,
|
||||
Key: "bye",
|
||||
Value: nil,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}, time.Second*3)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientConn) GetSecretKey() []byte {
|
||||
return c.clientConnSecretKeySnapshot()
|
||||
}
|
||||
|
||||
// Deprecated: SetSecretKey injects a raw per-connection transport key directly.
|
||||
func (c *ClientConn) SetSecretKey(key []byte) {
|
||||
c.setClientConnSecretKey(key)
|
||||
}
|
||||
|
||||
func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte {
|
||||
return c.clientConnMsgEnSnapshot()
|
||||
}
|
||||
|
||||
// Deprecated: SetMsgEn overrides the per-connection transport codec directly.
|
||||
func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||
c.setClientConnMsgEn(fn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte {
|
||||
return c.clientConnMsgDeSnapshot()
|
||||
}
|
||||
|
||||
// Deprecated: SetMsgDe overrides the per-connection transport codec directly.
|
||||
func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||
c.setClientConnMsgDe(fn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) StopMonitorChan() <-chan struct{} {
|
||||
return sessionStopChan(c.clientConnStopContextSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientConn) Status() Status {
|
||||
return c.clientConnStatusSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) Server() Server {
|
||||
if c != nil {
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
if server := logical.Server(); server != nil {
|
||||
return server
|
||||
}
|
||||
}
|
||||
}
|
||||
return c.server
|
||||
}
|
||||
|
||||
func (c *ClientConn) GetRemoteAddr() net.Addr {
|
||||
return c.clientConnRemoteAddrSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) markClientConnIdentityBound() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.markIdentityBound()
|
||||
return
|
||||
}
|
||||
state := c.ensureLogicalConnState()
|
||||
if state == nil {
|
||||
c.identityBound.Store(true)
|
||||
return
|
||||
}
|
||||
state.updatePeer(func(peer *logicalConnPeerState) {
|
||||
peer.identityBound = true
|
||||
})
|
||||
c.syncLegacyLogicalFieldsFromState(state)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnIdentityBoundSnapshot() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return c.clientConnLogicalPeerStateSnapshot().identityBound
|
||||
}
|
||||
|
||||
func (c *ClientConn) markClientConnStreamTransport() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.markStreamTransport()
|
||||
return
|
||||
}
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
state.streamTransport.Store(true)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnUsesStreamTransportSnapshot() bool {
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return false
|
||||
}
|
||||
return state.streamTransport.Load()
|
||||
}
|
||||
|
||||
func (c *ClientConn) shouldPreserveLogicalPeerOnTransportLoss() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
return c.clientConnIdentityBoundSnapshot() && c.clientConnUsesStreamTransportSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) markClientConnTransportAttached() uint64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
return logical.markTransportAttached()
|
||||
}
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return 0
|
||||
}
|
||||
gen := state.transportGen.Add(1)
|
||||
state.attachCount.Add(1)
|
||||
state.lastAttachAt.Store(time.Now().UnixNano())
|
||||
return gen
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportGenerationSnapshot() uint64 {
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return 0
|
||||
}
|
||||
return state.transportGen.Load()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportAttachCountSnapshot() uint64 {
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return 0
|
||||
}
|
||||
return state.attachCount.Load()
|
||||
}
|
||||
|
||||
func (c *ClientConn) markClientConnTransportDetached(reason string, err error) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.markTransportDetached(reason, err)
|
||||
return
|
||||
}
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return
|
||||
}
|
||||
detachState := &clientConnTransportDetachState{
|
||||
Generation: c.clientConnTransportGenerationSnapshot(),
|
||||
Reason: reason,
|
||||
At: time.Now(),
|
||||
}
|
||||
if err != nil {
|
||||
detachState.Err = err.Error()
|
||||
}
|
||||
state.detachCount.Add(1)
|
||||
state.transportDetach.Store(detachState)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachCountSnapshot() uint64 {
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return 0
|
||||
}
|
||||
return state.detachCount.Load()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clearClientConnTransportDetachState() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.clearTransportDetachState()
|
||||
return
|
||||
}
|
||||
c.setClientConnTransportDetachState(nil)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState {
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return nil
|
||||
}
|
||||
return cloneClientConnTransportDetachState(state.transportDetach.Load())
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnLogicalTransportDetachedSnapshot() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
if !c.clientConnIdentityBoundSnapshot() || !c.clientConnUsesStreamTransportSnapshot() {
|
||||
return false
|
||||
}
|
||||
if !c.clientConnAliveSnapshot() {
|
||||
return false
|
||||
}
|
||||
return !c.clientConnTransportAttachedSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnLastTransportAttachedAtSnapshot() time.Time {
|
||||
state := c.ensureClientConnTransportState()
|
||||
if state == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
unixNano := state.lastAttachAt.Load()
|
||||
if unixNano == 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(0, unixNano)
|
||||
}
|
||||
|
||||
func classifyClientConnTransportDetachReason(reason string) string {
|
||||
switch reason {
|
||||
case "":
|
||||
return ""
|
||||
case "read error":
|
||||
return clientConnTransportDetachKindReadError
|
||||
case "heartbeat timeout":
|
||||
return clientConnTransportDetachKindHeartbeatTimeout
|
||||
default:
|
||||
return clientConnTransportDetachKindOther
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachKindSnapshot() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
detach := c.clientConnTransportDetachSnapshot()
|
||||
if detach == nil {
|
||||
return ""
|
||||
}
|
||||
return classifyClientConnTransportDetachReason(detach.Reason)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachGenerationSnapshot() uint64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
detach := c.clientConnTransportDetachSnapshot()
|
||||
if detach == nil {
|
||||
return 0
|
||||
}
|
||||
if detach.Generation == 0 {
|
||||
return c.clientConnTransportGenerationSnapshot()
|
||||
}
|
||||
return detach.Generation
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) {
|
||||
if c == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
detach := c.clientConnTransportDetachSnapshot()
|
||||
if detach == nil || detach.At.IsZero() {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if c.server == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
keepSec := c.server.DetachedClientKeepSec()
|
||||
if keepSec <= 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return detach.At.Add(time.Duration(keepSec) * time.Second), true
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool {
|
||||
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
|
||||
return false
|
||||
}
|
||||
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return !now.Before(expiry)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration {
|
||||
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
|
||||
return 0
|
||||
}
|
||||
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
|
||||
if !ok {
|
||||
return 0
|
||||
}
|
||||
if !now.Before(expiry) {
|
||||
return 0
|
||||
}
|
||||
return expiry.Sub(now)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnReattachEligibleSnapshot(now time.Time) bool {
|
||||
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
|
||||
return false
|
||||
}
|
||||
if !c.clientConnAliveSnapshot() {
|
||||
return false
|
||||
}
|
||||
if c.clientConnTransportAttachedSnapshot() {
|
||||
return false
|
||||
}
|
||||
if c.clientConnTransportDetachExpiredSnapshot(now) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,333 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type clientConnAttachmentState struct {
|
||||
maxReadTimeout time.Duration
|
||||
maxWriteTimeout time.Duration
|
||||
msgEn func([]byte, []byte) []byte
|
||||
msgDe func([]byte, []byte) []byte
|
||||
fastStreamEncode transportFastStreamEncoder
|
||||
fastBulkEncode transportFastBulkEncoder
|
||||
fastPlainEncode transportFastPlainEncoder
|
||||
handshakeRsaKey []byte
|
||||
secretKey []byte
|
||||
lastHeartBeat int64
|
||||
}
|
||||
|
||||
func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState {
|
||||
if src == nil {
|
||||
return &clientConnAttachmentState{}
|
||||
}
|
||||
cloned := *src
|
||||
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
|
||||
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func cloneClientConnAttachmentBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
return append([]byte(nil), src...)
|
||||
}
|
||||
|
||||
func (c *LogicalConn) attachmentStateSnapshot() *clientConnAttachmentState {
|
||||
if c == nil {
|
||||
return &clientConnAttachmentState{}
|
||||
}
|
||||
if state := c.attachment.Load(); state != nil {
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
client.attachment.Store(state)
|
||||
}
|
||||
return cloneClientConnAttachmentState(state)
|
||||
}
|
||||
client := c.compatClientConn()
|
||||
if client != nil {
|
||||
if state := client.attachment.Load(); state != nil {
|
||||
if c.attachment.CompareAndSwap(nil, state) {
|
||||
client.attachment.Store(state)
|
||||
return cloneClientConnAttachmentState(state)
|
||||
}
|
||||
return c.attachmentStateSnapshot()
|
||||
}
|
||||
}
|
||||
return &clientConnAttachmentState{}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) setAttachmentState(state *clientConnAttachmentState) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
next := cloneClientConnAttachmentState(state)
|
||||
c.attachment.Store(next)
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
client.attachment.Store(next)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) updateAttachmentState(apply func(*clientConnAttachmentState)) {
|
||||
if c == nil || apply == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
current := c.attachment.Load()
|
||||
if current == nil {
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
current = client.attachment.Load()
|
||||
}
|
||||
}
|
||||
next := cloneClientConnAttachmentState(current)
|
||||
apply(next)
|
||||
if current == nil {
|
||||
if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) {
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
client.attachment.Store(next)
|
||||
}
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c.attachment.CompareAndSwap(current, next) {
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
client.attachment.Store(next)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState {
|
||||
if c == nil {
|
||||
return &clientConnAttachmentState{}
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
return logical.attachmentStateSnapshot()
|
||||
}
|
||||
if state := c.attachment.Load(); state != nil {
|
||||
return cloneClientConnAttachmentState(state)
|
||||
}
|
||||
return &clientConnAttachmentState{}
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnAttachmentState(state *clientConnAttachmentState) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.setAttachmentState(state)
|
||||
return
|
||||
}
|
||||
c.attachment.Store(cloneClientConnAttachmentState(state))
|
||||
}
|
||||
|
||||
func (c *ClientConn) updateClientConnAttachmentState(apply func(*clientConnAttachmentState)) {
|
||||
if c == nil || apply == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.updateAttachmentState(apply)
|
||||
return
|
||||
}
|
||||
for {
|
||||
current := c.attachment.Load()
|
||||
next := cloneClientConnAttachmentState(current)
|
||||
apply(next)
|
||||
if current == nil {
|
||||
if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if c.attachment.CompareAndSwap(current, next) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.maxReadTimeout = maxReadTimeout
|
||||
state.maxWriteTimeout = maxWriteTimeout
|
||||
state.msgEn = msgEn
|
||||
state.msgDe = msgDe
|
||||
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
|
||||
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) inheritClientConnAttachmentProfile(src *ClientConn) {
|
||||
if c == nil || src == nil {
|
||||
return
|
||||
}
|
||||
c.setClientConnAttachmentState(src.clientConnAttachmentStateSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnMaxReadTimeoutSnapshot() time.Duration {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().maxReadTimeout
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnMaxWriteTimeout(timeout time.Duration) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.maxWriteTimeout = timeout
|
||||
})
|
||||
return
|
||||
}
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.maxWriteTimeout = timeout
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnMaxWriteTimeoutSnapshot() time.Duration {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().maxWriteTimeout
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().msgEn
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.msgEn = fn
|
||||
state.fastStreamEncode = nil
|
||||
state.fastBulkEncode = nil
|
||||
state.fastPlainEncode = nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().msgDe
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.msgDe = fn
|
||||
state.fastStreamEncode = nil
|
||||
state.fastBulkEncode = nil
|
||||
state.fastPlainEncode = nil
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnFastStreamEncode(fn transportFastStreamEncoder) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.fastStreamEncode = fn
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().fastStreamEncode
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnFastBulkEncode(fn transportFastBulkEncoder) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.fastBulkEncode = fn
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().fastBulkEncode
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnFastPlainEncode(fn transportFastPlainEncoder) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.fastPlainEncode = fn
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().fastPlainEncode
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnHandshakeRsaKeySnapshot() []byte {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().handshakeRsaKey
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnSecretKeySnapshot() []byte {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().secretKey
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnSecretKey(key []byte) {
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.secretKey = cloneClientConnAttachmentBytes(key)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
return c.clientConnAttachmentStateSnapshot().lastHeartBeat
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnLastHeartbeatUnix(unix int64) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.setClientConnLastHeartbeatUnix(unix)
|
||||
return
|
||||
}
|
||||
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||
state.lastHeartBeat = unix
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientConn) markClientConnHeartbeatNow() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.markHeartbeatNow()
|
||||
return
|
||||
}
|
||||
c.setClientConnLastHeartbeatUnix(time.Now().Unix())
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnRemoteAddr(addr net.Addr) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
state := c.ensureLogicalConnState()
|
||||
if state == nil {
|
||||
c.ClientAddr = addr
|
||||
return
|
||||
}
|
||||
state.updatePeer(func(peer *logicalConnPeerState) {
|
||||
peer.clientAddr = addr
|
||||
})
|
||||
c.syncLegacyLogicalFieldsFromState(state)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *ClientConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
|
||||
if c == nil {
|
||||
return stopCtx, stopFn
|
||||
}
|
||||
return c.LogicalConn().startSession(tuConn, stopCtx, stopFn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
|
||||
if c == nil {
|
||||
return stopCtx, stopFn
|
||||
}
|
||||
return c.LogicalConn().startSessionTransport(tuConn, stopCtx, stopFn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) attachClientConnSessionTransport(tuConn net.Conn) error {
|
||||
if c == nil {
|
||||
return errors.New("client conn is nil")
|
||||
}
|
||||
return c.LogicalConn().attachSessionTransport(tuConn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) detachClientConnTransportForTransfer() (net.Conn, error) {
|
||||
if c == nil {
|
||||
return nil, errors.New("client conn is nil")
|
||||
}
|
||||
return c.LogicalConn().detachTransportForTransfer()
|
||||
}
|
||||
|
||||
func (c *ClientConn) stopServerOwnedSession(reason string, err error) {
|
||||
c.stopServerOwnedSessionWith(nil, reason, err)
|
||||
}
|
||||
|
||||
func (c *LogicalConn) stopServerOwnedSession(reason string, err error) {
|
||||
c.stopServerOwnedSessionWith(nil, reason, err)
|
||||
}
|
||||
|
||||
func (c *ClientConn) stopServerOwnedSessionWith(removeFn func(*ClientConn), reason string, err error) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.markSessionStopped(reason, err)
|
||||
c.detachServerOwnedSessionWith(removeFn)
|
||||
}
|
||||
|
||||
func (c *LogicalConn) stopServerOwnedSessionWith(removeFn func(*LogicalConn), reason string, err error) {
|
||||
client := c.compatClientConn()
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
client.markSessionStopped(reason, err)
|
||||
c.detachServerOwnedSessionWith(removeFn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) detachServerOwnedSession() {
|
||||
c.detachServerOwnedSessionWith(nil)
|
||||
}
|
||||
|
||||
func (c *LogicalConn) detachServerOwnedSession() {
|
||||
c.detachServerOwnedSessionWith(nil)
|
||||
}
|
||||
|
||||
func (c *ClientConn) detachServerOwnedSessionWith(removeFn func(*ClientConn)) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.detachServerOwnedTransport()
|
||||
if removeFn != nil {
|
||||
removeFn(c)
|
||||
return
|
||||
}
|
||||
if c.server != nil {
|
||||
c.server.removeClient(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) detachServerOwnedSessionWith(removeFn func(*LogicalConn)) {
|
||||
client := c.compatClientConn()
|
||||
if client == nil {
|
||||
return
|
||||
}
|
||||
c.detachServerOwnedTransport()
|
||||
if removeFn != nil {
|
||||
removeFn(c)
|
||||
return
|
||||
}
|
||||
if client.server != nil {
|
||||
client.server.removeLogical(c)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) detachServerOwnedTransport() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.LogicalConn().detachServerOwnedTransport()
|
||||
}
|
||||
|
||||
func (c *LogicalConn) detachServerOwnedTransport() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.closeTransport()
|
||||
c.clearSessionRuntimeTransport()
|
||||
}
|
||||
@@ -0,0 +1,232 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
type clientConnSessionRuntime struct {
|
||||
transport *transportBinding
|
||||
transportAttached bool
|
||||
transportGeneration uint64
|
||||
tuConn net.Conn
|
||||
stopCtx context.Context
|
||||
stopFn context.CancelFunc
|
||||
transportStopCtx context.Context
|
||||
transportStopFn context.CancelFunc
|
||||
transportDone chan struct{}
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnSessionRuntime(rt *clientConnSessionRuntime) {
|
||||
if c == nil || rt == nil {
|
||||
return
|
||||
}
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
if rt.transport == nil && rt.tuConn != nil {
|
||||
rt.transport = newTransportBinding(rt.tuConn, nil)
|
||||
}
|
||||
normalizeClientConnSessionRuntimeTransportState(rt)
|
||||
ensureClientConnSessionRuntimeTransportLifecycle(rt)
|
||||
ensureClientConnSessionRuntimeTransportDone(rt)
|
||||
c.sessionRuntime.Store(rt)
|
||||
return
|
||||
}
|
||||
logical.setSessionRuntime(rt)
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
state := c.ensureLogicalConnRuntimeState()
|
||||
if state == nil {
|
||||
return c.sessionRuntime.Load()
|
||||
}
|
||||
rt := state.sessionRuntimeSnapshot()
|
||||
if rt != c.sessionRuntime.Load() {
|
||||
c.sessionRuntime.Store(rt)
|
||||
}
|
||||
return rt
|
||||
}
|
||||
|
||||
func (c *ClientConn) clearClientConnSessionRuntimeTransport() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.transportStopFn != nil {
|
||||
rt.transportStopFn()
|
||||
}
|
||||
next := *rt
|
||||
next.transport = nil
|
||||
next.transportAttached = false
|
||||
next.transportGeneration = 0
|
||||
next.tuConn = nil
|
||||
next.transportStopCtx = nil
|
||||
next.transportStopFn = nil
|
||||
next.transportDone = nil
|
||||
c.setClientConnSessionRuntime(&next)
|
||||
return
|
||||
}
|
||||
logical.clearSessionRuntimeTransport()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportSnapshot() net.Conn {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transport != nil {
|
||||
return rt.transport.connSnapshot()
|
||||
}
|
||||
return rt.tuConn
|
||||
}
|
||||
return logical.transportSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnStopContextSnapshot() context.Context {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
return rt.stopCtx
|
||||
}
|
||||
return logical.stopContextSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnStopFuncSnapshot() context.CancelFunc {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
return rt.stopFn
|
||||
}
|
||||
return logical.stopFuncSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) closeClientConnTransport() {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
conn := c.clientConnTransportSnapshot()
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
logical.closeTransport()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportBindingSnapshot() *transportBinding {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transport != nil {
|
||||
return rt.transport
|
||||
}
|
||||
if rt.tuConn == nil {
|
||||
return nil
|
||||
}
|
||||
return newTransportBinding(rt.tuConn, nil)
|
||||
}
|
||||
return logical.transportBindingSnapshot()
|
||||
}
|
||||
|
||||
func normalizeClientConnSessionRuntimeTransportState(rt *clientConnSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.transport != nil {
|
||||
rt.transportAttached = rt.transport.connSnapshot() != nil
|
||||
return
|
||||
}
|
||||
rt.transportAttached = rt.tuConn != nil
|
||||
}
|
||||
|
||||
func ensureClientConnSessionRuntimeTransportLifecycle(rt *clientConnSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.tuConn == nil {
|
||||
rt.transportStopCtx = nil
|
||||
rt.transportStopFn = nil
|
||||
rt.transportDone = nil
|
||||
return
|
||||
}
|
||||
if rt.transportStopCtx != nil && rt.transportStopFn != nil {
|
||||
return
|
||||
}
|
||||
parent := rt.stopCtx
|
||||
if parent == nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent)
|
||||
}
|
||||
|
||||
func ensureClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.tuConn == nil {
|
||||
rt.transportDone = nil
|
||||
return
|
||||
}
|
||||
if rt.transportDone != nil {
|
||||
return
|
||||
}
|
||||
rt.transportDone = make(chan struct{})
|
||||
}
|
||||
|
||||
func closeClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) {
|
||||
if rt == nil || rt.transportDone == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-rt.transportDone:
|
||||
return
|
||||
default:
|
||||
close(rt.transportDone)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportStopContextSnapshot() context.Context {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transportStopCtx != nil {
|
||||
return rt.transportStopCtx
|
||||
}
|
||||
return rt.stopCtx
|
||||
}
|
||||
return logical.transportStopContextSnapshot()
|
||||
}
|
||||
|
||||
func (c *ClientConn) clientConnTransportAttachedSnapshot() bool {
|
||||
logical := c.LogicalConn()
|
||||
if logical == nil {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return false
|
||||
}
|
||||
return rt.transportAttached
|
||||
}
|
||||
return logical.transportAttachedSnapshot()
|
||||
}
|
||||
@@ -0,0 +1,443 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientConnReadTUMessagePreservesServerStopReason(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
left, right := net.Pipe()
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
|
||||
client, _, _ := newRegisteredServerClientForTest(t, server, "client-stop", left, stopCtx, stopFn)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
client.readTUMessage()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
server.stopClientSession(client, "recv stop signal from server", nil)
|
||||
_ = right.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("readTUMessage should exit after server stop")
|
||||
}
|
||||
|
||||
if status := client.Status(); status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after server stop: %+v", status)
|
||||
}
|
||||
if got := server.GetLogicalConn(client.ClientID); got != nil {
|
||||
t.Fatalf("logical should be removed after server stop, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnReadTUMessageReadErrorStopsAndRemovesClient(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
left, right := net.Pipe()
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
|
||||
client, _, _ := newRegisteredServerClientForTest(t, server, "client-read-error", left, stopCtx, stopFn)
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
client.readTUMessage()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
_ = right.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("readTUMessage should exit after read error")
|
||||
}
|
||||
|
||||
status := client.Status()
|
||||
if status.Alive || status.Reason != "read error" || status.Err == nil {
|
||||
t.Fatalf("unexpected status after read error: %+v", status)
|
||||
}
|
||||
if got := server.GetLogicalConn(client.ClientID); got != nil {
|
||||
t.Fatalf("logical should be removed after read error, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
client.markSessionStarted()
|
||||
|
||||
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||
defer runtimeCancel()
|
||||
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
|
||||
stopCtx: runtimeCtx,
|
||||
stopFn: runtimeCancel,
|
||||
})
|
||||
|
||||
client.markSessionStopped("runtime stop", nil)
|
||||
|
||||
select {
|
||||
case <-runtimeCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("runtime stop context should be canceled by markSessionStopped")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnDetachServerOwnedSessionClearsRuntimeTransport(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
client.startClientConnSession(left, stopCtx, stopFn)
|
||||
|
||||
client.detachServerOwnedSession()
|
||||
|
||||
if got := client.clientConnTransportSnapshot(); got != nil {
|
||||
t.Fatalf("runtime transport should be cleared after detach, got %v", got)
|
||||
}
|
||||
if got := client.clientConnStopContextSnapshot(); got != stopCtx {
|
||||
t.Fatalf("runtime stop context should be preserved after detach, got %v want %v", got, stopCtx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnReadFromTUTransportUsesRuntimeConn(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
runtimeLeft, runtimeRight := net.Pipe()
|
||||
defer runtimeLeft.Close()
|
||||
defer runtimeRight.Close()
|
||||
|
||||
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||
defer runtimeCancel()
|
||||
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
|
||||
tuConn: runtimeLeft,
|
||||
stopCtx: runtimeCtx,
|
||||
stopFn: runtimeCancel,
|
||||
})
|
||||
|
||||
payload := []byte("runtime-tu-conn")
|
||||
writeDone := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := runtimeRight.Write(payload)
|
||||
writeDone <- err
|
||||
}()
|
||||
|
||||
num, data, err := client.readFromTUTransport()
|
||||
if err != nil {
|
||||
t.Fatalf("readFromTUTransport failed: %v", err)
|
||||
}
|
||||
if got, want := string(data[:num]), string(payload); got != want {
|
||||
t.Fatalf("payload mismatch: got %q want %q", got, want)
|
||||
}
|
||||
select {
|
||||
case err := <-writeDone:
|
||||
if err != nil {
|
||||
t.Fatalf("runtime writer failed: %v", err)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("runtime writer did not finish")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStartClientConnSessionInitializesDefaultRuntime(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := client.startClientConnSession(left, nil, nil)
|
||||
defer stopFn()
|
||||
|
||||
if !client.Status().Alive {
|
||||
t.Fatalf("client should start alive: %+v", client.Status())
|
||||
}
|
||||
if stopCtx == nil || stopFn == nil {
|
||||
t.Fatal("startClientConnSession should initialize default stop context")
|
||||
}
|
||||
if got := client.clientConnTransportSnapshot(); got != left {
|
||||
t.Fatal("runtime transport snapshot should match passed conn")
|
||||
}
|
||||
if got := client.clientConnStopContextSnapshot(); got != stopCtx {
|
||||
t.Fatal("runtime stop context snapshot should match returned context")
|
||||
}
|
||||
if got := client.clientConnStopFuncSnapshot(); got == nil {
|
||||
t.Fatal("runtime stop func snapshot should be initialized")
|
||||
}
|
||||
if got := client.GetRemoteAddr(); got == nil || got.String() != left.RemoteAddr().String() {
|
||||
t.Fatalf("client remote addr mismatch: got %v want %v", got, left.RemoteAddr())
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogicalConnSessionTransportLifecycleUsesLogicalRuntimeOwner(t *testing.T) {
|
||||
client := &ClientConn{ClientID: "logical-runtime"}
|
||||
logical := client.LogicalConn()
|
||||
if logical == nil {
|
||||
t.Fatal("LogicalConn should exist")
|
||||
}
|
||||
|
||||
firstLeft, firstRight := net.Pipe()
|
||||
defer firstRight.Close()
|
||||
stopCtx, stopFn := logical.startSession(firstLeft, nil, nil)
|
||||
defer stopFn()
|
||||
|
||||
if stopCtx == nil {
|
||||
t.Fatal("logical startSession should initialize stop context")
|
||||
}
|
||||
if got := logical.transportSnapshot(); got != firstLeft {
|
||||
t.Fatalf("logical transport snapshot mismatch: got %v want %v", got, firstLeft)
|
||||
}
|
||||
if !logical.transportAttachedSnapshot() {
|
||||
t.Fatal("logical transport should be attached after startSession")
|
||||
}
|
||||
|
||||
firstGeneration := logical.transportGenerationSnapshot()
|
||||
if firstGeneration == 0 {
|
||||
t.Fatal("logical transport generation should advance for stream runtime")
|
||||
}
|
||||
|
||||
secondLeft, secondRight := net.Pipe()
|
||||
defer secondRight.Close()
|
||||
if err := logical.attachSessionTransport(secondLeft); err != nil {
|
||||
t.Fatalf("logical attachSessionTransport failed: %v", err)
|
||||
}
|
||||
|
||||
if got := logical.transportSnapshot(); got != secondLeft {
|
||||
t.Fatalf("logical transport snapshot after attach mismatch: got %v want %v", got, secondLeft)
|
||||
}
|
||||
if !logical.transportAttachedSnapshot() {
|
||||
t.Fatal("logical transport should stay attached after attachSessionTransport")
|
||||
}
|
||||
if got := logical.transportGenerationSnapshot(); got <= firstGeneration {
|
||||
t.Fatalf("logical transport generation should advance after attach: got %d want > %d", got, firstGeneration)
|
||||
}
|
||||
|
||||
detachedConn, err := logical.detachTransportForTransfer()
|
||||
if err != nil {
|
||||
t.Fatalf("logical detachTransportForTransfer failed: %v", err)
|
||||
}
|
||||
if detachedConn != secondLeft {
|
||||
t.Fatalf("detached conn mismatch: got %v want %v", detachedConn, secondLeft)
|
||||
}
|
||||
if got := logical.transportSnapshot(); got != nil {
|
||||
t.Fatalf("logical transport should be cleared after detach, got %v", got)
|
||||
}
|
||||
if logical.transportAttachedSnapshot() {
|
||||
t.Fatal("logical transport should be detached after detachTransportForTransfer")
|
||||
}
|
||||
if got := logical.stopContextSnapshot(); got != stopCtx {
|
||||
t.Fatalf("logical stop context should be preserved after detach, got %v want %v", got, stopCtx)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogicalConnOwnerStateMutationsSyncLegacyClientView(t *testing.T) {
|
||||
client := &ClientConn{ClientID: "logical-owner-state"}
|
||||
logical := client.LogicalConn()
|
||||
if logical == nil {
|
||||
t.Fatal("LogicalConn should exist")
|
||||
}
|
||||
|
||||
logical.markIdentityBound()
|
||||
logical.markStreamTransport()
|
||||
attachGeneration := logical.markTransportAttached()
|
||||
logical.setClientConnLastHeartbeatUnix(12345)
|
||||
logical.markTransportDetached("read error", errors.New("boom"))
|
||||
|
||||
if !client.clientConnIdentityBoundSnapshot() {
|
||||
t.Fatal("legacy client identity-bound snapshot should follow logical state")
|
||||
}
|
||||
if !client.clientConnUsesStreamTransportSnapshot() {
|
||||
t.Fatal("legacy client stream-transport snapshot should follow logical state")
|
||||
}
|
||||
if got := client.clientConnTransportGenerationSnapshot(); got != attachGeneration {
|
||||
t.Fatalf("legacy client transport generation = %d, want %d", got, attachGeneration)
|
||||
}
|
||||
if got := client.clientConnLastHeartbeatUnixSnapshot(); got != 12345 {
|
||||
t.Fatalf("legacy client last heartbeat = %d, want %d", got, 12345)
|
||||
}
|
||||
detach := client.clientConnTransportDetachSnapshot()
|
||||
if detach == nil {
|
||||
t.Fatal("legacy client detach snapshot should follow logical state")
|
||||
}
|
||||
if detach.Reason != "read error" || detach.Err != "boom" || detach.Generation != attachGeneration {
|
||||
t.Fatalf("legacy client detach snapshot mismatch: %+v", detach)
|
||||
}
|
||||
|
||||
logical.clearTransportDetachState()
|
||||
if got := client.clientConnTransportDetachSnapshot(); got != nil {
|
||||
t.Fatalf("legacy client detach snapshot should clear with logical state, got %+v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
client := &ClientConn{
|
||||
ClientID: "client-handoff",
|
||||
server: server,
|
||||
}
|
||||
client.startClientConnSessionTransport(left, stopCtx, stopFn)
|
||||
|
||||
logical := client.LogicalConn()
|
||||
detachedConn, err := logical.detachTransportForTransfer()
|
||||
if err != nil {
|
||||
t.Fatalf("logical detachTransportForTransfer failed: %v", err)
|
||||
}
|
||||
defer detachedConn.Close()
|
||||
|
||||
payload := []byte("handoff-payload")
|
||||
readDone := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, len(payload))
|
||||
_ = right.SetReadDeadline(time.Now().Add(time.Second))
|
||||
if _, err := io.ReadFull(right, buf); err != nil {
|
||||
readDone <- err
|
||||
return
|
||||
}
|
||||
if !bytes.Equal(buf, payload) {
|
||||
readDone <- fmt.Errorf("payload mismatch: got %q want %q", string(buf), string(payload))
|
||||
return
|
||||
}
|
||||
readDone <- nil
|
||||
}()
|
||||
|
||||
_ = detachedConn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||
if _, err := detachedConn.Write(payload); err != nil {
|
||||
t.Fatalf("detached handoff conn write failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-readDone:
|
||||
if err != nil {
|
||||
t.Fatalf("handoff conn read failed: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for handoff conn read")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
|
||||
transport: newTransportBinding(left, nil),
|
||||
tuConn: left,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
})
|
||||
|
||||
binding := client.clientConnTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
t.Fatal("runtime transport binding should exist")
|
||||
}
|
||||
if got := binding.connSnapshot(); got != left {
|
||||
t.Fatal("runtime transport binding conn should match runtime conn")
|
||||
}
|
||||
if got := binding.queueSnapshot(); got != nil {
|
||||
t.Fatalf("server-side peer binding queue should remain nil, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientConnDetachServerOwnedSessionCancelsTransportOnly(t *testing.T) {
|
||||
client := &ClientConn{}
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
client.startClientConnSession(left, stopCtx, stopFn)
|
||||
|
||||
transportStopCtx := client.clientConnTransportStopContextSnapshot()
|
||||
client.detachServerOwnedSession()
|
||||
|
||||
if transportStopCtx == nil {
|
||||
t.Fatal("transport stop context should exist")
|
||||
}
|
||||
select {
|
||||
case <-transportStopCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("transport stop context should be canceled after detach")
|
||||
}
|
||||
select {
|
||||
case <-client.clientConnStopContextSnapshot().Done():
|
||||
t.Fatal("logical stop context should remain active after pure detach")
|
||||
default:
|
||||
}
|
||||
if client.clientConnTransportAttachedSnapshot() {
|
||||
t.Fatal("client conn transport should be marked detached after pure detach")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttachClientConnSessionTransportRebindsRuntimeAndStartsReadLoop(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, 1024)
|
||||
server.setServerSessionRuntime(&serverSessionRuntime{
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
})
|
||||
|
||||
oldLeft, oldRight := net.Pipe()
|
||||
defer oldRight.Close()
|
||||
client := &ClientConn{
|
||||
ClientID: "client-reattach",
|
||||
server: server,
|
||||
}
|
||||
client.startClientConnSession(oldLeft, stopCtx, stopFn)
|
||||
|
||||
newLeft, newRight := net.Pipe()
|
||||
defer newRight.Close()
|
||||
if err := client.attachClientConnSessionTransport(newLeft); err != nil {
|
||||
t.Fatalf("attachClientConnSessionTransport failed: %v", err)
|
||||
}
|
||||
|
||||
rt := client.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
t.Fatal("client conn runtime should exist after attach")
|
||||
}
|
||||
if rt.tuConn != newLeft || !rt.transportAttached {
|
||||
t.Fatalf("attached client conn runtime mismatch: %+v", rt)
|
||||
}
|
||||
|
||||
wire := queue.BuildMessage([]byte("reattached"))
|
||||
if _, err := newRight.Write(wire); err != nil {
|
||||
t.Fatalf("new transport write failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-queue.RestoreChan():
|
||||
source := assertServerInboundQueueSource(t, msg.Conn, client)
|
||||
if got, want := source.TransportGeneration, client.clientConnTransportGenerationSnapshot(); got != want {
|
||||
t.Fatalf("queue transport generation mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := string(msg.Msg), "reattached"; got != want {
|
||||
t.Fatalf("queue payload mismatch: got %q want %q", got, want)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reattached server-owned transport did not push framed message")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func newStartedClientConnForTest(t *testing.T, id string, server Server, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) {
|
||||
t.Helper()
|
||||
client := &ClientConn{
|
||||
ClientID: id,
|
||||
server: server,
|
||||
}
|
||||
stopCtx, stopFn = client.startClientConnSession(conn, stopCtx, stopFn)
|
||||
return client, stopCtx, stopFn
|
||||
}
|
||||
|
||||
func newRegisteredServerClientForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) {
|
||||
t.Helper()
|
||||
client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn)
|
||||
server.getPeerRegistry().registerClient(client)
|
||||
return client, stopCtx, stopFn
|
||||
}
|
||||
|
||||
func newRegisteredServerLogicalForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*LogicalConn, context.Context, context.CancelFunc) {
|
||||
t.Helper()
|
||||
client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn)
|
||||
logical := logicalConnFromClient(client)
|
||||
server.getPeerRegistry().registerLogical(logical)
|
||||
return logical, stopCtx, stopFn
|
||||
}
|
||||
|
||||
func newServerCodecClientConnForTest(server *ServerCommon) *ClientConn {
|
||||
client := &ClientConn{server: server}
|
||||
client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
|
||||
return client
|
||||
}
|
||||
@@ -0,0 +1,223 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type serverLogicalTransportDetacher interface {
|
||||
detachLogicalSessionTransport(logical *LogicalConn, reason string, err error)
|
||||
}
|
||||
|
||||
type serverInboundSourcePusher interface {
|
||||
pushMessageSource([]byte, interface{})
|
||||
}
|
||||
|
||||
func (c *LogicalConn) readTUMessage() {
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
c.readTUMessageLoop(rt)
|
||||
}
|
||||
|
||||
func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
stopCtx := rt.transportStopCtx
|
||||
if stopCtx == nil {
|
||||
stopCtx = rt.stopCtx
|
||||
}
|
||||
if stopCtx == nil {
|
||||
return
|
||||
}
|
||||
conn := rt.tuConn
|
||||
generation := rt.transportGeneration
|
||||
defer closeClientConnSessionRuntimeTransportDone(rt)
|
||||
buf := streamReadBuffer()
|
||||
for {
|
||||
select {
|
||||
case <-sessionStopChan(stopCtx):
|
||||
if c.shouldCloseTransportOnStop(conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return
|
||||
default:
|
||||
}
|
||||
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
|
||||
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
|
||||
if len(data) == 0 {
|
||||
data = streamReadBuffer()
|
||||
}
|
||||
if conn == nil {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
num, err := conn.Read(data)
|
||||
return num, data, err
|
||||
}
|
||||
|
||||
func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
|
||||
if err == os.ErrDeadlineExceeded {
|
||||
if num != 0 {
|
||||
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case <-sessionStopChan(stopCtx):
|
||||
if c.shouldCloseTransportOnStop(conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return false
|
||||
default:
|
||||
}
|
||||
if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
|
||||
detacher.detachLogicalSessionTransport(c, "read error", err)
|
||||
return false
|
||||
}
|
||||
c.stopServerOwnedSession("read error", err)
|
||||
return false
|
||||
}
|
||||
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
|
||||
if c == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
server := c.Server()
|
||||
if server == nil {
|
||||
return
|
||||
}
|
||||
if pusher, ok := server.(serverInboundSourcePusher); ok {
|
||||
pusher.pushMessageSource(data, newServerInboundSource(c, conn, nil, generation))
|
||||
return
|
||||
}
|
||||
server.pushMessage(data, c.clientConnIDSnapshot())
|
||||
}
|
||||
|
||||
func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool {
|
||||
if c == nil || conn == nil {
|
||||
return false
|
||||
}
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil || !rt.transportAttached {
|
||||
return false
|
||||
}
|
||||
current := rt.tuConn
|
||||
if rt.transport != nil && rt.transport.connSnapshot() != nil {
|
||||
current = rt.transport.connSnapshot()
|
||||
}
|
||||
return current == conn
|
||||
}
|
||||
|
||||
func (c *ClientConn) readFromTUTransport() (int, []byte, error) {
|
||||
binding := c.clientConnTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
conn := binding.connSnapshot()
|
||||
return c.readFromTUTransportConn(conn)
|
||||
}
|
||||
|
||||
func (c *ClientConn) readFromTUTransportConn(conn net.Conn) (int, []byte, error) {
|
||||
return c.readFromTUTransportConnWithBuffer(conn, streamReadBuffer())
|
||||
}
|
||||
|
||||
func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
return logical.readFromTUTransportConnWithBuffer(conn, data)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
data = streamReadBuffer()
|
||||
}
|
||||
if conn == nil {
|
||||
return 0, nil, net.ErrClosed
|
||||
}
|
||||
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
}
|
||||
num, err := conn.Read(data)
|
||||
return num, data, err
|
||||
}
|
||||
|
||||
func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool {
|
||||
return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err)
|
||||
}
|
||||
|
||||
func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
|
||||
}
|
||||
if err == os.ErrDeadlineExceeded {
|
||||
if num != 0 {
|
||||
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err != nil {
|
||||
select {
|
||||
case <-sessionStopChan(stopCtx):
|
||||
if c.shouldCloseClientConnTransportOnStop(conn) {
|
||||
_ = conn.Close()
|
||||
}
|
||||
return false
|
||||
default:
|
||||
}
|
||||
if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
|
||||
detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err)
|
||||
return false
|
||||
}
|
||||
c.stopServerOwnedSession("read error", err)
|
||||
return false
|
||||
}
|
||||
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
logical.pushServerOwnedTransportMessage(data, conn, generation)
|
||||
return
|
||||
}
|
||||
if c == nil || c.server == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
if pusher, ok := c.server.(serverInboundSourcePusher); ok {
|
||||
pusher.pushMessageSource(data, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation))
|
||||
return
|
||||
}
|
||||
c.server.pushMessage(data, c.clientConnIDSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool {
|
||||
if logical := c.LogicalConn(); logical != nil {
|
||||
return logical.shouldCloseTransportOnStop(conn)
|
||||
}
|
||||
if c == nil || conn == nil {
|
||||
return false
|
||||
}
|
||||
rt := c.clientConnSessionRuntimeSnapshot()
|
||||
if rt == nil || !rt.transportAttached {
|
||||
return false
|
||||
}
|
||||
current := rt.tuConn
|
||||
if rt.transport != nil && rt.transport.connSnapshot() != nil {
|
||||
current = rt.transport.connSnapshot()
|
||||
}
|
||||
return current == conn
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package notify
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
type clientConnTransportState struct {
|
||||
streamTransport atomic.Bool
|
||||
transportGen atomic.Uint64
|
||||
attachCount atomic.Uint64
|
||||
detachCount atomic.Uint64
|
||||
lastAttachAt atomic.Int64
|
||||
transportDetach atomic.Pointer[clientConnTransportDetachState]
|
||||
}
|
||||
|
||||
func cloneClientConnTransportDetachState(src *clientConnTransportDetachState) *clientConnTransportDetachState {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
cloned := *src
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func (c *LogicalConn) ensureTransportState() *clientConnTransportState {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if state := c.transportState.Load(); state != nil {
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
client.transportState.Store(state)
|
||||
}
|
||||
return state
|
||||
}
|
||||
client := c.compatClientConn()
|
||||
if client != nil {
|
||||
if state := client.transportState.Load(); state != nil {
|
||||
if c.transportState.CompareAndSwap(nil, state) {
|
||||
client.transportState.Store(state)
|
||||
return state
|
||||
}
|
||||
return c.ensureTransportState()
|
||||
}
|
||||
}
|
||||
state := &clientConnTransportState{}
|
||||
if c.transportState.CompareAndSwap(nil, state) {
|
||||
if client != nil {
|
||||
client.transportState.Store(state)
|
||||
}
|
||||
return state
|
||||
}
|
||||
return c.ensureTransportState()
|
||||
}
|
||||
|
||||
func (c *ClientConn) ensureClientConnTransportState() *clientConnTransportState {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
return logical.ensureTransportState()
|
||||
}
|
||||
if state := c.transportState.Load(); state != nil {
|
||||
return state
|
||||
}
|
||||
state := &clientConnTransportState{}
|
||||
if c.transportState.CompareAndSwap(nil, state) {
|
||||
return state
|
||||
}
|
||||
return c.transportState.Load()
|
||||
}
|
||||
|
||||
func (c *ClientConn) setClientConnTransportDetachState(state *clientConnTransportDetachState) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if logical := c.logicalView.Load(); logical != nil {
|
||||
logical.setTransportDetachState(state)
|
||||
return
|
||||
}
|
||||
transportState := c.ensureClientConnTransportState()
|
||||
if transportState == nil {
|
||||
return
|
||||
}
|
||||
transportState.transportDetach.Store(cloneClientConnTransportDetachState(state))
|
||||
}
|
||||
|
||||
func (c *LogicalConn) setTransportDetachState(state *clientConnTransportDetachState) {
|
||||
transportState := c.ensureTransportState()
|
||||
if transportState == nil {
|
||||
return
|
||||
}
|
||||
transportState.transportDetach.Store(cloneClientConnTransportDetachState(state))
|
||||
if client := c.compatClientConn(); client != nil {
|
||||
client.transportState.Store(transportState)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/notify/internal/transport"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
clientConnectSourceConn = "conn"
|
||||
clientConnectSourceNetwork = "network"
|
||||
clientConnectSourceTimeout = "timeout"
|
||||
clientConnectSourceFactory = "factory"
|
||||
)
|
||||
|
||||
var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable")
|
||||
|
||||
type clientConnectSource struct {
|
||||
kind string
|
||||
network string
|
||||
addr string
|
||||
dialFn func(context.Context) (net.Conn, error)
|
||||
}
|
||||
|
||||
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
||||
source := &clientConnectSource{kind: clientConnectSourceConn}
|
||||
if conn == nil {
|
||||
return source
|
||||
}
|
||||
if remoteAddr := conn.RemoteAddr(); remoteAddr != nil {
|
||||
source.network = remoteAddr.Network()
|
||||
source.addr = remoteAddr.String()
|
||||
}
|
||||
if source.network == "" {
|
||||
if localAddr := conn.LocalAddr(); localAddr != nil {
|
||||
source.network = localAddr.Network()
|
||||
}
|
||||
}
|
||||
return source
|
||||
}
|
||||
|
||||
func newClientNetworkConnectSource(network string, addr string) *clientConnectSource {
|
||||
return &clientConnectSource{
|
||||
kind: clientConnectSourceNetwork,
|
||||
network: network,
|
||||
addr: addr,
|
||||
dialFn: func(context.Context) (net.Conn, error) {
|
||||
return transport.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource {
|
||||
return &clientConnectSource{
|
||||
kind: clientConnectSourceTimeout,
|
||||
network: network,
|
||||
addr: addr,
|
||||
dialFn: func(context.Context) (net.Conn, error) {
|
||||
return transport.DialTimeout(network, addr, timeout)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource {
|
||||
return &clientConnectSource{
|
||||
kind: clientConnectSourceFactory,
|
||||
dialFn: dialFn,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *clientConnectSource) clone() *clientConnectSource {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
out := *s
|
||||
return &out
|
||||
}
|
||||
|
||||
func (s *clientConnectSource) canReconnect() bool {
|
||||
return s != nil && s.dialFn != nil
|
||||
}
|
||||
|
||||
func (s *clientConnectSource) isUDP() bool {
|
||||
if s == nil {
|
||||
return false
|
||||
}
|
||||
return transport.IsUDPNetwork(s.network)
|
||||
}
|
||||
|
||||
func (s *clientConnectSource) dial(ctx context.Context) (net.Conn, error) {
|
||||
if s == nil || s.dialFn == nil {
|
||||
return nil, errClientReconnectSourceUnavailable
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return s.dialFn(ctx)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setClientConnectSource(source *clientConnectSource) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if source == nil {
|
||||
c.connectSource.Store(nil)
|
||||
return
|
||||
}
|
||||
c.connectSource.Store(source.clone())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientConnectSourceSnapshot() *clientConnectSource {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
if source := c.connectSource.Load(); source != nil {
|
||||
return source.clone()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
package notify
|
||||
|
||||
func (c *ClientCommon) dispatchMsg(message Message) {
|
||||
switch message.TransferMsg.Type {
|
||||
case MSG_SYS_WAIT:
|
||||
fallthrough
|
||||
case MSG_SYS:
|
||||
c.sysMsg(message)
|
||||
return
|
||||
case MSG_KEY_CHANGE:
|
||||
fallthrough
|
||||
case MSG_SYS_REPLY:
|
||||
fallthrough
|
||||
case MSG_SYNC_REPLY:
|
||||
if c.getPendingWaitPool().deliver(message.ID, message) {
|
||||
return
|
||||
}
|
||||
fallthrough
|
||||
default:
|
||||
}
|
||||
if c.dispatchInternalTransferControl(message) {
|
||||
return
|
||||
}
|
||||
callFn := func(fn func(*Message)) {
|
||||
fn(&message)
|
||||
}
|
||||
fn, ok := c.linkFns[message.Key]
|
||||
if ok {
|
||||
callFn(fn)
|
||||
}
|
||||
if c.defaultFns != nil {
|
||||
callFn(c.defaultFns)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sysMsg(message Message) {
|
||||
switch message.Key {
|
||||
case "bye":
|
||||
if message.TransferMsg.Type == MSG_SYS_WAIT {
|
||||
c.setByeFromServer(true)
|
||||
message.Reply(nil)
|
||||
c.stopClientSession("recv stop signal from server", nil)
|
||||
return
|
||||
}
|
||||
c.stopClientSessionFromServer("recv stop signal from server", nil)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/starcrypto"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow.
|
||||
// Prefer UseModernPSKClient.
|
||||
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
|
||||
pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
data, err := c.sendWait(TransferMsg{
|
||||
ID: 19961127,
|
||||
Key: "sirius",
|
||||
Value: newSendKey,
|
||||
Type: MSG_KEY_CHANGE,
|
||||
}, time.Second*10)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if string(data.Value) != "success" {
|
||||
return errors.New("cannot exchange new aes-key")
|
||||
}
|
||||
c.SecretKey = newKey
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated: aesRsaHello is the legacy RSA-based key exchange bootstrap.
|
||||
func aesRsaHello(c Client) error {
|
||||
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
|
||||
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
|
||||
return c.ExchangeKey(newAesKey)
|
||||
}
|
||||
@@ -0,0 +1,158 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReconnectClientRejectsDirectConnSource(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
|
||||
client.SetSecretKey(secret)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
err := ReconnectClient(context.Background(), client)
|
||||
if !errors.Is(err, errClientReconnectSourceUnavailable) {
|
||||
t.Fatalf("ReconnectClient error = %v, want %v", err, errClientReconnectSourceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnectClientWithFactorySource(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
client.SetSecretKey(secret)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
|
||||
dialCount := 0
|
||||
var peers []net.Conn
|
||||
dialFn := func(context.Context) (net.Conn, error) {
|
||||
dialCount++
|
||||
left, right := net.Pipe()
|
||||
peers = append(peers, right)
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
return left, nil
|
||||
}
|
||||
|
||||
if err := client.ConnectByFactory(context.Background(), dialFn); err != nil {
|
||||
t.Fatalf("ConnectByFactory failed: %v", err)
|
||||
}
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
before, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot before reconnect failed: %v", err)
|
||||
}
|
||||
if !before.CanReconnect || before.ConnectSource != clientConnectSourceFactory {
|
||||
t.Fatalf("unexpected reconnect snapshot before reconnect: %+v", before)
|
||||
}
|
||||
|
||||
if err := ReconnectClient(context.Background(), client); err != nil {
|
||||
t.Fatalf("ReconnectClient failed: %v", err)
|
||||
}
|
||||
after, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot after reconnect failed: %v", err)
|
||||
}
|
||||
if !after.Alive || !after.HasRuntimeConn || !after.CanReconnect {
|
||||
t.Fatalf("unexpected reconnect snapshot after reconnect: %+v", after)
|
||||
}
|
||||
if got, want := dialCount, 2; got != want {
|
||||
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("final Stop failed: %v", err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
_ = peer.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestReconnectClientWithRetryRecordsRetryState(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
client.SetSecretKey(secret)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
|
||||
dialCount := 0
|
||||
wantErr := errors.New("dial failed once")
|
||||
var peers []net.Conn
|
||||
dialFn := func(context.Context) (net.Conn, error) {
|
||||
dialCount++
|
||||
if dialCount == 2 {
|
||||
return nil, wantErr
|
||||
}
|
||||
left, right := net.Pipe()
|
||||
peers = append(peers, right)
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
return left, nil
|
||||
}
|
||||
|
||||
if err := client.ConnectByFactory(context.Background(), dialFn); err != nil {
|
||||
t.Fatalf("ConnectByFactory failed: %v", err)
|
||||
}
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
|
||||
if err := ReconnectClientWithRetry(context.Background(), client, &ConnectRetryOptions{
|
||||
MaxAttempts: 3,
|
||||
BaseDelay: 0,
|
||||
MaxDelay: 0,
|
||||
}); err != nil {
|
||||
t.Fatalf("ReconnectClientWithRetry failed: %v", err)
|
||||
}
|
||||
snapshot, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.Retry.RetryEventTotal, uint64(1); got != want {
|
||||
t.Fatalf("retry events mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Retry.LastRetryAttempt, 1; got != want {
|
||||
t.Fatalf("last retry attempt mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Retry.LastRetryError, wantErr.Error(); got != want {
|
||||
t.Fatalf("last retry error mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if snapshot.Retry.LastResultError != "" {
|
||||
t.Fatalf("last result error should be empty, got %q", snapshot.Retry.LastResultError)
|
||||
}
|
||||
if got, want := dialCount, 3; got != want {
|
||||
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("final Stop failed: %v", err)
|
||||
}
|
||||
for _, peer := range peers {
|
||||
_ = peer.Close()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package notify
|
||||
|
||||
import "context"
|
||||
|
||||
func (c *ClientCommon) SetRecordStreamHandler(fn func(RecordAcceptInfo) error) {
|
||||
runtime := c.getRecordRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
runtime.setHandler(fn)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) {
|
||||
if c == nil {
|
||||
return nil, errStreamClientNil
|
||||
}
|
||||
opt = normalizeRecordOpenOptions(opt)
|
||||
stream, err := c.OpenStream(ctx, opt.Stream)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err := WrapStreamAsRecord(stream, opt)
|
||||
if err != nil {
|
||||
_ = stream.Reset(err)
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) claimInboundRecordStream(stream *streamHandle) (bool, error) {
|
||||
if stream == nil || stream.Channel() != StreamRecordChannel {
|
||||
return false, nil
|
||||
}
|
||||
runtime := c.getRecordRuntime()
|
||||
if runtime == nil {
|
||||
return true, errRecordRuntimeNil
|
||||
}
|
||||
handler := runtime.handlerSnapshot()
|
||||
if handler == nil {
|
||||
return true, errRecordHandlerNotConfigured
|
||||
}
|
||||
record, err := WrapStreamAsRecord(stream, RecordOpenOptions{
|
||||
Stream: StreamOpenOptions{
|
||||
ID: stream.ID(),
|
||||
Channel: stream.Channel(),
|
||||
Metadata: stream.Metadata(),
|
||||
ReadTimeout: stream.readTimeoutSnapshot(),
|
||||
WriteTimeout: stream.writeTimeoutSnapshot(),
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
info := RecordAcceptInfo{
|
||||
ID: stream.ID(),
|
||||
Metadata: stream.Metadata(),
|
||||
TransportGeneration: stream.TransportGeneration(),
|
||||
RecordStream: record,
|
||||
}
|
||||
go func() {
|
||||
if err := handler(info); err != nil {
|
||||
_ = record.Reset(err)
|
||||
}
|
||||
}()
|
||||
return true, nil
|
||||
}
|
||||
@@ -0,0 +1,523 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *ClientCommon) closeClientTransport() {
|
||||
c.closeClientTransportBinding(c.clientTransportBindingSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) closeClientTransportConn(conn net.Conn) {
|
||||
if c == nil || conn == nil {
|
||||
return
|
||||
}
|
||||
_ = conn.Close()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) {
|
||||
if binding == nil {
|
||||
return
|
||||
}
|
||||
c.closeClientTransportConn(binding.connSnapshot())
|
||||
binding.stopBackgroundWorkers()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) beginClientSessionEpoch() uint64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
return atomic.AddUint64(&c.sessionEpoch, 1)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) currentClientSessionEpoch() uint64 {
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
return atomic.LoadUint64(&c.sessionEpoch)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool {
|
||||
if c == nil || epoch == 0 {
|
||||
return false
|
||||
}
|
||||
return c.currentClientSessionEpoch() == epoch
|
||||
}
|
||||
|
||||
func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool {
|
||||
if !c.isClientSessionEpochCurrent(epoch) {
|
||||
return false
|
||||
}
|
||||
c.stopClientSession(reason, err)
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setByeFromServer(val bool) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.byeFromServer = val
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) resetClientStopState() {
|
||||
c.setByeFromServer(false)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) shouldSayGoodByeOnStop() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return !c.byeFromServer
|
||||
}
|
||||
|
||||
func (c *ClientCommon) stopClientSession(reason string, err error) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.markSessionStopped(reason, err)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.setByeFromServer(true)
|
||||
c.markSessionStopped(reason, err)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) {
|
||||
if !c.beginClientSessionStart() {
|
||||
return nil, errors.New("client already run")
|
||||
}
|
||||
return func(success bool) {
|
||||
if success {
|
||||
return
|
||||
}
|
||||
c.cleanupFailedClientStart()
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientCanAttachTransport() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
if !sessionIsAlive(&c.alive) {
|
||||
return false
|
||||
}
|
||||
if c.clientTransportAttachedSnapshot() {
|
||||
return false
|
||||
}
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return false
|
||||
}
|
||||
return rt.stopCtx != nil && rt.queue != nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
|
||||
if c == nil {
|
||||
return errors.New("client is nil")
|
||||
}
|
||||
if conn == nil {
|
||||
return errors.New("conn is nil")
|
||||
}
|
||||
if err := c.attachClientSessionTransport(conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return err
|
||||
}
|
||||
if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil {
|
||||
return err
|
||||
}
|
||||
c.setClientConnectSource(source)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Connect(network string, addr string) error {
|
||||
if err := c.validateSecurityConfiguration(); err != nil {
|
||||
return err
|
||||
}
|
||||
source := newClientNetworkConnectSource(network, addr)
|
||||
c.applySignalReliabilityTransportDefault(source.isUDP())
|
||||
if c.clientCanAttachTransport() {
|
||||
conn, err := source.dial(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.attachClientWithConnSource(conn, source)
|
||||
}
|
||||
finish, err := c.beginClientConnectAttempt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
started := false
|
||||
defer func() {
|
||||
finish(started)
|
||||
}()
|
||||
conn, err := source.dial(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||
return err
|
||||
}
|
||||
started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
|
||||
if err := c.validateSecurityConfiguration(); err != nil {
|
||||
return err
|
||||
}
|
||||
source := newClientTimeoutConnectSource(network, addr, timeout)
|
||||
c.applySignalReliabilityTransportDefault(source.isUDP())
|
||||
if c.clientCanAttachTransport() {
|
||||
conn, err := source.dial(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.attachClientWithConnSource(conn, source)
|
||||
}
|
||||
finish, err := c.beginClientConnectAttempt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
started := false
|
||||
defer func() {
|
||||
finish(started)
|
||||
}()
|
||||
conn, err := source.dial(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||
return err
|
||||
}
|
||||
started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ConnectByConn(conn net.Conn) error {
|
||||
if err := c.validateSecurityConfiguration(); err != nil {
|
||||
return err
|
||||
}
|
||||
if conn == nil {
|
||||
return errors.New("conn is nil")
|
||||
}
|
||||
source := newClientConnConnectSource(conn)
|
||||
c.applySignalReliabilityTransportDefault(false)
|
||||
if c.clientCanAttachTransport() {
|
||||
return c.attachClientWithConnSource(conn, source)
|
||||
}
|
||||
finish, err := c.beginClientConnectAttempt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
started := false
|
||||
defer func() {
|
||||
finish(started)
|
||||
}()
|
||||
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||
return err
|
||||
}
|
||||
started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error {
|
||||
if err := c.validateSecurityConfiguration(); err != nil {
|
||||
return err
|
||||
}
|
||||
if dialFn == nil {
|
||||
return errors.New("dialFn is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
source := newClientFactoryConnectSource(dialFn)
|
||||
if c.clientCanAttachTransport() {
|
||||
c.applySignalReliabilityTransportDefault(false)
|
||||
conn, err := dialFn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conn == nil {
|
||||
return errors.New("conn is nil")
|
||||
}
|
||||
return c.attachClientWithConnSource(conn, source)
|
||||
}
|
||||
finish, err := c.beginClientConnectAttempt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
started := false
|
||||
defer func() {
|
||||
finish(started)
|
||||
}()
|
||||
conn, err := dialFn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conn == nil {
|
||||
return errors.New("conn is nil")
|
||||
}
|
||||
c.applySignalReliabilityTransportDefault(false)
|
||||
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||
return err
|
||||
}
|
||||
started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) startClientWithConn(conn net.Conn) error {
|
||||
return c.startClientWithConnSource(conn, newClientConnConnectSource(conn))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
epoch := c.beginClientSessionEpoch()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
c.setClientConnectSource(source)
|
||||
rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch)
|
||||
c.setClientSessionRuntime(rt)
|
||||
c.resetClientStopState()
|
||||
c.markSessionStarted()
|
||||
return c.clientPostInit(rt)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) monitorPool() {
|
||||
c.monitorPoolLoop(c.clientStopContextSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) {
|
||||
if stopCtx == nil {
|
||||
return
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-stopCtx.Done():
|
||||
if c.clientStopContextSnapshot() == stopCtx {
|
||||
c.getPendingWaitPool().closeAll()
|
||||
c.getFileAckPool().closeAll()
|
||||
c.getSignalAckPool().closeAll()
|
||||
}
|
||||
return
|
||||
case <-time.After(time.Second * 30):
|
||||
}
|
||||
now := time.Now()
|
||||
c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error {
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
go c.monitorPoolLoop(rt.stopCtx)
|
||||
if err := c.startClientTransportRuntime(rt); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.bootstrapClientTransportRuntime(rt, true, true)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error {
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
transportStopCtx := rt.transportStopCtx
|
||||
if transportStopCtx == nil {
|
||||
transportStopCtx = rt.stopCtx
|
||||
}
|
||||
if c.useHeartBeat {
|
||||
go c.heartbeatLoop(transportStopCtx, rt.epoch)
|
||||
}
|
||||
go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch)
|
||||
go c.loadMessageLoop(rt)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error {
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if runKeyExchange && !c.skipKeyExchange {
|
||||
if err := c.keyExchangeFn(c); err != nil {
|
||||
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
|
||||
}
|
||||
}
|
||||
if err := c.announceClientPeerIdentity(); err != nil {
|
||||
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error {
|
||||
if c == nil || rt == nil {
|
||||
return err
|
||||
}
|
||||
c.retireClientSessionRuntime(rt, true)
|
||||
c.closeClientTransportConn(rt.conn)
|
||||
if stopSessionOnFailure {
|
||||
c.stopClientSessionIfCurrent(rt.epoch, reason, err)
|
||||
return err
|
||||
}
|
||||
c.clearClientSessionRuntimeTransport()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Heartbeat() {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
epoch := rt.epoch
|
||||
if epoch == 0 {
|
||||
epoch = c.currentClientSessionEpoch()
|
||||
}
|
||||
transportStopCtx := rt.transportStopCtx
|
||||
if transportStopCtx == nil {
|
||||
transportStopCtx = rt.stopCtx
|
||||
}
|
||||
c.heartbeatLoop(transportStopCtx, epoch)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) {
|
||||
if stopCtx == nil {
|
||||
return
|
||||
}
|
||||
failedCount := 0
|
||||
for {
|
||||
select {
|
||||
case <-stopCtx.Done():
|
||||
return
|
||||
case <-time.After(c.heartbeatPeriod):
|
||||
}
|
||||
err := c.sendHeartbeat()
|
||||
var stop bool
|
||||
failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount)
|
||||
if stop {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readMessage() {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
epoch := rt.epoch
|
||||
if epoch == 0 {
|
||||
epoch = c.currentClientSessionEpoch()
|
||||
}
|
||||
transportStopCtx := rt.transportStopCtx
|
||||
if transportStopCtx == nil {
|
||||
transportStopCtx = rt.stopCtx
|
||||
}
|
||||
c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) {
|
||||
if stopCtx == nil {
|
||||
return
|
||||
}
|
||||
binding := newTransportBinding(conn, queue)
|
||||
dispatcher := c.clientInboundDispatcherSnapshot()
|
||||
buf := streamReadBuffer()
|
||||
for {
|
||||
select {
|
||||
case <-stopCtx.Done():
|
||||
c.closeClientTransportBinding(binding)
|
||||
return
|
||||
default:
|
||||
}
|
||||
readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf)
|
||||
if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sayGoodBye() error {
|
||||
_, err := c.sendWait(TransferMsg{
|
||||
ID: 10010,
|
||||
Key: "bye",
|
||||
Value: nil,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}, time.Second*3)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) loadMessage() {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
c.loadMessageLoop(rt)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
stopCtx := rt.transportStopCtx
|
||||
if stopCtx == nil {
|
||||
stopCtx = rt.stopCtx
|
||||
}
|
||||
if stopCtx == nil {
|
||||
return
|
||||
}
|
||||
queue := rt.queue
|
||||
if rt.transport != nil {
|
||||
queue = rt.transport.queueSnapshot()
|
||||
}
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
dispatcher := rt.inboundDispatcher
|
||||
if dispatcher == nil {
|
||||
dispatcher = newInboundDispatcher()
|
||||
defer dispatcher.CloseAndWait()
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-stopCtx.Done():
|
||||
sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil
|
||||
if sessionStopping && rt.inboundDispatcher != nil {
|
||||
rt.inboundDispatcher.CloseAndWait()
|
||||
}
|
||||
if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() {
|
||||
c.sayGoodBye()
|
||||
}
|
||||
c.closeClientTransportBinding(rt.transport)
|
||||
return
|
||||
case data, ok := <-queue.RestoreChan():
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
msg := data
|
||||
c.wg.Add(1)
|
||||
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
||||
defer c.wg.Done()
|
||||
now := time.Now()
|
||||
if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client decode envelope error", err)
|
||||
}
|
||||
}
|
||||
}) {
|
||||
c.wg.Done()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+193
@@ -0,0 +1,193 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
|
||||
if err := c.ensureClientSendReady(); err != nil {
|
||||
return WaitMsg{}, err
|
||||
}
|
||||
var wait WaitMsg
|
||||
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
|
||||
msg.ID = atomic.AddUint64(&c.msgID, 1)
|
||||
}
|
||||
env, err := wrapTransferMsgEnvelope(msg, c.sequenceEn)
|
||||
if err != nil {
|
||||
return WaitMsg{}, err
|
||||
}
|
||||
if requiresSignalReplyWait(msg) {
|
||||
wait = c.getPendingWaitPool().createAndStore(msg)
|
||||
}
|
||||
err = c.sendSignalEnvelopeMaybeReliable(env, msg)
|
||||
if err != nil {
|
||||
if requiresSignalReplyWait(msg) {
|
||||
c.getPendingWaitPool().removeAndClose(msg.ID)
|
||||
}
|
||||
return WaitMsg{}, err
|
||||
}
|
||||
return wait, err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendEnvelope(env Envelope) error {
|
||||
if err := c.ensureClientSendReady(); err != nil {
|
||||
return err
|
||||
}
|
||||
payload, err := c.encodeEnvelopePayload(env)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if batchedControlEnvelope(env) {
|
||||
return c.writeControlPayloadToTransport(payload)
|
||||
}
|
||||
return c.writePayloadToTransport(payload)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) dispatchEnvelope(env Envelope, now time.Time) {
|
||||
switch env.Kind {
|
||||
case EnvelopeSignalAck:
|
||||
if c.handleSignalAckEnvelope(env) {
|
||||
return
|
||||
}
|
||||
case EnvelopeStreamData:
|
||||
c.dispatchStreamEnvelope(env)
|
||||
return
|
||||
case EnvelopeSignal:
|
||||
transfer, err := unwrapTransferMsgEnvelope(env, c.sequenceDe)
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client unwrap signal envelope error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if c.handleReceivedSignalReliability(transfer) {
|
||||
return
|
||||
}
|
||||
message := Message{
|
||||
ServerConn: c,
|
||||
TransferMsg: transfer,
|
||||
NetType: NET_CLIENT,
|
||||
Time: now,
|
||||
}
|
||||
c.dispatchMsg(message)
|
||||
case EnvelopeFileMeta, EnvelopeFileChunk, EnvelopeFileEnd, EnvelopeFileAbort, EnvelopeAck:
|
||||
c.dispatchFileEnvelope(env, now)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Send(key string, value MsgVal) error {
|
||||
_, err := c.send(TransferMsg{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MSG_ASYNC,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
|
||||
data, err := c.send(msg)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
stopCh := sessionStopChan(c.clientStopContextSnapshot())
|
||||
if timeout.Seconds() == 0 {
|
||||
msg, ok := <-data.Reply
|
||||
if !ok {
|
||||
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
select {
|
||||
case <-time.After(timeout):
|
||||
c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID)
|
||||
return Message{}, os.ErrDeadlineExceeded
|
||||
case <-stopCh:
|
||||
return Message{}, errServiceShutdown
|
||||
case msg, ok := <-data.Reply:
|
||||
if !ok {
|
||||
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
|
||||
data, err := c.send(msg)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
stopCh := sessionStopChan(c.clientStopContextSnapshot())
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID)
|
||||
return Message{}, normalizeStreamDeadlineError(ctx.Err())
|
||||
case <-stopCh:
|
||||
return Message{}, errServiceShutdown
|
||||
case msg, ok := <-data.Reply:
|
||||
if !ok {
|
||||
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
|
||||
data, err := c.sequenceEn(val)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return c.sendCtx(TransferMsg{
|
||||
Key: key,
|
||||
Value: data,
|
||||
Type: MSG_SYNC_ASK,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendObj(key string, val interface{}) error {
|
||||
data, err := encode(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = c.send(TransferMsg{
|
||||
Key: key,
|
||||
Value: data,
|
||||
Type: MSG_ASYNC,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
|
||||
return c.sendCtx(TransferMsg{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MSG_SYNC_ASK,
|
||||
}, ctx)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
|
||||
return c.sendWait(TransferMsg{
|
||||
Key: key,
|
||||
Value: value,
|
||||
Type: MSG_SYNC_ASK,
|
||||
}, timeout)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
|
||||
data, err := c.sequenceEn(value)
|
||||
if err != nil {
|
||||
return Message{}, err
|
||||
}
|
||||
return c.SendWait(key, data, timeout)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
|
||||
return m.Reply(value)
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientStopSessionIfCurrentEpoch(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.markSessionStarted()
|
||||
|
||||
staleEpoch := client.beginClientSessionEpoch()
|
||||
currentEpoch := client.beginClientSessionEpoch()
|
||||
|
||||
if client.stopClientSessionIfCurrent(staleEpoch, "stale", nil) {
|
||||
t.Fatal("stale epoch should not stop current session")
|
||||
}
|
||||
status := client.Status()
|
||||
if !status.Alive || status.Reason != "" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after stale stop: %+v", status)
|
||||
}
|
||||
|
||||
if !client.stopClientSessionIfCurrent(currentEpoch, "current", nil) {
|
||||
t.Fatal("current epoch should stop session")
|
||||
}
|
||||
status = client.Status()
|
||||
if status.Alive || status.Reason != "current" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after current stop: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadErrorWithStaleEpochDoesNotStopCurrentSession(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.markSessionStarted()
|
||||
|
||||
staleEpoch := client.beginClientSessionEpoch()
|
||||
currentEpoch := client.beginClientSessionEpoch()
|
||||
|
||||
readErr := errors.New("read failed")
|
||||
client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, staleEpoch)
|
||||
|
||||
status := client.Status()
|
||||
if !status.Alive || status.Reason != "" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after stale read error: %+v", status)
|
||||
}
|
||||
|
||||
client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, currentEpoch)
|
||||
|
||||
status = client.Status()
|
||||
if status.Alive || status.Reason != "client read error" || !errors.Is(status.Err, readErr) {
|
||||
t.Fatalf("unexpected status after current read error: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestHeartbeatFailureWithStaleEpochDoesNotStopCurrentSession(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.markSessionStarted()
|
||||
|
||||
staleEpoch := client.beginClientSessionEpoch()
|
||||
currentEpoch := client.beginClientSessionEpoch()
|
||||
heartbeatErr := errors.New("heartbeat failed")
|
||||
|
||||
failedCount, stop := client.handleHeartbeatResultWithSession(staleEpoch, heartbeatErr, 2)
|
||||
if failedCount != 3 || !stop {
|
||||
t.Fatalf("unexpected stale heartbeat result: failedCount=%d stop=%v", failedCount, stop)
|
||||
}
|
||||
status := client.Status()
|
||||
if !status.Alive || status.Reason != "" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after stale heartbeat error: %+v", status)
|
||||
}
|
||||
|
||||
failedCount, stop = client.handleHeartbeatResultWithSession(currentEpoch, heartbeatErr, 2)
|
||||
if failedCount != 3 || !stop {
|
||||
t.Fatalf("unexpected current heartbeat result: failedCount=%d stop=%v", failedCount, stop)
|
||||
}
|
||||
status = client.Status()
|
||||
if status.Alive || status.Reason != "heartbeat failed more than 3 times" || status.Err == nil || status.Err.Error() != "heartbeat failed more than 3 times" {
|
||||
t.Fatalf("unexpected status after current heartbeat error: %+v", status)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type clientSessionRuntime struct {
|
||||
transport *transportBinding
|
||||
transportAttached bool
|
||||
conn net.Conn
|
||||
stopCtx context.Context
|
||||
stopFn context.CancelFunc
|
||||
transportStopCtx context.Context
|
||||
transportStopFn context.CancelFunc
|
||||
queue *stario.StarQueue
|
||||
inboundDispatcher *inboundDispatcher
|
||||
epoch uint64
|
||||
suppressGoodByeOnStop *atomic.Bool
|
||||
}
|
||||
|
||||
func newClientSessionRuntimeBase(stopCtx context.Context, stopFn context.CancelFunc) *clientSessionRuntime {
|
||||
return &clientSessionRuntime{
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
inboundDispatcher: newInboundDispatcher(),
|
||||
suppressGoodByeOnStop: &atomic.Bool{},
|
||||
}
|
||||
}
|
||||
|
||||
func prepareClientSessionRuntime(rt *clientSessionRuntime) *clientSessionRuntime {
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.inboundDispatcher == nil {
|
||||
rt.inboundDispatcher = newInboundDispatcher()
|
||||
}
|
||||
if rt.suppressGoodByeOnStop == nil {
|
||||
rt.suppressGoodByeOnStop = &atomic.Bool{}
|
||||
}
|
||||
if rt.transport == nil && rt.conn != nil {
|
||||
rt.transport = newTransportBinding(rt.conn, rt.queue)
|
||||
}
|
||||
normalizeClientSessionRuntimeTransportState(rt)
|
||||
ensureClientSessionRuntimeTransportLifecycle(rt)
|
||||
return rt
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setClientSessionRuntime(rt *clientSessionRuntime) {
|
||||
if c == nil || rt == nil {
|
||||
return
|
||||
}
|
||||
var oldBinding *transportBinding
|
||||
if prev := c.clientSessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport {
|
||||
oldBinding = prev.transport
|
||||
}
|
||||
rt = prepareClientSessionRuntime(rt)
|
||||
c.sessionRuntime.Store(rt)
|
||||
c.stopCtx = rt.stopCtx
|
||||
c.stopFn = rt.stopFn
|
||||
if rt.transport != nil {
|
||||
c.queue = rt.transport.queueSnapshot()
|
||||
c.conn = rt.transport.connSnapshot()
|
||||
} else {
|
||||
c.queue = rt.queue
|
||||
c.conn = rt.conn
|
||||
}
|
||||
if oldBinding != nil {
|
||||
oldBinding.stopBackgroundWorkers()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) resetClientSessionRuntimeBase() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
c.sessionRuntime.Store(newClientSessionRuntimeBase(stopCtx, stopFn))
|
||||
c.conn = nil
|
||||
c.queue = nil
|
||||
c.stopCtx = stopCtx
|
||||
c.stopFn = stopFn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) cleanupFailedClientStart() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt != nil && rt.stopFn != nil {
|
||||
rt.stopFn()
|
||||
}
|
||||
c.cleanupClientSessionResources()
|
||||
c.rollbackClientSessionStart()
|
||||
c.resetClientSessionRuntimeBase()
|
||||
}
|
||||
|
||||
func newClientSessionRuntime(conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc, queue *stario.StarQueue, epoch uint64) *clientSessionRuntime {
|
||||
return prepareClientSessionRuntime(&clientSessionRuntime{
|
||||
transport: newTransportBinding(conn, queue),
|
||||
transportAttached: conn != nil,
|
||||
conn: conn,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
inboundDispatcher: newInboundDispatcher(),
|
||||
epoch: epoch,
|
||||
suppressGoodByeOnStop: &atomic.Bool{},
|
||||
})
|
||||
}
|
||||
|
||||
func (rt *clientSessionRuntime) runtimeShouldSuppressGoodByeOnStop() bool {
|
||||
if rt == nil || rt.suppressGoodByeOnStop == nil {
|
||||
return false
|
||||
}
|
||||
return rt.suppressGoodByeOnStop.Load()
|
||||
}
|
||||
|
||||
func (rt *clientSessionRuntime) markRuntimeSuppressGoodByeOnStop() {
|
||||
if rt == nil || rt.suppressGoodByeOnStop == nil {
|
||||
return
|
||||
}
|
||||
rt.suppressGoodByeOnStop.Store(true)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) retireClientSessionRuntime(rt *clientSessionRuntime, suppressGoodBye bool) {
|
||||
if c == nil || rt == nil {
|
||||
return
|
||||
}
|
||||
if suppressGoodBye {
|
||||
rt.markRuntimeSuppressGoodByeOnStop()
|
||||
}
|
||||
if rt.transportStopFn != nil {
|
||||
rt.transportStopFn()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clearClientSessionRuntimeTransport() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.transportStopFn != nil {
|
||||
rt.transportStopFn()
|
||||
}
|
||||
next := *rt
|
||||
next.transport = nil
|
||||
next.transportAttached = false
|
||||
next.conn = nil
|
||||
next.transportStopCtx = nil
|
||||
next.transportStopFn = nil
|
||||
c.setClientSessionRuntime(&next)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clearClientSessionRuntimeQueue() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
next := *rt
|
||||
next.queue = nil
|
||||
if next.transport != nil {
|
||||
next.transport = newTransportBinding(next.transport.connSnapshot(), nil)
|
||||
}
|
||||
c.setClientSessionRuntime(&next)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) attachClientSessionTransport(conn net.Conn) error {
|
||||
if c == nil {
|
||||
return errors.New("client is nil")
|
||||
}
|
||||
if conn == nil {
|
||||
return errors.New("conn is nil")
|
||||
}
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return errors.New("client session runtime is nil")
|
||||
}
|
||||
if rt.queue == nil {
|
||||
return errClientSessionQueueUnavailable
|
||||
}
|
||||
oldBinding := rt.transport
|
||||
if rt.transportStopFn != nil {
|
||||
rt.transportStopFn()
|
||||
}
|
||||
next := *rt
|
||||
next.transport = newTransportBinding(conn, rt.queue)
|
||||
next.transportAttached = true
|
||||
next.conn = conn
|
||||
next.transportStopCtx = nil
|
||||
next.transportStopFn = nil
|
||||
next.suppressGoodByeOnStop = &atomic.Bool{}
|
||||
c.setClientSessionRuntime(&next)
|
||||
if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != conn {
|
||||
_ = oldConn.Close()
|
||||
}
|
||||
return c.startClientTransportRuntime(c.clientSessionRuntimeSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientSessionRuntimeSnapshot() *clientSessionRuntime {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
return c.sessionRuntime.Load()
|
||||
}
|
||||
|
||||
func normalizeClientSessionRuntimeTransportState(rt *clientSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.transport != nil {
|
||||
rt.transportAttached = rt.transport.connSnapshot() != nil
|
||||
return
|
||||
}
|
||||
rt.transportAttached = rt.conn != nil
|
||||
}
|
||||
|
||||
func ensureClientSessionRuntimeTransportLifecycle(rt *clientSessionRuntime) {
|
||||
if rt == nil {
|
||||
return
|
||||
}
|
||||
if rt.conn == nil {
|
||||
rt.transportStopCtx = nil
|
||||
rt.transportStopFn = nil
|
||||
return
|
||||
}
|
||||
if rt.transportStopCtx != nil && rt.transportStopFn != nil {
|
||||
return
|
||||
}
|
||||
parent := rt.stopCtx
|
||||
if parent == nil {
|
||||
parent = context.Background()
|
||||
}
|
||||
rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientTransportConnSnapshot() net.Conn {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transport != nil {
|
||||
return rt.transport.connSnapshot()
|
||||
}
|
||||
return rt.conn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientInboundDispatcherSnapshot() *inboundDispatcher {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
return rt.inboundDispatcher
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientStopContextSnapshot() context.Context {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
return rt.stopCtx
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientStopFuncSnapshot() context.CancelFunc {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
return rt.stopFn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientQueueSnapshot() *stario.StarQueue {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transport != nil {
|
||||
return rt.transport.queueSnapshot()
|
||||
}
|
||||
return rt.queue
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientTransportBindingSnapshot() *transportBinding {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transport != nil {
|
||||
return rt.transport
|
||||
}
|
||||
if rt.conn == nil {
|
||||
return nil
|
||||
}
|
||||
return newTransportBinding(rt.conn, rt.queue)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientTransportStopContextSnapshot() context.Context {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return nil
|
||||
}
|
||||
if rt.transportStopCtx != nil {
|
||||
return rt.transportStopCtx
|
||||
}
|
||||
return rt.stopCtx
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientTransportAttachedSnapshot() bool {
|
||||
rt := c.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
return false
|
||||
}
|
||||
return rt.transportAttached
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientWriteToTransportUsesRuntimeConn(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
fallbackLeft, fallbackRight := net.Pipe()
|
||||
defer fallbackLeft.Close()
|
||||
defer fallbackRight.Close()
|
||||
runtimeLeft, runtimeRight := net.Pipe()
|
||||
defer runtimeLeft.Close()
|
||||
defer runtimeRight.Close()
|
||||
|
||||
client.conn = fallbackLeft
|
||||
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||
defer runtimeCancel()
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
conn: runtimeLeft,
|
||||
stopCtx: runtimeCtx,
|
||||
stopFn: runtimeCancel,
|
||||
epoch: 1,
|
||||
})
|
||||
|
||||
payload := []byte("runtime-conn")
|
||||
recvCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, len(payload))
|
||||
_, err := io.ReadFull(runtimeRight, buf)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
recvCh <- buf
|
||||
}()
|
||||
|
||||
if err := client.writeToTransport(payload); err != nil {
|
||||
t.Fatalf("writeToTransport failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
t.Fatalf("runtime conn read failed: %v", err)
|
||||
case got := <-recvCh:
|
||||
if string(got) != string(payload) {
|
||||
t.Fatalf("runtime payload mismatch: got %q want %q", string(got), string(payload))
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("runtime conn did not receive payload")
|
||||
}
|
||||
|
||||
_ = fallbackRight.SetReadDeadline(time.Now().Add(20 * time.Millisecond))
|
||||
buf := make([]byte, 1)
|
||||
if _, err := fallbackRight.Read(buf); err == nil {
|
||||
t.Fatal("fallback conn should not receive payload when runtime conn is active")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
if !client.beginClientSessionStart() {
|
||||
t.Fatal("beginClientSessionStart should succeed")
|
||||
}
|
||||
client.markSessionStarted()
|
||||
|
||||
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||
defer runtimeCancel()
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
stopCtx: runtimeCtx,
|
||||
stopFn: runtimeCancel,
|
||||
epoch: 1,
|
||||
})
|
||||
|
||||
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
|
||||
defer fallbackCancel()
|
||||
client.stopCtx = fallbackCtx
|
||||
client.stopFn = fallbackCancel
|
||||
|
||||
client.markSessionStopped("runtime stop", nil)
|
||||
|
||||
select {
|
||||
case <-runtimeCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("runtime stop context should be canceled by markSessionStopped")
|
||||
}
|
||||
select {
|
||||
case <-fallbackCtx.Done():
|
||||
t.Fatal("fallback owner stop context should not be canceled when runtime stopFn is active")
|
||||
default:
|
||||
}
|
||||
rt := client.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
t.Fatal("runtime snapshot should remain available after stop")
|
||||
}
|
||||
if rt.conn != nil || rt.queue != nil {
|
||||
t.Fatalf("runtime transport should be cleared after stop: %+v", rt)
|
||||
}
|
||||
if rt.stopCtx == nil {
|
||||
t.Fatalf("runtime stop context should be preserved after stop: %+v", rt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientClearSessionRuntimeTransportPreservesStopState(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
conn: left,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
epoch: 7,
|
||||
})
|
||||
|
||||
client.clearClientSessionRuntimeTransport()
|
||||
|
||||
rt := client.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
t.Fatal("runtime snapshot should remain after transport clear")
|
||||
}
|
||||
if rt.conn != nil {
|
||||
t.Fatalf("runtime conn should be cleared: %+v", rt)
|
||||
}
|
||||
if rt.queue != queue {
|
||||
t.Fatalf("runtime queue should be preserved across pure transport clear: got %v want %v", rt.queue, queue)
|
||||
}
|
||||
if rt.stopCtx != stopCtx || rt.stopFn == nil || rt.epoch != 7 {
|
||||
t.Fatalf("runtime control state should be preserved: %+v", rt)
|
||||
}
|
||||
if client.clientTransportAttachedSnapshot() {
|
||||
t.Fatal("client transport should be marked detached after runtime clear")
|
||||
}
|
||||
if got := client.clientQueueSnapshot(); got != queue {
|
||||
t.Fatalf("client queue snapshot should be preserved after transport clear: got %v want %v", got, queue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
transport: newTransportBinding(left, queue),
|
||||
conn: left,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
epoch: 9,
|
||||
})
|
||||
|
||||
binding := client.clientTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
t.Fatal("runtime transport binding should exist")
|
||||
}
|
||||
if got := binding.connSnapshot(); got != left {
|
||||
t.Fatal("runtime transport binding conn should match runtime conn")
|
||||
}
|
||||
if got := binding.queueSnapshot(); got != queue {
|
||||
t.Fatal("runtime transport binding queue should match runtime queue")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetireClientSessionRuntimeCancelsTransportOnly(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
rt := newClientSessionRuntime(left, stopCtx, stopFn, queue, 3)
|
||||
client.setClientSessionRuntime(rt)
|
||||
client.retireClientSessionRuntime(rt, true)
|
||||
|
||||
transportStopCtx := client.clientTransportStopContextSnapshot()
|
||||
if transportStopCtx == nil {
|
||||
t.Fatal("transport stop context should exist")
|
||||
}
|
||||
select {
|
||||
case <-transportStopCtx.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("transport stop context should be canceled by retireClientSessionRuntime")
|
||||
}
|
||||
select {
|
||||
case <-client.clientStopContextSnapshot().Done():
|
||||
t.Fatal("logical stop context should remain active when only retiring transport")
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientClearSessionRuntimeTransportPreservesQueueForEncoding(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
UseLegacySecurityClient(client)
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
conn: left,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
epoch: 8,
|
||||
})
|
||||
client.markSessionStarted()
|
||||
defer client.markSessionStopped("test done", nil)
|
||||
|
||||
client.clearClientSessionRuntimeTransport()
|
||||
|
||||
data, err := client.encodeEnvelope(newSignalAckEnvelope(1003))
|
||||
if err != nil {
|
||||
t.Fatalf("encodeEnvelope failed after pure transport clear: %v", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("encodeEnvelope should still return framed payload after pure transport clear")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAttachClientSessionTransportRebindsRuntimeAndDispatchesOnNewConn(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
UseLegacySecurityClient(client)
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
oldLeft, oldRight := net.Pipe()
|
||||
defer oldRight.Close()
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
conn: oldLeft,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
epoch: 11,
|
||||
suppressGoodByeOnStop: &atomic.Bool{},
|
||||
})
|
||||
client.markSessionStarted()
|
||||
defer client.markSessionStopped("test done", nil)
|
||||
|
||||
recvCh := make(chan Message, 1)
|
||||
client.SetLink("reattach", func(message *Message) {
|
||||
recvCh <- *message
|
||||
})
|
||||
|
||||
newLeft, newRight := net.Pipe()
|
||||
defer newRight.Close()
|
||||
if err := client.attachClientSessionTransport(newLeft); err != nil {
|
||||
t.Fatalf("attachClientSessionTransport failed: %v", err)
|
||||
}
|
||||
|
||||
rt := client.clientSessionRuntimeSnapshot()
|
||||
if rt == nil {
|
||||
t.Fatal("runtime snapshot should exist after attach")
|
||||
}
|
||||
if rt.conn != newLeft || !rt.transportAttached || rt.queue != queue || rt.epoch != 11 {
|
||||
t.Fatalf("attached runtime mismatch: %+v", rt)
|
||||
}
|
||||
|
||||
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
||||
ID: 42,
|
||||
Key: "reattach",
|
||||
Value: []byte("ok"),
|
||||
Type: MSG_ASYNC,
|
||||
}, client.sequenceEn)
|
||||
if err != nil {
|
||||
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
|
||||
}
|
||||
wire, err := client.encodeEnvelope(env)
|
||||
if err != nil {
|
||||
t.Fatalf("encodeEnvelope failed: %v", err)
|
||||
}
|
||||
if _, err := newRight.Write(wire); err != nil {
|
||||
t.Fatalf("new transport write failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case message := <-recvCh:
|
||||
if got, want := message.Key, "reattach"; got != want {
|
||||
t.Fatalf("message key mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := string(message.Value), "ok"; got != want {
|
||||
t.Fatalf("message value mismatch: got %q want %q", got, want)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reattached transport did not dispatch message")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||
|
||||
oldLeft, oldRight := net.Pipe()
|
||||
defer oldLeft.Close()
|
||||
defer oldRight.Close()
|
||||
oldBinding := newTransportBinding(oldLeft, queue)
|
||||
oldSender := oldBinding.bulkBatchSenderSnapshot()
|
||||
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
transport: oldBinding,
|
||||
conn: oldLeft,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
epoch: 1,
|
||||
})
|
||||
|
||||
newLeft, newRight := net.Pipe()
|
||||
defer newLeft.Close()
|
||||
defer newRight.Close()
|
||||
newBinding := newTransportBinding(newLeft, queue)
|
||||
|
||||
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||
transport: newBinding,
|
||||
conn: newLeft,
|
||||
stopCtx: stopCtx,
|
||||
stopFn: stopFn,
|
||||
queue: queue,
|
||||
epoch: 2,
|
||||
})
|
||||
|
||||
err := oldSender.submit(context.Background(), []byte("payload"))
|
||||
if err != errTransportDetached {
|
||||
t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
)
|
||||
|
||||
func (c *ClientCommon) SetStreamHandler(fn func(StreamAcceptInfo) error) {
|
||||
runtime := c.getStreamRuntime()
|
||||
if runtime == nil {
|
||||
return
|
||||
}
|
||||
runtime.setHandler(fn)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
|
||||
if c == nil {
|
||||
return nil, errStreamClientNil
|
||||
}
|
||||
runtime := c.getStreamRuntime()
|
||||
if runtime == nil {
|
||||
return nil, errStreamRuntimeNil
|
||||
}
|
||||
req := clientStreamRequest(runtime, opt)
|
||||
if req.StreamID == "" {
|
||||
return nil, errStreamIDEmpty
|
||||
}
|
||||
if _, exists := runtime.lookup(clientFileScope(), req.StreamID); exists {
|
||||
return nil, errStreamAlreadyExists
|
||||
}
|
||||
resp, err := sendStreamOpenClient(ctx, c, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.DataID != 0 {
|
||||
req.DataID = resp.DataID
|
||||
}
|
||||
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
|
||||
stream.setClientSnapshotOwner(c)
|
||||
stream.setAddrSnapshot(c.clientStreamAddrSnapshot())
|
||||
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||||
_, _ = sendStreamResetClient(context.Background(), c, StreamResetRequest{
|
||||
StreamID: req.StreamID,
|
||||
Error: err.Error(),
|
||||
})
|
||||
return nil, err
|
||||
}
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientStreamAddrSnapshot() (net.Addr, net.Addr) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
conn := c.clientTransportConnSnapshot()
|
||||
if conn == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return conn.LocalAddr(), conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOpenRequest {
|
||||
id := opt.ID
|
||||
if id == "" && runtime != nil {
|
||||
id = runtime.nextID()
|
||||
}
|
||||
return normalizeStreamOpenRequest(StreamOpenRequest{
|
||||
StreamID: id,
|
||||
Channel: opt.Channel,
|
||||
Metadata: cloneStreamMetadata(opt.Metadata),
|
||||
ReadTimeout: opt.ReadTimeout,
|
||||
WriteTimeout: opt.WriteTimeout,
|
||||
})
|
||||
}
|
||||
|
||||
func clientStreamCloseSender(c *ClientCommon) streamCloseSender {
|
||||
return func(ctx context.Context, stream *streamHandle, full bool) error {
|
||||
_, err := sendStreamCloseClient(ctx, c, StreamCloseRequest{
|
||||
StreamID: stream.ID(),
|
||||
Full: full,
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func clientStreamResetSender(c *ClientCommon) streamResetSender {
|
||||
return func(ctx context.Context, stream *streamHandle, message string) error {
|
||||
_, err := sendStreamResetClient(ctx, c, StreamResetRequest{
|
||||
StreamID: stream.ID(),
|
||||
Error: message,
|
||||
})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender {
|
||||
return func(ctx context.Context, stream *streamHandle, chunk []byte) error {
|
||||
if c == nil {
|
||||
return errStreamClientNil
|
||||
}
|
||||
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||
return errTransportDetached
|
||||
}
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
if dataID := stream.dataIDSnapshot(); dataID != 0 {
|
||||
return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk)
|
||||
}
|
||||
return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,206 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/stario"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func batchedControlEnvelope(env Envelope) bool {
|
||||
switch env.Kind {
|
||||
case EnvelopeSignal, EnvelopeSignalAck:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func writeDeadlineFromTimeout(timeout time.Duration) time.Time {
|
||||
if timeout <= 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Now().Add(timeout)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendHeartbeat() error {
|
||||
_, err := c.sendWait(TransferMsg{
|
||||
ID: 10000,
|
||||
Key: "heartbeat",
|
||||
Value: nil,
|
||||
Type: MSG_SYS_WAIT,
|
||||
}, time.Second*5)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleHeartbeatResult(err error, failedCount int) (int, bool) {
|
||||
return c.handleHeartbeatResultWithSession(c.currentClientSessionEpoch(), err, failedCount)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleHeartbeatResultWithSession(epoch uint64, err error, failedCount int) (int, bool) {
|
||||
if err == nil {
|
||||
c.lastHeartbeat = time.Now().Unix()
|
||||
return 0, false
|
||||
}
|
||||
if c.debugMode {
|
||||
fmt.Println("failed to recv heartbeat,timeout!")
|
||||
}
|
||||
failedCount++
|
||||
if failedCount < 3 {
|
||||
return failedCount, false
|
||||
}
|
||||
if c.debugMode {
|
||||
fmt.Println("heatbeat failed more than 3 times,stop client")
|
||||
}
|
||||
if !c.stopClientSessionIfCurrent(epoch, "heartbeat failed more than 3 times", errors.New("heartbeat failed more than 3 times")) {
|
||||
return failedCount, true
|
||||
}
|
||||
return failedCount, true
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readFromTransport() (int, []byte, error) {
|
||||
return c.readFromTransportBinding(c.clientTransportBindingSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readFromTransportConn(conn net.Conn) (int, []byte, error) {
|
||||
return c.readFromTransportBinding(newTransportBinding(conn, nil))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readFromTransportBinding(binding *transportBinding) (int, []byte, error) {
|
||||
return c.readFromTransportBindingWithBuffer(binding, streamReadBuffer())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) readFromTransportBindingWithBuffer(binding *transportBinding, data []byte) (int, []byte, error) {
|
||||
if len(data) == 0 {
|
||||
data = streamReadBuffer()
|
||||
}
|
||||
if binding == nil {
|
||||
return 0, data, net.ErrClosed
|
||||
}
|
||||
conn := binding.connSnapshot()
|
||||
if conn == nil {
|
||||
return 0, data, net.ErrClosed
|
||||
}
|
||||
if c.maxReadTimeout.Seconds() != 0 {
|
||||
_ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout))
|
||||
}
|
||||
readNum, err := conn.Read(data)
|
||||
return readNum, data, err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleTransportReadResult(readNum int, data []byte, err error) bool {
|
||||
return c.handleTransportReadResultWithSession(c.clientStopContextSnapshot(), c.clientTransportConnSnapshot(), c.clientQueueSnapshot(), readNum, data, err, c.currentClientSessionEpoch())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64) bool {
|
||||
return c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, c.clientInboundDispatcherSnapshot())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64, dispatcher *inboundDispatcher) bool {
|
||||
binding := newTransportBinding(conn, queue)
|
||||
if err == os.ErrDeadlineExceeded {
|
||||
if readNum != 0 && queue != nil {
|
||||
if !c.pushMessageFast(queue, data[:readNum], dispatcher) {
|
||||
queue.ParseMessage(data[:readNum], "b612")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
if err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client read error", err)
|
||||
}
|
||||
select {
|
||||
case <-sessionStopChan(stopCtx):
|
||||
c.closeClientTransportBinding(binding)
|
||||
return false
|
||||
default:
|
||||
}
|
||||
c.stopClientSessionIfCurrent(epoch, "client read error", err)
|
||||
return false
|
||||
}
|
||||
if queue != nil {
|
||||
if !c.pushMessageFast(queue, data[:readNum], dispatcher) {
|
||||
queue.ParseMessage(data[:readNum], "b612")
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool {
|
||||
if queue == nil || dispatcher == nil || len(data) == 0 {
|
||||
return false
|
||||
}
|
||||
if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error {
|
||||
payload := msg.Msg
|
||||
c.wg.Add(1)
|
||||
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
||||
defer c.wg.Done()
|
||||
now := time.Now()
|
||||
if err := c.dispatchInboundTransportPayload(payload, now); err != nil {
|
||||
if c.showError || c.debugMode {
|
||||
fmt.Println("client decode envelope error", err)
|
||||
}
|
||||
}
|
||||
}) {
|
||||
c.wg.Done()
|
||||
}
|
||||
return nil
|
||||
}); err != nil && (c.showError || c.debugMode) {
|
||||
fmt.Println("client parse inbound frame error", err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *ClientCommon) writeToTransport(data []byte) error {
|
||||
binding := c.clientTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
return binding.withConnWriteLock(func(conn net.Conn) error {
|
||||
if c.maxWriteTimeout.Seconds() != 0 {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
|
||||
}
|
||||
return writeFullToConnUnlocked(conn, data)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientCommon) writePayloadToTransport(payload []byte) error {
|
||||
binding := c.clientTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
queue := binding.queueSnapshot()
|
||||
if queue == nil {
|
||||
return errClientSessionQueueUnavailable
|
||||
}
|
||||
return binding.withConnWriteLock(func(conn net.Conn) error {
|
||||
if c.maxWriteTimeout.Seconds() != 0 {
|
||||
_ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
|
||||
}
|
||||
return writeFramedPayloadUnlocked(conn, queue, payload)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientCommon) writeControlPayloadToTransport(payload []byte) error {
|
||||
binding := c.clientTransportBindingSnapshot()
|
||||
if binding == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
queue := binding.queueSnapshot()
|
||||
if queue == nil {
|
||||
return errClientSessionQueueUnavailable
|
||||
}
|
||||
conn := binding.connSnapshot()
|
||||
if conn == nil || isPacketTransportConn(conn) {
|
||||
return c.writePayloadToTransport(payload)
|
||||
}
|
||||
sender := binding.controlBatchSenderSnapshot()
|
||||
if sender == nil {
|
||||
return c.writePayloadToTransport(payload)
|
||||
}
|
||||
return sender.submit(payload, writeDeadlineFromTimeout(c.maxWriteTimeout))
|
||||
}
|
||||
@@ -2,28 +2,50 @@ package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
SetDefaultLink(func(message *Message))
|
||||
SetLink(string, func(*Message))
|
||||
SetFileHandler(func(FileEvent))
|
||||
SetStreamHandler(func(StreamAcceptInfo) error)
|
||||
SetRecordStreamHandler(func(RecordAcceptInfo) error)
|
||||
SetBulkHandler(func(BulkAcceptInfo) error)
|
||||
SetTransferHandler(func(TransferAcceptInfo) (TransferReceiveOptions, error))
|
||||
GetStreamConfig() StreamConfig
|
||||
SetStreamConfig(StreamConfig)
|
||||
SetTransferResumeStore(TransferResumeStore)
|
||||
RecoverTransferSnapshots(context.Context) error
|
||||
SetFileReceiveDir(dir string) error
|
||||
send(msg TransferMsg) (WaitMsg, error)
|
||||
sendEnvelope(env Envelope) error
|
||||
sendWait(msg TransferMsg, timeout time.Duration) (Message, error)
|
||||
Send(key string, value MsgVal) error
|
||||
SendWait(key string, value MsgVal, timeout time.Duration) (Message, error)
|
||||
SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error)
|
||||
SendCtx(ctx context.Context, key string, value MsgVal) (Message, error)
|
||||
Reply(m Message, value MsgVal) error
|
||||
// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow.
|
||||
// Prefer UseModernPSKClient.
|
||||
ExchangeKey(newKey []byte) error
|
||||
Connect(network string, addr string) error
|
||||
ConnectTimeout(network string, addr string, timeout time.Duration) error
|
||||
ConnectByConn(conn net.Conn) error
|
||||
ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error
|
||||
// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||
SkipExchangeKey() bool
|
||||
// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||
SetSkipExchangeKey(bool)
|
||||
|
||||
GetMsgEn() func([]byte, []byte) []byte
|
||||
// Deprecated: SetMsgEn overrides the transport codec directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
SetMsgEn(func([]byte, []byte) []byte)
|
||||
GetMsgDe() func([]byte, []byte) []byte
|
||||
// Deprecated: SetMsgDe overrides the transport codec directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
SetMsgDe(func([]byte, []byte) []byte)
|
||||
|
||||
Heartbeat()
|
||||
@@ -31,8 +53,12 @@ type Client interface {
|
||||
SetHeartbeatPeroid(duration time.Duration)
|
||||
|
||||
GetSecretKey() []byte
|
||||
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||
SetSecretKey(key []byte)
|
||||
// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||
RsaPubKey() []byte
|
||||
// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||
SetRsaPubKey([]byte)
|
||||
|
||||
Stop() error
|
||||
@@ -48,4 +74,9 @@ type Client interface {
|
||||
SetSequenceDe(func([]byte) (interface{}, error))
|
||||
SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error)
|
||||
SendObj(key string, val interface{}) error
|
||||
OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error)
|
||||
OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error)
|
||||
OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
|
||||
SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error)
|
||||
SendFile(ctx context.Context, filePath string) error
|
||||
}
|
||||
|
||||
@@ -0,0 +1,544 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var errInMemoryListenerClosed = errors.New("in-memory listener closed")
|
||||
|
||||
type inMemoryListener struct {
|
||||
closed chan struct{}
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newInMemoryListener() *inMemoryListener {
|
||||
return &inMemoryListener{
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *inMemoryListener) Accept() (net.Conn, error) {
|
||||
<-l.closed
|
||||
return nil, errInMemoryListenerClosed
|
||||
}
|
||||
|
||||
func (l *inMemoryListener) Close() error {
|
||||
l.once.Do(func() {
|
||||
close(l.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *inMemoryListener) Addr() net.Addr {
|
||||
return inMemoryAddr("in-memory-listener")
|
||||
}
|
||||
|
||||
type inMemoryAddr string
|
||||
|
||||
func (a inMemoryAddr) Network() string { return "in-memory" }
|
||||
func (a inMemoryAddr) String() string { return string(a) }
|
||||
|
||||
func TestConnectByConnRequiresModernPSK(t *testing.T) {
|
||||
client := NewClient()
|
||||
left, right := net.Pipe()
|
||||
defer left.Close()
|
||||
defer right.Close()
|
||||
|
||||
err := client.ConnectByConn(left)
|
||||
if !errors.Is(err, errModernPSKRequired) {
|
||||
t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByConnWithConfiguredSecurity(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
|
||||
client.SetSecretKey(secret)
|
||||
if err := client.ConnectByConn(left); err != nil {
|
||||
t.Fatalf("ConnectByConn failed: %v", err)
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByFactoryRequiresModernPSK(t *testing.T) {
|
||||
client := NewClient()
|
||||
called := false
|
||||
|
||||
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||
called = true
|
||||
left, right := net.Pipe()
|
||||
_ = right.Close()
|
||||
return left, nil
|
||||
})
|
||||
if !errors.Is(err, errModernPSKRequired) {
|
||||
t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired)
|
||||
}
|
||||
if called {
|
||||
t.Fatal("dialFn should not be called before security validation passes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByFactoryRejectsNilDialFn(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||
|
||||
err := client.ConnectByFactory(context.Background(), nil)
|
||||
if err == nil || err.Error() != "dialFn is nil" {
|
||||
t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByFactoryPropagatesDialError(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||
wantErr := errors.New("dial failed")
|
||||
|
||||
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||
return nil, wantErr
|
||||
})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
|
||||
client.SetSecretKey(secret)
|
||||
if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) {
|
||||
if ctx == nil {
|
||||
t.Fatal("ConnectByFactory should normalize nil context")
|
||||
}
|
||||
return left, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("ConnectByFactory failed: %v", err)
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
firstDialEntered := make(chan struct{}, 1)
|
||||
firstDone := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) {
|
||||
firstDialEntered <- struct{}{}
|
||||
<-ctx.Done()
|
||||
return nil, ctx.Err()
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-firstDialEntered:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first connect attempt did not enter dialFn")
|
||||
}
|
||||
|
||||
secondDialCalled := false
|
||||
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||
secondDialCalled = true
|
||||
return nil, errors.New("second dial should not run")
|
||||
})
|
||||
if err == nil || err.Error() != "client already run" {
|
||||
t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err)
|
||||
}
|
||||
if secondDialCalled {
|
||||
t.Fatal("second dialFn should not be called during first connect start")
|
||||
}
|
||||
|
||||
cancel()
|
||||
select {
|
||||
case err = <-firstDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first ConnectByFactory did not finish after cancel")
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled)
|
||||
}
|
||||
|
||||
wantErr := errors.New("dial after rollback")
|
||||
err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||
return nil, wantErr
|
||||
})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
client.SetSecretKey(secret)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
|
||||
firstLeft, firstRight := net.Pipe()
|
||||
defer firstRight.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, firstRight)
|
||||
if err := client.ConnectByConn(firstLeft); err != nil {
|
||||
t.Fatalf("initial ConnectByConn failed: %v", err)
|
||||
}
|
||||
before := client.clientSessionRuntimeSnapshot()
|
||||
if before == nil {
|
||||
t.Fatal("runtime should exist after initial connect")
|
||||
}
|
||||
initialEpoch := before.epoch
|
||||
initialStopCtx := before.stopCtx
|
||||
initialQueue := before.queue
|
||||
|
||||
client.clearClientSessionRuntimeTransport()
|
||||
|
||||
recvCh := make(chan Message, 1)
|
||||
client.SetLink("reattach-public", func(message *Message) {
|
||||
recvCh <- *message
|
||||
})
|
||||
|
||||
secondLeft, secondRight := net.Pipe()
|
||||
defer secondRight.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, secondRight)
|
||||
if err := client.ConnectByConn(secondLeft); err != nil {
|
||||
t.Fatalf("reattach ConnectByConn failed: %v", err)
|
||||
}
|
||||
|
||||
after := client.clientSessionRuntimeSnapshot()
|
||||
if after == nil {
|
||||
t.Fatal("runtime should exist after reattach")
|
||||
}
|
||||
if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached {
|
||||
t.Fatalf("reattached runtime mismatch: %+v", after)
|
||||
}
|
||||
|
||||
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
||||
ID: 88,
|
||||
Key: "reattach-public",
|
||||
Value: []byte("ok"),
|
||||
Type: MSG_ASYNC,
|
||||
}, client.sequenceEn)
|
||||
if err != nil {
|
||||
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
|
||||
}
|
||||
wire, err := client.encodeEnvelope(env)
|
||||
if err != nil {
|
||||
t.Fatalf("encodeEnvelope failed: %v", err)
|
||||
}
|
||||
if _, err := secondRight.Write(wire); err != nil {
|
||||
t.Fatalf("reattached conn write failed: %v", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case msg := <-recvCh:
|
||||
if got, want := msg.Key, "reattach-public"; got != want {
|
||||
t.Fatalf("message key mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := string(msg.Value), "ok"; got != want {
|
||||
t.Fatalf("message value mismatch: got %q want %q", got, want)
|
||||
}
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("reattached public conn did not dispatch message")
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("final Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||
client.SetSecretKey(secret)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
server.SetSecretKey(secret)
|
||||
})
|
||||
|
||||
firstLeft, firstRight := net.Pipe()
|
||||
defer firstRight.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, firstRight)
|
||||
if err := client.ConnectByConn(firstLeft); err != nil {
|
||||
t.Fatalf("initial ConnectByConn failed: %v", err)
|
||||
}
|
||||
before := client.clientSessionRuntimeSnapshot()
|
||||
if before == nil {
|
||||
t.Fatal("runtime should exist after initial connect")
|
||||
}
|
||||
initialEpoch := before.epoch
|
||||
|
||||
client.clearClientSessionRuntimeTransport()
|
||||
|
||||
var dialCount atomic.Int32
|
||||
secondLeft, secondRight := net.Pipe()
|
||||
defer secondRight.Close()
|
||||
bootstrapPeerAttachConnForTest(t, server, secondRight)
|
||||
if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||
dialCount.Add(1)
|
||||
return secondLeft, nil
|
||||
}); err != nil {
|
||||
t.Fatalf("reattach ConnectByFactory failed: %v", err)
|
||||
}
|
||||
if got, want := dialCount.Load(), int32(1); got != want {
|
||||
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
after := client.clientSessionRuntimeSnapshot()
|
||||
if after == nil {
|
||||
t.Fatal("runtime should exist after factory reattach")
|
||||
}
|
||||
if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached {
|
||||
t.Fatalf("reattached runtime mismatch: %+v", after)
|
||||
}
|
||||
snapshot, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want {
|
||||
t.Fatalf("connect source mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if !snapshot.CanReconnect {
|
||||
t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot)
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("final Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
UseLegacySecurityClient(client)
|
||||
failErr := errors.New("key exchange fail for test")
|
||||
client.keyExchangeFn = func(Client) error {
|
||||
return failErr
|
||||
}
|
||||
|
||||
left1, right1 := net.Pipe()
|
||||
defer right1.Close()
|
||||
err := client.ConnectByConn(left1)
|
||||
if !errors.Is(err, failErr) {
|
||||
t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr)
|
||||
}
|
||||
status := client.Status()
|
||||
if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) {
|
||||
t.Fatalf("unexpected status after failed key exchange: %+v", status)
|
||||
}
|
||||
select {
|
||||
case <-client.StopMonitorChan():
|
||||
t.Fatal("StopMonitorChan should remain open after failed connect cleanup")
|
||||
case <-time.After(20 * time.Millisecond):
|
||||
}
|
||||
|
||||
client.SetSkipExchangeKey(true)
|
||||
left2, right2 := net.Pipe()
|
||||
defer right2.Close()
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
UseLegacySecurityServer(server)
|
||||
})
|
||||
bootstrapPeerAttachConnForTest(t, server, right2)
|
||||
if err := client.ConnectByConn(left2); err != nil {
|
||||
t.Fatalf("ConnectByConn second attempt failed: %v", err)
|
||||
}
|
||||
if !client.Status().Alive {
|
||||
t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status())
|
||||
}
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenByListenerRequiresModernPSK(t *testing.T) {
|
||||
server := NewServer()
|
||||
listener := newInMemoryListener()
|
||||
defer listener.Close()
|
||||
|
||||
err := server.ListenByListener(listener)
|
||||
if !errors.Is(err, errModernPSKRequired) {
|
||||
t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenByListenerWithConfiguredSecurity(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
listener := newInMemoryListener()
|
||||
defer listener.Close()
|
||||
|
||||
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||
if err := server.ListenByListener(listener); err != nil {
|
||||
t.Fatalf("ListenByListener failed: %v", err)
|
||||
}
|
||||
if !server.Status().Alive {
|
||||
t.Fatal("server should be alive after ListenByListener")
|
||||
}
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenByListenerRejectsNil(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||
err := server.ListenByListener(nil)
|
||||
if err == nil || err.Error() != "listener is nil" {
|
||||
t.Fatalf("ListenByListener nil error = %v, want listener is nil", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadMessagePreservesUserStopReason(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
left, right := net.Pipe()
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
|
||||
client.conn = left
|
||||
client.stopCtx = stopCtx
|
||||
client.stopFn = stopFn
|
||||
client.markSessionStarted()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
client.readMessage()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("Stop failed: %v", err)
|
||||
}
|
||||
_ = right.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("readMessage should exit after user stop")
|
||||
}
|
||||
|
||||
status := client.Status()
|
||||
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after user stop: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadMessagePreservesServerStopReason(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
left, right := net.Pipe()
|
||||
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||
defer stopFn()
|
||||
|
||||
client.conn = left
|
||||
client.stopCtx = stopCtx
|
||||
client.stopFn = stopFn
|
||||
client.markSessionStarted()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
client.readMessage()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
client.stopClientSessionFromServer("recv stop signal from server", nil)
|
||||
_ = right.Close()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("readMessage should exit after server stop")
|
||||
}
|
||||
|
||||
status := client.Status()
|
||||
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after server stop: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.markSessionStarted()
|
||||
|
||||
client.stopClientSessionFromServer("recv stop signal from server", nil)
|
||||
|
||||
if client.shouldSayGoodByeOnStop() {
|
||||
t.Fatal("server stop should disable goodbye on stop")
|
||||
}
|
||||
status := client.Status()
|
||||
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after server stop helper: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
client.markSessionStarted()
|
||||
|
||||
client.stopClientSession("recv stop signal from user", nil)
|
||||
|
||||
if !client.shouldSayGoodByeOnStop() {
|
||||
t.Fatal("local stop should keep goodbye enabled")
|
||||
}
|
||||
status := client.Status()
|
||||
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
|
||||
t.Fatalf("unexpected status after local stop helper: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
loopCtx, loopCancel := context.WithCancel(context.Background())
|
||||
loopCancel()
|
||||
|
||||
client.stopCtx = context.Background()
|
||||
client.conn = nil
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
client.readMessageLoop(loopCtx, left, nil, 1)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("readMessageLoop should exit when provided stopCtx is canceled")
|
||||
}
|
||||
|
||||
if _, err := right.Write([]byte("x")); err == nil {
|
||||
t.Fatal("peer conn should be closed when loop exits")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,257 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultConnectRetryAttempts = 3
|
||||
defaultConnectRetryBase = 200 * time.Millisecond
|
||||
defaultConnectRetryMax = 2 * time.Second
|
||||
)
|
||||
|
||||
type ConnectRetryOptions struct {
|
||||
MaxAttempts int
|
||||
BaseDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
ShouldRetry func(error) bool
|
||||
OnRetry func(ConnectRetryEvent)
|
||||
}
|
||||
|
||||
type ConnectRetryEvent struct {
|
||||
Attempt int
|
||||
MaxAttempts int
|
||||
Err error
|
||||
NextDelay time.Duration
|
||||
}
|
||||
|
||||
var (
|
||||
errConnectRetryClientNil = errors.New("connect retry client is nil")
|
||||
errConnectRetryServerNil = errors.New("connect retry server is nil")
|
||||
errConnectRetryFnNil = errors.New("connect retry fn is nil")
|
||||
errConnectRetryDialFnNil = errors.New("connect retry dialFn is nil")
|
||||
errClientReconnectNil = errors.New("client reconnect target is nil")
|
||||
errClientReconnectUnsupported = errors.New("client reconnect target type is unsupported")
|
||||
errClientReconnectActive = errors.New("client reconnect requires an inactive session")
|
||||
)
|
||||
|
||||
func DefaultConnectRetryOptions() ConnectRetryOptions {
|
||||
return ConnectRetryOptions{
|
||||
MaxAttempts: defaultConnectRetryAttempts,
|
||||
BaseDelay: defaultConnectRetryBase,
|
||||
MaxDelay: defaultConnectRetryMax,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeConnectRetryOptions(opts *ConnectRetryOptions) ConnectRetryOptions {
|
||||
cfg := DefaultConnectRetryOptions()
|
||||
if opts == nil {
|
||||
return cfg
|
||||
}
|
||||
if opts.MaxAttempts > 0 {
|
||||
cfg.MaxAttempts = opts.MaxAttempts
|
||||
}
|
||||
if opts.BaseDelay > 0 {
|
||||
cfg.BaseDelay = opts.BaseDelay
|
||||
}
|
||||
if opts.MaxDelay > 0 {
|
||||
cfg.MaxDelay = opts.MaxDelay
|
||||
}
|
||||
cfg.ShouldRetry = opts.ShouldRetry
|
||||
cfg.OnRetry = opts.OnRetry
|
||||
if cfg.MaxDelay < cfg.BaseDelay {
|
||||
cfg.MaxDelay = cfg.BaseDelay
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func RetryConnect(ctx context.Context, opts *ConnectRetryOptions, fn func(context.Context) error) error {
|
||||
if fn == nil {
|
||||
return errConnectRetryFnNil
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
cfg := normalizeConnectRetryOptions(opts)
|
||||
var lastErr error
|
||||
for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
lastErr = fn(ctx)
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
if cfg.ShouldRetry != nil && !cfg.ShouldRetry(lastErr) {
|
||||
return lastErr
|
||||
}
|
||||
if attempt >= cfg.MaxAttempts {
|
||||
break
|
||||
}
|
||||
delay := connectRetryBackoffDelay(cfg, attempt)
|
||||
if cfg.OnRetry != nil {
|
||||
cfg.OnRetry(ConnectRetryEvent{
|
||||
Attempt: attempt,
|
||||
MaxAttempts: cfg.MaxAttempts,
|
||||
Err: lastErr,
|
||||
NextDelay: delay,
|
||||
})
|
||||
}
|
||||
if err := waitConnectRetryDelay(ctx, delay); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func ConnectClientWithRetry(ctx context.Context, client Client, network string, addr string, opts *ConnectRetryOptions) error {
|
||||
if client == nil {
|
||||
return errConnectRetryClientNil
|
||||
}
|
||||
recorder, _ := any(client).(connectionRetryRecorder)
|
||||
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
|
||||
return client.Connect(network, addr)
|
||||
})
|
||||
if recorder != nil {
|
||||
recorder.recordConnectionRetryResult(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ConnectClientFactoryWithRetry(ctx context.Context, client Client, dialFn func(context.Context) (net.Conn, error), opts *ConnectRetryOptions) error {
|
||||
if client == nil {
|
||||
return errConnectRetryClientNil
|
||||
}
|
||||
if dialFn == nil {
|
||||
return errConnectRetryDialFnNil
|
||||
}
|
||||
recorder, _ := any(client).(connectionRetryRecorder)
|
||||
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
|
||||
return client.ConnectByFactory(ctx, dialFn)
|
||||
})
|
||||
if recorder != nil {
|
||||
recorder.recordConnectionRetryResult(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type clientReconnecter interface {
|
||||
reconnect(context.Context) error
|
||||
}
|
||||
|
||||
func ReconnectClient(ctx context.Context, client Client) error {
|
||||
if client == nil {
|
||||
return errClientReconnectNil
|
||||
}
|
||||
reconnecter, ok := any(client).(clientReconnecter)
|
||||
if !ok {
|
||||
return errClientReconnectUnsupported
|
||||
}
|
||||
return reconnecter.reconnect(ctx)
|
||||
}
|
||||
|
||||
func ReconnectClientWithRetry(ctx context.Context, client Client, opts *ConnectRetryOptions) error {
|
||||
if client == nil {
|
||||
return errConnectRetryClientNil
|
||||
}
|
||||
recorder, _ := any(client).(connectionRetryRecorder)
|
||||
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
|
||||
return ReconnectClient(ctx, client)
|
||||
})
|
||||
if recorder != nil {
|
||||
recorder.recordConnectionRetryResult(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func ListenServerWithRetry(ctx context.Context, server Server, network string, addr string, opts *ConnectRetryOptions) error {
|
||||
if server == nil {
|
||||
return errConnectRetryServerNil
|
||||
}
|
||||
recorder, _ := any(server).(connectionRetryRecorder)
|
||||
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
|
||||
return server.Listen(network, addr)
|
||||
})
|
||||
if recorder != nil {
|
||||
recorder.recordConnectionRetryResult(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) reconnect(ctx context.Context) error {
|
||||
if c == nil {
|
||||
return errClientReconnectNil
|
||||
}
|
||||
if sessionIsAlive(&c.alive) {
|
||||
return errClientReconnectActive
|
||||
}
|
||||
source := c.clientConnectSourceSnapshot()
|
||||
if source == nil || !source.canReconnect() {
|
||||
return errClientReconnectSourceUnavailable
|
||||
}
|
||||
finish, err := c.beginClientConnectAttempt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
started := false
|
||||
defer func() {
|
||||
finish(started)
|
||||
}()
|
||||
if err := c.validateSecurityConfiguration(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.closeClientTransport()
|
||||
c.applySignalReliabilityTransportDefault(source.isUDP())
|
||||
conn, err := source.dial(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if conn == nil {
|
||||
return errors.New("conn is nil")
|
||||
}
|
||||
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||
return err
|
||||
}
|
||||
started = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func connectRetryBackoffDelay(cfg ConnectRetryOptions, failedAttempt int) time.Duration {
|
||||
delay := cfg.BaseDelay
|
||||
if delay <= 0 {
|
||||
return 0
|
||||
}
|
||||
for i := 1; i < failedAttempt; i++ {
|
||||
if delay >= cfg.MaxDelay/2 {
|
||||
return cfg.MaxDelay
|
||||
}
|
||||
delay *= 2
|
||||
}
|
||||
if delay > cfg.MaxDelay {
|
||||
return cfg.MaxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func waitConnectRetryDelay(ctx context.Context, delay time.Duration) error {
|
||||
if delay <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(delay)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ConnectionRetrySnapshot struct {
|
||||
RetryEventTotal uint64
|
||||
LastRetryAttempt int
|
||||
LastRetryDelay time.Duration
|
||||
LastRetryError string
|
||||
LastRetryAt time.Time
|
||||
LastResultError string
|
||||
LastResultAt time.Time
|
||||
}
|
||||
|
||||
type connectionRetryState struct {
|
||||
mu sync.Mutex
|
||||
|
||||
retryEventTotal uint64
|
||||
lastRetryAttempt int
|
||||
lastRetryDelay time.Duration
|
||||
lastRetryError string
|
||||
lastRetryAt time.Time
|
||||
lastResultError string
|
||||
lastResultAt time.Time
|
||||
}
|
||||
|
||||
func newConnectionRetryState() *connectionRetryState {
|
||||
return &connectionRetryState{}
|
||||
}
|
||||
|
||||
func (s *connectionRetryState) recordRetryEvent(event ConnectRetryEvent) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.retryEventTotal++
|
||||
s.lastRetryAttempt = event.Attempt
|
||||
s.lastRetryDelay = event.NextDelay
|
||||
if event.Err != nil {
|
||||
s.lastRetryError = event.Err.Error()
|
||||
} else {
|
||||
s.lastRetryError = ""
|
||||
}
|
||||
s.lastRetryAt = time.Now()
|
||||
}
|
||||
|
||||
func (s *connectionRetryState) recordResult(err error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if err != nil {
|
||||
s.lastResultError = err.Error()
|
||||
} else {
|
||||
s.lastResultError = ""
|
||||
}
|
||||
s.lastResultAt = time.Now()
|
||||
}
|
||||
|
||||
func (s *connectionRetryState) snapshot() ConnectionRetrySnapshot {
|
||||
if s == nil {
|
||||
return ConnectionRetrySnapshot{}
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return ConnectionRetrySnapshot{
|
||||
RetryEventTotal: s.retryEventTotal,
|
||||
LastRetryAttempt: s.lastRetryAttempt,
|
||||
LastRetryDelay: s.lastRetryDelay,
|
||||
LastRetryError: s.lastRetryError,
|
||||
LastRetryAt: s.lastRetryAt,
|
||||
LastResultError: s.lastResultError,
|
||||
LastResultAt: s.lastResultAt,
|
||||
}
|
||||
}
|
||||
|
||||
type connectionRetryRecorder interface {
|
||||
recordConnectionRetryEvent(event ConnectRetryEvent)
|
||||
recordConnectionRetryResult(err error)
|
||||
}
|
||||
|
||||
func wrapConnectRetryOptionsWithRecorder(opts *ConnectRetryOptions, recorder connectionRetryRecorder) *ConnectRetryOptions {
|
||||
if recorder == nil {
|
||||
return opts
|
||||
}
|
||||
if opts == nil {
|
||||
return &ConnectRetryOptions{
|
||||
OnRetry: recorder.recordConnectionRetryEvent,
|
||||
}
|
||||
}
|
||||
next := *opts
|
||||
originOnRetry := next.OnRetry
|
||||
next.OnRetry = func(event ConnectRetryEvent) {
|
||||
recorder.recordConnectionRetryEvent(event)
|
||||
if originOnRetry != nil {
|
||||
originOnRetry(event)
|
||||
}
|
||||
}
|
||||
return &next
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getConnectionRetryState() *connectionRetryState {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.connectionRetryState == nil {
|
||||
c.connectionRetryState = newConnectionRetryState()
|
||||
}
|
||||
return c.connectionRetryState
|
||||
}
|
||||
|
||||
func (c *ClientCommon) recordConnectionRetryEvent(event ConnectRetryEvent) {
|
||||
c.getConnectionRetryState().recordRetryEvent(event)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) recordConnectionRetryResult(err error) {
|
||||
c.getConnectionRetryState().recordResult(err)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) connectionRetrySnapshot() ConnectionRetrySnapshot {
|
||||
return c.getConnectionRetryState().snapshot()
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getConnectionRetryState() *connectionRetryState {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.connectionRetryState == nil {
|
||||
s.connectionRetryState = newConnectionRetryState()
|
||||
}
|
||||
return s.connectionRetryState
|
||||
}
|
||||
|
||||
func (s *ServerCommon) recordConnectionRetryEvent(event ConnectRetryEvent) {
|
||||
s.getConnectionRetryState().recordRetryEvent(event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) recordConnectionRetryResult(err error) {
|
||||
s.getConnectionRetryState().recordResult(err)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) connectionRetrySnapshot() ConnectionRetrySnapshot {
|
||||
return s.getConnectionRetryState().snapshot()
|
||||
}
|
||||
@@ -0,0 +1,309 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRetryConnectSucceedsAfterRetries(t *testing.T) {
|
||||
var attempts int
|
||||
wantErr := errors.New("dial failed")
|
||||
|
||||
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||
MaxAttempts: 4,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: 2 * time.Millisecond,
|
||||
}, func(context.Context) error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return wantErr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RetryConnect failed: %v", err)
|
||||
}
|
||||
if got, want := attempts, 3; got != want {
|
||||
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryConnectReturnsLastError(t *testing.T) {
|
||||
var attempts int
|
||||
wantErr := errors.New("connect failed")
|
||||
|
||||
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||
MaxAttempts: 3,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: time.Millisecond,
|
||||
}, func(context.Context) error {
|
||||
attempts++
|
||||
return wantErr
|
||||
})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
|
||||
}
|
||||
if got, want := attempts, 3; got != want {
|
||||
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryConnectContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
var attempts int
|
||||
|
||||
err := RetryConnect(ctx, &ConnectRetryOptions{
|
||||
MaxAttempts: 3,
|
||||
BaseDelay: 100 * time.Millisecond,
|
||||
MaxDelay: 100 * time.Millisecond,
|
||||
}, func(context.Context) error {
|
||||
attempts++
|
||||
cancel()
|
||||
return errors.New("fail")
|
||||
})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("RetryConnect error = %v, want context canceled", err)
|
||||
}
|
||||
if got, want := attempts, 1; got != want {
|
||||
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectRetryRejectsNilInputs(t *testing.T) {
|
||||
if err := RetryConnect(context.Background(), nil, nil); !errors.Is(err, errConnectRetryFnNil) {
|
||||
t.Fatalf("RetryConnect nil fn error = %v, want %v", err, errConnectRetryFnNil)
|
||||
}
|
||||
if err := ConnectClientWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryClientNil) {
|
||||
t.Fatalf("ConnectClientWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil)
|
||||
}
|
||||
if err := ConnectClientFactoryWithRetry(context.Background(), nil, nil, nil); !errors.Is(err, errConnectRetryClientNil) {
|
||||
t.Fatalf("ConnectClientFactoryWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil)
|
||||
}
|
||||
if err := ConnectClientFactoryWithRetry(context.Background(), NewClient(), nil, nil); !errors.Is(err, errConnectRetryDialFnNil) {
|
||||
t.Fatalf("ConnectClientFactoryWithRetry nil dialFn error = %v, want %v", err, errConnectRetryDialFnNil)
|
||||
}
|
||||
if err := ListenServerWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryServerNil) {
|
||||
t.Fatalf("ListenServerWithRetry nil server error = %v, want %v", err, errConnectRetryServerNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectRetryBackoffDelayCapped(t *testing.T) {
|
||||
cfg := normalizeConnectRetryOptions(&ConnectRetryOptions{
|
||||
MaxAttempts: 5,
|
||||
BaseDelay: 10 * time.Millisecond,
|
||||
MaxDelay: 30 * time.Millisecond,
|
||||
})
|
||||
if got, want := connectRetryBackoffDelay(cfg, 1), 10*time.Millisecond; got != want {
|
||||
t.Fatalf("delay attempt1 mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := connectRetryBackoffDelay(cfg, 2), 20*time.Millisecond; got != want {
|
||||
t.Fatalf("delay attempt2 mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := connectRetryBackoffDelay(cfg, 3), 30*time.Millisecond; got != want {
|
||||
t.Fatalf("delay attempt3 mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := connectRetryBackoffDelay(cfg, 4), 30*time.Millisecond; got != want {
|
||||
t.Fatalf("delay attempt4 mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryConnectShouldRetryCanStopEarly(t *testing.T) {
|
||||
var attempts int
|
||||
wantErr := errors.New("not retriable")
|
||||
|
||||
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||
MaxAttempts: 5,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: 2 * time.Millisecond,
|
||||
ShouldRetry: func(err error) bool {
|
||||
return !errors.Is(err, wantErr)
|
||||
},
|
||||
}, func(context.Context) error {
|
||||
attempts++
|
||||
return wantErr
|
||||
})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
|
||||
}
|
||||
if got, want := attempts, 1; got != want {
|
||||
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryConnectOnRetryHook(t *testing.T) {
|
||||
var events []ConnectRetryEvent
|
||||
wantErr := errors.New("dial failed")
|
||||
|
||||
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||
MaxAttempts: 3,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: 2 * time.Millisecond,
|
||||
OnRetry: func(event ConnectRetryEvent) {
|
||||
events = append(events, event)
|
||||
},
|
||||
}, func(context.Context) error {
|
||||
return wantErr
|
||||
})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
|
||||
}
|
||||
if got, want := len(events), 2; got != want {
|
||||
t.Fatalf("retry events mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := events[0].Attempt, 1; got != want {
|
||||
t.Fatalf("event[0] attempt mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := events[0].MaxAttempts, 3; got != want {
|
||||
t.Fatalf("event[0] max attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if !errors.Is(events[0].Err, wantErr) {
|
||||
t.Fatalf("event[0] err mismatch: got %v want %v", events[0].Err, wantErr)
|
||||
}
|
||||
if got, want := events[0].NextDelay, time.Millisecond; got != want {
|
||||
t.Fatalf("event[0] next delay mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := events[1].Attempt, 2; got != want {
|
||||
t.Fatalf("event[1] attempt mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := events[1].NextDelay, 2*time.Millisecond; got != want {
|
||||
t.Fatalf("event[1] next delay mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectClientFactoryWithRetryRecoversFromFailedStart(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
UseLegacySecurityClient(client)
|
||||
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||
UseLegacySecurityServer(server)
|
||||
})
|
||||
|
||||
wantErr := errors.New("key exchange failed on first attempt")
|
||||
keyExchangeAttempts := 0
|
||||
client.keyExchangeFn = func(Client) error {
|
||||
keyExchangeAttempts++
|
||||
if keyExchangeAttempts == 1 {
|
||||
return wantErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
dialAttempts := 0
|
||||
var peerConns []net.Conn
|
||||
dialFn := func(context.Context) (net.Conn, error) {
|
||||
dialAttempts++
|
||||
left, right := net.Pipe()
|
||||
peerConns = append(peerConns, right)
|
||||
bootstrapPeerAttachConnForTest(t, server, right)
|
||||
return left, nil
|
||||
}
|
||||
|
||||
err := ConnectClientFactoryWithRetry(context.Background(), client, dialFn, &ConnectRetryOptions{
|
||||
MaxAttempts: 3,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: time.Millisecond,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ConnectClientFactoryWithRetry failed: %v", err)
|
||||
}
|
||||
if got, want := dialAttempts, 2; got != want {
|
||||
t.Fatalf("dial attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := keyExchangeAttempts, 2; got != want {
|
||||
t.Fatalf("key exchange attempts mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if status := client.Status(); !status.Alive {
|
||||
t.Fatalf("client should be alive after retry success: %+v", status)
|
||||
}
|
||||
runtimeSnapshot, err := GetClientRuntimeSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want {
|
||||
t.Fatalf("client retry events mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want {
|
||||
t.Fatalf("client last retry attempt mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := runtimeSnapshot.Retry.LastRetryError, wantErr.Error(); got != want {
|
||||
t.Fatalf("client last retry error mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if runtimeSnapshot.Retry.LastRetryAt.IsZero() {
|
||||
t.Fatal("client last retry time should be recorded")
|
||||
}
|
||||
if runtimeSnapshot.Retry.LastResultError != "" {
|
||||
t.Fatalf("client last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError)
|
||||
}
|
||||
if runtimeSnapshot.Retry.LastResultAt.IsZero() {
|
||||
t.Fatal("client last result time should be recorded")
|
||||
}
|
||||
|
||||
client.setByeFromServer(true)
|
||||
if err := client.Stop(); err != nil {
|
||||
t.Fatalf("client Stop failed: %v", err)
|
||||
}
|
||||
for _, conn := range peerConns {
|
||||
_ = conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenServerWithRetryRecoversFromFailedStart(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
var retryEvents []ConnectRetryEvent
|
||||
|
||||
err := ListenServerWithRetry(context.Background(), server, "tcp", "127.0.0.1:0", &ConnectRetryOptions{
|
||||
MaxAttempts: 3,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: time.Millisecond,
|
||||
OnRetry: func(event ConnectRetryEvent) {
|
||||
retryEvents = append(retryEvents, event)
|
||||
if event.Attempt == 1 {
|
||||
UseLegacySecurityServer(server)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListenServerWithRetry failed: %v", err)
|
||||
}
|
||||
if status := server.Status(); !status.Alive {
|
||||
t.Fatalf("server should be alive after retry success: %+v", status)
|
||||
}
|
||||
if got := len(retryEvents); got < 1 {
|
||||
t.Fatal("OnRetry should be called at least once")
|
||||
}
|
||||
if got, want := retryEvents[0].Attempt, 1; got != want {
|
||||
t.Fatalf("retry event attempt mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if !errors.Is(retryEvents[0].Err, errModernPSKRequired) {
|
||||
t.Fatalf("retry event err mismatch: got %v want %v", retryEvents[0].Err, errModernPSKRequired)
|
||||
}
|
||||
runtimeSnapshot, err := GetServerRuntimeSnapshot(server)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want {
|
||||
t.Fatalf("server retry events mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want {
|
||||
t.Fatalf("server last retry attempt mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := runtimeSnapshot.Retry.LastRetryError, errModernPSKRequired.Error(); got != want {
|
||||
t.Fatalf("server last retry error mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if runtimeSnapshot.Retry.LastRetryAt.IsZero() {
|
||||
t.Fatal("server last retry time should be recorded")
|
||||
}
|
||||
if runtimeSnapshot.Retry.LastResultError != "" {
|
||||
t.Fatalf("server last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError)
|
||||
}
|
||||
if runtimeSnapshot.Retry.LastResultAt.IsZero() {
|
||||
t.Fatal("server last result time should be recorded")
|
||||
}
|
||||
|
||||
if err := server.Stop(); err != nil {
|
||||
t.Fatalf("server Stop failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const controlBatchMaxPayloads = 16
|
||||
|
||||
type controlBatchRequest struct {
|
||||
payload []byte
|
||||
deadline time.Time
|
||||
done chan error
|
||||
}
|
||||
|
||||
type controlBatchSender struct {
|
||||
binding *transportBinding
|
||||
reqCh chan controlBatchRequest
|
||||
stopCh chan struct{}
|
||||
doneCh chan struct{}
|
||||
|
||||
stopOnce sync.Once
|
||||
errMu sync.Mutex
|
||||
err error
|
||||
}
|
||||
|
||||
func newControlBatchSender(binding *transportBinding) *controlBatchSender {
|
||||
sender := &controlBatchSender{
|
||||
binding: binding,
|
||||
reqCh: make(chan controlBatchRequest, controlBatchMaxPayloads*4),
|
||||
stopCh: make(chan struct{}),
|
||||
doneCh: make(chan struct{}),
|
||||
}
|
||||
go sender.run()
|
||||
return sender
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) submit(payload []byte, deadline time.Time) error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
req := controlBatchRequest{
|
||||
payload: payload,
|
||||
deadline: deadline,
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
return s.stoppedErr()
|
||||
case s.reqCh <- req:
|
||||
}
|
||||
return <-req.done
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) run() {
|
||||
defer close(s.doneCh)
|
||||
for {
|
||||
req, ok := s.nextRequest()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
batch := []controlBatchRequest{req}
|
||||
drain:
|
||||
for len(batch) < controlBatchMaxPayloads {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
s.failPending(s.stoppedErr())
|
||||
return
|
||||
case next := <-s.reqCh:
|
||||
batch = append(batch, next)
|
||||
default:
|
||||
break drain
|
||||
}
|
||||
}
|
||||
payloads := make([][]byte, 0, len(batch))
|
||||
for _, item := range batch {
|
||||
payloads = append(payloads, item.payload)
|
||||
}
|
||||
err := s.flush(payloads, controlBatchRequestsEarliestDeadline(batch))
|
||||
if err != nil {
|
||||
s.setErr(err)
|
||||
for _, item := range batch {
|
||||
item.done <- err
|
||||
}
|
||||
s.failPending(err)
|
||||
return
|
||||
}
|
||||
for _, item := range batch {
|
||||
item.done <- nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) nextRequest() (controlBatchRequest, bool) {
|
||||
select {
|
||||
case <-s.stopCh:
|
||||
s.failPending(s.stoppedErr())
|
||||
return controlBatchRequest{}, false
|
||||
case req := <-s.reqCh:
|
||||
return req, true
|
||||
}
|
||||
}
|
||||
|
||||
func controlBatchRequestsEarliestDeadline(batch []controlBatchRequest) time.Time {
|
||||
var deadline time.Time
|
||||
for _, item := range batch {
|
||||
if item.deadline.IsZero() {
|
||||
continue
|
||||
}
|
||||
if deadline.IsZero() || item.deadline.Before(deadline) {
|
||||
deadline = item.deadline
|
||||
}
|
||||
}
|
||||
return deadline
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) flush(payloads [][]byte, deadline time.Time) error {
|
||||
if s == nil || s.binding == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
queue := s.binding.queueSnapshot()
|
||||
if queue == nil {
|
||||
return errTransportFrameQueueUnavailable
|
||||
}
|
||||
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
|
||||
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
s.setErr(errTransportDetached)
|
||||
close(s.stopCh)
|
||||
})
|
||||
<-s.doneCh
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) failPending(err error) {
|
||||
for {
|
||||
select {
|
||||
case item := <-s.reqCh:
|
||||
item.done <- err
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) setErr(err error) {
|
||||
if s == nil || err == nil {
|
||||
return
|
||||
}
|
||||
s.errMu.Lock()
|
||||
if s.err == nil {
|
||||
s.err = err
|
||||
}
|
||||
s.errMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) errSnapshot() error {
|
||||
if s == nil {
|
||||
return errTransportDetached
|
||||
}
|
||||
s.errMu.Lock()
|
||||
defer s.errMu.Unlock()
|
||||
return s.err
|
||||
}
|
||||
|
||||
func (s *controlBatchSender) stoppedErr() error {
|
||||
if err := s.errSnapshot(); err != nil {
|
||||
return err
|
||||
}
|
||||
return errTransportDetached
|
||||
}
|
||||
+48
-2
@@ -1,10 +1,13 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
itransfer "b612.me/notify/internal/transfer"
|
||||
"b612.me/starcrypto"
|
||||
"log"
|
||||
)
|
||||
|
||||
// Deprecated: legacy static RSA private key retained only for compatibility
|
||||
// with MSG_KEY_CHANGE.
|
||||
var defaultRsaKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIJKAIBAAKCAgEAxmeMqr9yfJFKZn26oe/HvC7bZXNLC9Nk55AuTkb4XuIoqXDb
|
||||
AJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT0ZCEf37ILU0G+scRzVwYHiLMwOUC
|
||||
@@ -55,8 +58,10 @@ HKpWIdjFJK1EqSfcINe2YuoyUIulz9oG7ObRHD4D8jSPjA8Ete+XsBHGyOtUl09u
|
||||
X4u9uClhqjK+r1Tno2vw5yF6ZxfQtdWuL4W0UL1S8E+VO7vjTjNOYvgjAIpAM/gW
|
||||
sqjA2Qw52UZqhhLXoTfRvtJilxlXXhIRJSsnUoGiYVCQ/upjqJCClEvJfIWdGY/U
|
||||
I2CbFrwJcNvOG1lUsSM55JUmbrSWVPfo7yq2k9GCuFxOy2n/SVlvlQUcNkA=
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
-----END RSA PRIVATE KEY-----`)
|
||||
|
||||
// Deprecated: legacy static RSA public key retained only for compatibility
|
||||
// with MSG_KEY_CHANGE.
|
||||
var defaultRsaPubKey = []byte(`-----BEGIN PUBLIC KEY-----
|
||||
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAxmeMqr9yfJFKZn26oe/H
|
||||
vC7bZXNLC9Nk55AuTkb4XuIoqXDbAJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT
|
||||
@@ -70,10 +75,13 @@ hq+q8YLcnKHvNKYVyCf/upExpAiArr88y/KbeKes0KorKkwMBnGUMTothWM25wHo
|
||||
zcurixNvP4UMWX7LWD7vOZZuNDQNutZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SeP
|
||||
UMJNDyjfDUJM8C2DOlyhGNPkgazOGdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9G
|
||||
GoBHEfvmAoGGrk4qNbjm7JECAwEAAQ==
|
||||
-----END PUBLIC KEY-----`)
|
||||
-----END PUBLIC KEY-----`)
|
||||
|
||||
// Deprecated: legacy static AES key retained only for compatibility with the
|
||||
// old AES-CFB transport path.
|
||||
var defaultAesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133}
|
||||
|
||||
// Deprecated: legacy AES-CFB transport codec retained only for compatibility.
|
||||
func defaultMsgEn(key []byte, d []byte) []byte {
|
||||
data, err := starcrypto.CustomEncryptAesCFB(d, key)
|
||||
if err != nil {
|
||||
@@ -83,6 +91,7 @@ func defaultMsgEn(key []byte, d []byte) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
// Deprecated: legacy AES-CFB transport codec retained only for compatibility.
|
||||
func defaultMsgDe(key []byte, d []byte) []byte {
|
||||
data, err := starcrypto.CustomDecryptAesCFB(d, key)
|
||||
if err != nil {
|
||||
@@ -94,4 +103,41 @@ func defaultMsgDe(key []byte, d []byte) []byte {
|
||||
|
||||
func init() {
|
||||
RegisterName("b612.me/notify.Transfer", TransferMsg{})
|
||||
RegisterName("b612.me/notify.Envelope", Envelope{})
|
||||
RegisterName("b612.me/notify.TransferRange", TransferRange{})
|
||||
RegisterName("b612.me/notify.TransferBeginRequest", TransferBeginRequest{})
|
||||
RegisterName("b612.me/notify.TransferBeginResponse", TransferBeginResponse{})
|
||||
RegisterName("b612.me/notify.TransferResumeRequest", TransferResumeRequest{})
|
||||
RegisterName("b612.me/notify.TransferResumeResponse", TransferResumeResponse{})
|
||||
RegisterName("b612.me/notify.TransferCommitRequest", TransferCommitRequest{})
|
||||
RegisterName("b612.me/notify.TransferCommitResponse", TransferCommitResponse{})
|
||||
RegisterName("b612.me/notify.TransferAbortRequest", TransferAbortRequest{})
|
||||
RegisterName("b612.me/notify.TransferAbortResponse", TransferAbortResponse{})
|
||||
RegisterName("b612.me/notify.StreamOpenRequest", StreamOpenRequest{})
|
||||
RegisterName("b612.me/notify.StreamOpenResponse", StreamOpenResponse{})
|
||||
RegisterName("b612.me/notify.StreamCloseRequest", StreamCloseRequest{})
|
||||
RegisterName("b612.me/notify.StreamCloseResponse", StreamCloseResponse{})
|
||||
RegisterName("b612.me/notify.StreamResetRequest", StreamResetRequest{})
|
||||
RegisterName("b612.me/notify.StreamResetResponse", StreamResetResponse{})
|
||||
RegisterName("b612.me/notify.BulkRange", BulkRange{})
|
||||
RegisterName("b612.me/notify.BulkOpenRequest", BulkOpenRequest{})
|
||||
RegisterName("b612.me/notify.BulkOpenResponse", BulkOpenResponse{})
|
||||
RegisterName("b612.me/notify.BulkCloseRequest", BulkCloseRequest{})
|
||||
RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{})
|
||||
RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{})
|
||||
RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{})
|
||||
RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{})
|
||||
RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{})
|
||||
RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{})
|
||||
RegisterName("b612.me/notify.peerAttachRequest", peerAttachRequest{})
|
||||
RegisterName("b612.me/notify.peerAttachResponse", peerAttachResponse{})
|
||||
RegisterName("b612.me/notify/transfer.Begin", itransfer.Begin{})
|
||||
RegisterName("b612.me/notify/transfer.BeginAck", itransfer.BeginAck{})
|
||||
RegisterName("b612.me/notify/transfer.Resume", itransfer.Resume{})
|
||||
RegisterName("b612.me/notify/transfer.ResumeAck", itransfer.ResumeAck{})
|
||||
RegisterName("b612.me/notify/transfer.Commit", itransfer.Commit{})
|
||||
RegisterName("b612.me/notify/transfer.CommitAck", itransfer.CommitAck{})
|
||||
RegisterName("b612.me/notify/transfer.Abort", itransfer.Abort{})
|
||||
RegisterName("b612.me/notify/transfer.Segment", itransfer.Segment{})
|
||||
RegisterName("b612.me/notify/transfer.Ack", itransfer.Ack{})
|
||||
}
|
||||
|
||||
@@ -0,0 +1,387 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type DiagnosticsResetCauseSummary struct {
|
||||
Total int
|
||||
TransportDetached int
|
||||
ServiceShutdown int
|
||||
Backpressure int
|
||||
Other int
|
||||
}
|
||||
|
||||
type DiagnosticsTransferTelemetrySummary struct {
|
||||
SourceReadBytes int64
|
||||
StreamWriteBytes int64
|
||||
SinkWriteBytes int64
|
||||
SourceReadDuration time.Duration
|
||||
StreamWriteDuration time.Duration
|
||||
SinkWriteDuration time.Duration
|
||||
SyncDuration time.Duration
|
||||
VerifyDuration time.Duration
|
||||
CommitDuration time.Duration
|
||||
CommitWaitDuration time.Duration
|
||||
WorkDuration time.Duration
|
||||
ObservedDuration time.Duration
|
||||
SourceReadThroughputBPS float64
|
||||
StreamWriteThroughputBPS float64
|
||||
SinkWriteThroughputBPS float64
|
||||
CommitWaitRatio float64
|
||||
}
|
||||
|
||||
type DiagnosticsSummary struct {
|
||||
LogicalCount int
|
||||
CurrentTransportCount int
|
||||
|
||||
StreamCount int
|
||||
ActiveStreamCount int
|
||||
StaleStreamCount int
|
||||
ResetStreamCount int
|
||||
|
||||
BulkCount int
|
||||
DedicatedBulkCount int
|
||||
ActiveBulkCount int
|
||||
StaleBulkCount int
|
||||
ResetBulkCount int
|
||||
|
||||
TransferCount int
|
||||
ActiveTransferCount int
|
||||
PausedTransferCount int
|
||||
DoneTransferCount int
|
||||
FailedTransferCount int
|
||||
AbortedTransferCount int
|
||||
|
||||
StreamResetCauses DiagnosticsResetCauseSummary
|
||||
BulkResetCauses DiagnosticsResetCauseSummary
|
||||
TransferTelemetry DiagnosticsTransferTelemetrySummary
|
||||
}
|
||||
|
||||
type ClientDiagnosticsSnapshot struct {
|
||||
Runtime ClientRuntimeSnapshot
|
||||
Streams []StreamSnapshot
|
||||
Bulks []BulkSnapshot
|
||||
Transfers []TransferSnapshot
|
||||
Summary DiagnosticsSummary
|
||||
}
|
||||
|
||||
type ServerDiagnosticsSnapshot struct {
|
||||
Runtime ServerRuntimeSnapshot
|
||||
Logicals []ClientConnRuntimeSnapshot
|
||||
CurrentTransports []TransportConnRuntimeSnapshot
|
||||
Streams []StreamSnapshot
|
||||
Bulks []BulkSnapshot
|
||||
Transfers []TransferSnapshot
|
||||
Summary DiagnosticsSummary
|
||||
}
|
||||
|
||||
var (
|
||||
errClientDiagnosticsSnapshotNil = errors.New("client diagnostics snapshot target is nil")
|
||||
errServerDiagnosticsSnapshotNil = errors.New("server diagnostics snapshot target is nil")
|
||||
)
|
||||
|
||||
func GetClientDiagnosticsSnapshot(c Client) (ClientDiagnosticsSnapshot, error) {
|
||||
if c == nil {
|
||||
return ClientDiagnosticsSnapshot{}, errClientDiagnosticsSnapshotNil
|
||||
}
|
||||
runtime, err := GetClientRuntimeSnapshot(c)
|
||||
if err != nil {
|
||||
return ClientDiagnosticsSnapshot{}, err
|
||||
}
|
||||
streams, err := GetClientStreamSnapshots(c)
|
||||
if err != nil {
|
||||
return ClientDiagnosticsSnapshot{}, err
|
||||
}
|
||||
bulks, err := GetClientBulkSnapshots(c)
|
||||
if err != nil {
|
||||
return ClientDiagnosticsSnapshot{}, err
|
||||
}
|
||||
transfers, err := GetClientTransferSnapshots(c)
|
||||
if err != nil {
|
||||
return ClientDiagnosticsSnapshot{}, err
|
||||
}
|
||||
snapshot := ClientDiagnosticsSnapshot{
|
||||
Runtime: runtime,
|
||||
Streams: streams,
|
||||
Bulks: bulks,
|
||||
Transfers: transfers,
|
||||
}
|
||||
snapshot.Summary = summarizeClientDiagnosticsSnapshot(snapshot)
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func GetServerDiagnosticsSnapshot(s Server) (ServerDiagnosticsSnapshot, error) {
|
||||
if s == nil {
|
||||
return ServerDiagnosticsSnapshot{}, errServerDiagnosticsSnapshotNil
|
||||
}
|
||||
runtime, err := GetServerRuntimeSnapshot(s)
|
||||
if err != nil {
|
||||
return ServerDiagnosticsSnapshot{}, err
|
||||
}
|
||||
logicals, err := serverLogicalRuntimeSnapshots(s)
|
||||
if err != nil {
|
||||
return ServerDiagnosticsSnapshot{}, err
|
||||
}
|
||||
transports, err := serverCurrentTransportRuntimeSnapshots(s)
|
||||
if err != nil {
|
||||
return ServerDiagnosticsSnapshot{}, err
|
||||
}
|
||||
streams, err := GetServerStreamSnapshots(s)
|
||||
if err != nil {
|
||||
return ServerDiagnosticsSnapshot{}, err
|
||||
}
|
||||
bulks, err := GetServerBulkSnapshots(s)
|
||||
if err != nil {
|
||||
return ServerDiagnosticsSnapshot{}, err
|
||||
}
|
||||
transfers, err := GetServerTransferSnapshots(s)
|
||||
if err != nil {
|
||||
return ServerDiagnosticsSnapshot{}, err
|
||||
}
|
||||
snapshot := ServerDiagnosticsSnapshot{
|
||||
Runtime: runtime,
|
||||
Logicals: logicals,
|
||||
CurrentTransports: transports,
|
||||
Streams: streams,
|
||||
Bulks: bulks,
|
||||
Transfers: transfers,
|
||||
}
|
||||
snapshot.Summary = summarizeServerDiagnosticsSnapshot(snapshot)
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func serverLogicalRuntimeSnapshots(s Server) ([]ClientConnRuntimeSnapshot, error) {
|
||||
if s == nil {
|
||||
return nil, errServerDiagnosticsSnapshotNil
|
||||
}
|
||||
logicals := s.GetLogicalConnList()
|
||||
out := make([]ClientConnRuntimeSnapshot, 0, len(logicals))
|
||||
for _, logical := range logicals {
|
||||
if logical == nil {
|
||||
continue
|
||||
}
|
||||
snapshot, err := GetLogicalConnRuntimeSnapshot(logical)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, snapshot)
|
||||
}
|
||||
sortClientConnRuntimeSnapshots(out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSnapshot, error) {
|
||||
if s == nil {
|
||||
return nil, errServerDiagnosticsSnapshotNil
|
||||
}
|
||||
transports := s.GetCurrentTransportConnList()
|
||||
out := make([]TransportConnRuntimeSnapshot, 0, len(transports))
|
||||
for _, transport := range transports {
|
||||
if transport == nil {
|
||||
continue
|
||||
}
|
||||
snapshot, err := GetTransportConnRuntimeSnapshot(transport)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, snapshot)
|
||||
}
|
||||
sortTransportConnRuntimeSnapshots(out)
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary {
|
||||
summary := DiagnosticsSummary{
|
||||
LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime),
|
||||
}
|
||||
if snapshot.Runtime.TransportAttached {
|
||||
summary.CurrentTransportCount = 1
|
||||
}
|
||||
summarizeStreamSnapshots(&summary, snapshot.Streams)
|
||||
summarizeBulkSnapshots(&summary, snapshot.Bulks)
|
||||
summarizeTransferSnapshots(&summary, snapshot.Transfers)
|
||||
return summary
|
||||
}
|
||||
|
||||
func summarizeServerDiagnosticsSnapshot(snapshot ServerDiagnosticsSnapshot) DiagnosticsSummary {
|
||||
summary := DiagnosticsSummary{
|
||||
LogicalCount: len(snapshot.Logicals),
|
||||
CurrentTransportCount: len(snapshot.CurrentTransports),
|
||||
}
|
||||
summarizeStreamSnapshots(&summary, snapshot.Streams)
|
||||
summarizeBulkSnapshots(&summary, snapshot.Bulks)
|
||||
summarizeTransferSnapshots(&summary, snapshot.Transfers)
|
||||
return summary
|
||||
}
|
||||
|
||||
func diagnosticsLogicalCountFromClientRuntime(runtime ClientRuntimeSnapshot) int {
|
||||
if runtime.Alive || runtime.SessionEpoch != 0 || runtime.TransportAttached || runtime.HasRuntimeConn || runtime.HasRuntimeQueue {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func summarizeStreamSnapshots(summary *DiagnosticsSummary, snapshots []StreamSnapshot) {
|
||||
if summary == nil {
|
||||
return
|
||||
}
|
||||
summary.StreamCount = len(snapshots)
|
||||
for _, snapshot := range snapshots {
|
||||
switch {
|
||||
case snapshot.ResetError != "":
|
||||
summary.ResetStreamCount++
|
||||
accumulateDiagnosticsResetCause(&summary.StreamResetCauses, snapshot.ResetError, errStreamBackpressureExceeded.Error())
|
||||
case streamSnapshotFinished(snapshot):
|
||||
case streamSnapshotBoundActive(snapshot):
|
||||
summary.ActiveStreamCount++
|
||||
default:
|
||||
summary.StaleStreamCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeBulkSnapshots(summary *DiagnosticsSummary, snapshots []BulkSnapshot) {
|
||||
if summary == nil {
|
||||
return
|
||||
}
|
||||
summary.BulkCount = len(snapshots)
|
||||
for _, snapshot := range snapshots {
|
||||
if snapshot.Dedicated {
|
||||
summary.DedicatedBulkCount++
|
||||
}
|
||||
switch {
|
||||
case snapshot.ResetError != "":
|
||||
summary.ResetBulkCount++
|
||||
accumulateDiagnosticsResetCause(&summary.BulkResetCauses, snapshot.ResetError, errBulkBackpressureExceeded.Error())
|
||||
case bulkSnapshotFinished(snapshot):
|
||||
case bulkSnapshotBoundActive(snapshot):
|
||||
summary.ActiveBulkCount++
|
||||
default:
|
||||
summary.StaleBulkCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeTransferSnapshots(summary *DiagnosticsSummary, snapshots []TransferSnapshot) {
|
||||
if summary == nil {
|
||||
return
|
||||
}
|
||||
summary.TransferCount = len(snapshots)
|
||||
for _, snapshot := range snapshots {
|
||||
switch snapshot.State {
|
||||
case TransferStateDone:
|
||||
summary.DoneTransferCount++
|
||||
case TransferStateFailed:
|
||||
summary.FailedTransferCount++
|
||||
case TransferStateAborted:
|
||||
summary.AbortedTransferCount++
|
||||
case TransferStatePaused:
|
||||
summary.PausedTransferCount++
|
||||
default:
|
||||
summary.ActiveTransferCount++
|
||||
}
|
||||
accumulateDiagnosticsTransferTelemetry(&summary.TransferTelemetry, snapshot)
|
||||
}
|
||||
finalizeDiagnosticsTransferTelemetry(&summary.TransferTelemetry)
|
||||
}
|
||||
|
||||
func streamSnapshotFinished(snapshot StreamSnapshot) bool {
|
||||
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
|
||||
}
|
||||
|
||||
func bulkSnapshotFinished(snapshot BulkSnapshot) bool {
|
||||
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
|
||||
}
|
||||
|
||||
func streamSnapshotBoundActive(snapshot StreamSnapshot) bool {
|
||||
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
|
||||
}
|
||||
|
||||
func bulkSnapshotBoundActive(snapshot BulkSnapshot) bool {
|
||||
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
|
||||
}
|
||||
|
||||
func accumulateDiagnosticsResetCause(summary *DiagnosticsResetCauseSummary, resetError string, backpressureError string) {
|
||||
if summary == nil || resetError == "" {
|
||||
return
|
||||
}
|
||||
summary.Total++
|
||||
if diagnosticsResetErrorMatches(resetError, errTransportDetached) {
|
||||
summary.TransportDetached++
|
||||
return
|
||||
}
|
||||
if diagnosticsResetErrorMatches(resetError, errServiceShutdown) {
|
||||
summary.ServiceShutdown++
|
||||
return
|
||||
}
|
||||
if resetError == backpressureError || strings.HasPrefix(resetError, backpressureError+":") {
|
||||
summary.Backpressure++
|
||||
return
|
||||
}
|
||||
summary.Other++
|
||||
}
|
||||
|
||||
func diagnosticsResetErrorMatches(resetError string, target error) bool {
|
||||
if resetError == "" || target == nil {
|
||||
return false
|
||||
}
|
||||
base := target.Error()
|
||||
return resetError == base || strings.HasPrefix(resetError, base+":")
|
||||
}
|
||||
|
||||
func accumulateDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary, snapshot TransferSnapshot) {
|
||||
if summary == nil {
|
||||
return
|
||||
}
|
||||
summary.SourceReadBytes += transferSummarySourceReadBytes(snapshot)
|
||||
summary.StreamWriteBytes += transferSummaryStreamWriteBytes(snapshot)
|
||||
summary.SinkWriteBytes += transferSummarySinkWriteBytes(snapshot)
|
||||
summary.SourceReadDuration += snapshot.SourceReadDuration
|
||||
summary.StreamWriteDuration += snapshot.StreamWriteDuration
|
||||
summary.SinkWriteDuration += snapshot.SinkWriteDuration
|
||||
summary.SyncDuration += snapshot.SyncDuration
|
||||
summary.VerifyDuration += snapshot.VerifyDuration
|
||||
summary.CommitDuration += snapshot.CommitDuration
|
||||
summary.CommitWaitDuration += snapshot.CommitWaitDuration
|
||||
}
|
||||
|
||||
func finalizeDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary) {
|
||||
if summary == nil {
|
||||
return
|
||||
}
|
||||
summary.WorkDuration = summary.SourceReadDuration + summary.StreamWriteDuration + summary.SinkWriteDuration +
|
||||
summary.SyncDuration + summary.VerifyDuration + summary.CommitDuration
|
||||
summary.ObservedDuration = summary.WorkDuration + summary.CommitWaitDuration
|
||||
summary.SourceReadThroughputBPS = throughputBytesPerSecond(summary.SourceReadBytes, summary.SourceReadDuration)
|
||||
summary.StreamWriteThroughputBPS = throughputBytesPerSecond(summary.StreamWriteBytes, summary.StreamWriteDuration)
|
||||
summary.SinkWriteThroughputBPS = throughputBytesPerSecond(summary.SinkWriteBytes, summary.SinkWriteDuration)
|
||||
summary.CommitWaitRatio = durationRatio(summary.CommitWaitDuration, summary.ObservedDuration)
|
||||
}
|
||||
|
||||
func sortClientConnRuntimeSnapshots(src []ClientConnRuntimeSnapshot) {
|
||||
sort.Slice(src, func(i, j int) bool {
|
||||
if src[i].ClientID != src[j].ClientID {
|
||||
return src[i].ClientID < src[j].ClientID
|
||||
}
|
||||
if src[i].TransportGeneration != src[j].TransportGeneration {
|
||||
return src[i].TransportGeneration < src[j].TransportGeneration
|
||||
}
|
||||
return src[i].RemoteAddress < src[j].RemoteAddress
|
||||
})
|
||||
}
|
||||
|
||||
func sortTransportConnRuntimeSnapshots(src []TransportConnRuntimeSnapshot) {
|
||||
sort.Slice(src, func(i, j int) bool {
|
||||
if src[i].ClientID != src[j].ClientID {
|
||||
return src[i].ClientID < src[j].ClientID
|
||||
}
|
||||
if src[i].TransportGeneration != src[j].TransportGeneration {
|
||||
return src[i].TransportGeneration < src[j].TransportGeneration
|
||||
}
|
||||
return src[i].RemoteAddress < src[j].RemoteAddress
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,417 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
itransfer "b612.me/notify/internal/transfer"
|
||||
)
|
||||
|
||||
func TestGetClientDiagnosticsSnapshotDefaults(t *testing.T) {
|
||||
client := NewClient()
|
||||
snapshot, err := GetClientDiagnosticsSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.Runtime.OwnerState, "idle"; got != want {
|
||||
t.Fatalf("Runtime.OwnerState = %q, want %q", got, want)
|
||||
}
|
||||
if len(snapshot.Streams) != 0 || len(snapshot.Bulks) != 0 || len(snapshot.Transfers) != 0 {
|
||||
t.Fatalf("default diagnostics should be empty: %+v", snapshot)
|
||||
}
|
||||
if snapshot.Summary != (DiagnosticsSummary{}) {
|
||||
t.Fatalf("default summary mismatch: %+v", snapshot.Summary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientDiagnosticsSnapshotAggregatesActiveState(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
|
||||
streamAcceptCh := make(chan StreamAcceptInfo, 1)
|
||||
bulkAcceptCh := make(chan BulkAcceptInfo, 1)
|
||||
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||||
streamAcceptCh <- info
|
||||
return nil
|
||||
})
|
||||
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||
bulkAcceptCh <- info
|
||||
return nil
|
||||
})
|
||||
|
||||
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
t.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = server.Stop()
|
||||
}()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||
t.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Stop()
|
||||
}()
|
||||
|
||||
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{
|
||||
ID: "diag-client-stream",
|
||||
Channel: StreamDataChannel,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("client OpenStream failed: %v", err)
|
||||
}
|
||||
waitAcceptedStream(t, streamAcceptCh, 2*time.Second)
|
||||
|
||||
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||
ID: "diag-client-bulk",
|
||||
Range: BulkRange{
|
||||
Length: 64,
|
||||
},
|
||||
ChunkSize: 16 * 1024,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("client OpenBulk failed: %v", err)
|
||||
}
|
||||
waitAcceptedBulk(t, bulkAcceptCh, 2*time.Second)
|
||||
|
||||
transferRuntime := client.getTransferRuntime()
|
||||
transferRuntime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{
|
||||
ID: "diag-client-transfer-done",
|
||||
Channel: itransfer.DataChannel,
|
||||
Size: 32,
|
||||
Checksum: "sum-client",
|
||||
})
|
||||
transferRuntime.activate(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done")
|
||||
transferRuntime.complete(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done")
|
||||
|
||||
snapshot, err := GetClientDiagnosticsSnapshot(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.Summary.LogicalCount, 1; got != want {
|
||||
t.Fatalf("LogicalCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.CurrentTransportCount, 1; got != want {
|
||||
t.Fatalf("CurrentTransportCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.StreamCount, 1; got != want {
|
||||
t.Fatalf("StreamCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.ActiveStreamCount, 1; got != want {
|
||||
t.Fatalf("ActiveStreamCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.BulkCount, 1; got != want {
|
||||
t.Fatalf("BulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.ActiveBulkCount, 1; got != want {
|
||||
t.Fatalf("ActiveBulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.TransferCount, 1; got != want {
|
||||
t.Fatalf("TransferCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.DoneTransferCount, 1; got != want {
|
||||
t.Fatalf("DoneTransferCount = %d, want %d", got, want)
|
||||
}
|
||||
if got := snapshot.Summary.StaleStreamCount + snapshot.Summary.ResetStreamCount + snapshot.Summary.StaleBulkCount + snapshot.Summary.ResetBulkCount + snapshot.Summary.FailedTransferCount; got != 0 {
|
||||
t.Fatalf("unexpected unhealthy counters in active snapshot: %+v", snapshot.Summary)
|
||||
}
|
||||
|
||||
_ = stream.Close()
|
||||
_ = bulk.Close()
|
||||
}
|
||||
|
||||
func TestGetServerDiagnosticsSnapshotAggregatesStaleAndResetState(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
|
||||
left, right := net.Pipe()
|
||||
defer right.Close()
|
||||
|
||||
logical := server.bootstrapAcceptedLogical("diag-server-peer", nil, left)
|
||||
if logical == nil {
|
||||
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||||
}
|
||||
logical.markIdentityBound()
|
||||
logical.compatClientConn().markClientConnStreamTransport()
|
||||
transport := logical.CurrentTransportConn()
|
||||
if transport == nil {
|
||||
t.Fatal("CurrentTransportConn should return active transport")
|
||||
}
|
||||
scope := serverFileScope(logical)
|
||||
|
||||
streamStale := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{
|
||||
StreamID: "diag-stream-stale",
|
||||
DataID: 1,
|
||||
Channel: StreamDataChannel,
|
||||
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
|
||||
if err := server.getStreamRuntime().register(scope, streamStale); err != nil {
|
||||
t.Fatalf("register stale stream failed: %v", err)
|
||||
}
|
||||
streamReset := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{
|
||||
StreamID: "diag-stream-reset",
|
||||
DataID: 2,
|
||||
Channel: StreamDataChannel,
|
||||
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
|
||||
if err := server.getStreamRuntime().register(scope, streamReset); err != nil {
|
||||
t.Fatalf("register reset stream failed: %v", err)
|
||||
}
|
||||
streamReset.mu.Lock()
|
||||
streamReset.resetErr = errTransportDetached
|
||||
streamReset.mu.Unlock()
|
||||
|
||||
bulkStale := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{
|
||||
BulkID: "diag-bulk-stale",
|
||||
DataID: 3,
|
||||
Range: BulkRange{
|
||||
Length: 16,
|
||||
},
|
||||
ChunkSize: 32 * 1024,
|
||||
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil)
|
||||
if err := server.getBulkRuntime().register(scope, bulkStale); err != nil {
|
||||
t.Fatalf("register stale bulk failed: %v", err)
|
||||
}
|
||||
bulkReset := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{
|
||||
BulkID: "diag-bulk-reset",
|
||||
DataID: 4,
|
||||
Dedicated: true,
|
||||
Range: BulkRange{
|
||||
Length: 16,
|
||||
},
|
||||
ChunkSize: 32 * 1024,
|
||||
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil)
|
||||
if err := server.getBulkRuntime().register(scope, bulkReset); err != nil {
|
||||
t.Fatalf("register reset bulk failed: %v", err)
|
||||
}
|
||||
bulkReset.mu.Lock()
|
||||
bulkReset.resetErr = errTransportDetached
|
||||
bulkReset.mu.Unlock()
|
||||
|
||||
transferRuntime := server.getTransferRuntime()
|
||||
transferRuntime.ensureTransferDescriptor(fileTransferDirectionReceive, scope, scope, transport.TransportGeneration(), itransfer.Descriptor{
|
||||
ID: "diag-transfer-failed",
|
||||
Channel: itransfer.DataChannel,
|
||||
Size: 64,
|
||||
Checksum: "sum-server",
|
||||
})
|
||||
transferRuntime.activate(fileTransferDirectionReceive, scope, "diag-transfer-failed")
|
||||
transferRuntime.fail(fileTransferDirectionReceive, scope, "diag-transfer-failed", errors.New("boom"))
|
||||
|
||||
logical.markTransportDetached("heartbeat timeout", nil)
|
||||
logical.detachServerOwnedTransport()
|
||||
|
||||
snapshot, err := GetServerDiagnosticsSnapshot(server)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerDiagnosticsSnapshot failed: %v", err)
|
||||
}
|
||||
if got, want := len(snapshot.Logicals), 1; got != want {
|
||||
t.Fatalf("logical snapshot count = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := len(snapshot.CurrentTransports), 0; got != want {
|
||||
t.Fatalf("current transport snapshot count = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Runtime.DetachedClientCount, 1; got != want {
|
||||
t.Fatalf("DetachedClientCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.LogicalCount, 1; got != want {
|
||||
t.Fatalf("LogicalCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.CurrentTransportCount, 0; got != want {
|
||||
t.Fatalf("CurrentTransportCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.StreamCount, 2; got != want {
|
||||
t.Fatalf("StreamCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.StaleStreamCount, 1; got != want {
|
||||
t.Fatalf("StaleStreamCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.ResetStreamCount, 1; got != want {
|
||||
t.Fatalf("ResetStreamCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.StreamResetCauses.Total, 1; got != want {
|
||||
t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.StreamResetCauses.TransportDetached, 1; got != want {
|
||||
t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.BulkCount, 2; got != want {
|
||||
t.Fatalf("BulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.DedicatedBulkCount, 1; got != want {
|
||||
t.Fatalf("DedicatedBulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.StaleBulkCount, 1; got != want {
|
||||
t.Fatalf("StaleBulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.ResetBulkCount, 1; got != want {
|
||||
t.Fatalf("ResetBulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.BulkResetCauses.Total, 1; got != want {
|
||||
t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.BulkResetCauses.TransportDetached, 1; got != want {
|
||||
t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.TransferCount, 1; got != want {
|
||||
t.Fatalf("TransferCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Summary.FailedTransferCount, 1; got != want {
|
||||
t.Fatalf("FailedTransferCount = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnosticsSummaryClassifiesResetCauses(t *testing.T) {
|
||||
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
|
||||
Streams: []StreamSnapshot{
|
||||
{ResetError: errTransportDetached.Error()},
|
||||
{ResetError: errServiceShutdown.Error()},
|
||||
{ResetError: errStreamBackpressureExceeded.Error()},
|
||||
{ResetError: "stream boom"},
|
||||
},
|
||||
Bulks: []BulkSnapshot{
|
||||
{ResetError: errTransportDetached.Error()},
|
||||
{ResetError: errServiceShutdown.Error()},
|
||||
{ResetError: errBulkBackpressureExceeded.Error()},
|
||||
{ResetError: "bulk boom"},
|
||||
},
|
||||
})
|
||||
|
||||
if got, want := summary.ResetStreamCount, 4; got != want {
|
||||
t.Fatalf("ResetStreamCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.StreamResetCauses.Total, 4; got != want {
|
||||
t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.StreamResetCauses.TransportDetached, 1; got != want {
|
||||
t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.StreamResetCauses.ServiceShutdown, 1; got != want {
|
||||
t.Fatalf("StreamResetCauses.ServiceShutdown = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.StreamResetCauses.Backpressure, 1; got != want {
|
||||
t.Fatalf("StreamResetCauses.Backpressure = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.StreamResetCauses.Other, 1; got != want {
|
||||
t.Fatalf("StreamResetCauses.Other = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if got, want := summary.ResetBulkCount, 4; got != want {
|
||||
t.Fatalf("ResetBulkCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.BulkResetCauses.Total, 4; got != want {
|
||||
t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.BulkResetCauses.TransportDetached, 1; got != want {
|
||||
t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.BulkResetCauses.ServiceShutdown, 1; got != want {
|
||||
t.Fatalf("BulkResetCauses.ServiceShutdown = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.BulkResetCauses.Backpressure, 1; got != want {
|
||||
t.Fatalf("BulkResetCauses.Backpressure = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.BulkResetCauses.Other, 1; got != want {
|
||||
t.Fatalf("BulkResetCauses.Other = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiagnosticsSummaryAggregatesTransferTelemetry(t *testing.T) {
|
||||
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
|
||||
Transfers: []TransferSnapshot{
|
||||
{
|
||||
ID: "send-done",
|
||||
State: TransferStateDone,
|
||||
SentBytes: 2048,
|
||||
SourceReadDuration: 200 * time.Millisecond,
|
||||
StreamWriteDuration: 400 * time.Millisecond,
|
||||
CommitWaitDuration: 100 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
ID: "recv-failed",
|
||||
State: TransferStateFailed,
|
||||
ReceivedBytes: 1024,
|
||||
SinkWriteDuration: 250 * time.Millisecond,
|
||||
SyncDuration: 50 * time.Millisecond,
|
||||
VerifyDuration: 25 * time.Millisecond,
|
||||
CommitDuration: 75 * time.Millisecond,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
if got, want := summary.TransferCount, 2; got != want {
|
||||
t.Fatalf("TransferCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.DoneTransferCount, 1; got != want {
|
||||
t.Fatalf("DoneTransferCount = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := summary.FailedTransferCount, 1; got != want {
|
||||
t.Fatalf("FailedTransferCount = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
telemetry := summary.TransferTelemetry
|
||||
if got, want := telemetry.SourceReadBytes, int64(2048); got != want {
|
||||
t.Fatalf("SourceReadBytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := telemetry.StreamWriteBytes, int64(2048); got != want {
|
||||
t.Fatalf("StreamWriteBytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := telemetry.SinkWriteBytes, int64(1024); got != want {
|
||||
t.Fatalf("SinkWriteBytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := telemetry.SourceReadDuration, 200*time.Millisecond; got != want {
|
||||
t.Fatalf("SourceReadDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.StreamWriteDuration, 400*time.Millisecond; got != want {
|
||||
t.Fatalf("StreamWriteDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.SinkWriteDuration, 250*time.Millisecond; got != want {
|
||||
t.Fatalf("SinkWriteDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.SyncDuration, 50*time.Millisecond; got != want {
|
||||
t.Fatalf("SyncDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.VerifyDuration, 25*time.Millisecond; got != want {
|
||||
t.Fatalf("VerifyDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.CommitDuration, 75*time.Millisecond; got != want {
|
||||
t.Fatalf("CommitDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.CommitWaitDuration, 100*time.Millisecond; got != want {
|
||||
t.Fatalf("CommitWaitDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.WorkDuration, time.Second; got != want {
|
||||
t.Fatalf("WorkDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.ObservedDuration, 1100*time.Millisecond; got != want {
|
||||
t.Fatalf("ObservedDuration = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := telemetry.SourceReadThroughputBPS, 10240.0; math.Abs(got-want) > 0.001 {
|
||||
t.Fatalf("SourceReadThroughputBPS = %f, want %f", got, want)
|
||||
}
|
||||
if got, want := telemetry.StreamWriteThroughputBPS, 5120.0; math.Abs(got-want) > 0.001 {
|
||||
t.Fatalf("StreamWriteThroughputBPS = %f, want %f", got, want)
|
||||
}
|
||||
if got, want := telemetry.SinkWriteThroughputBPS, 4096.0; math.Abs(got-want) > 0.001 {
|
||||
t.Fatalf("SinkWriteThroughputBPS = %f, want %f", got, want)
|
||||
}
|
||||
if got, want := telemetry.CommitWaitRatio, 1.0/11.0; math.Abs(got-want) > 0.000001 {
|
||||
t.Fatalf("CommitWaitRatio = %f, want %f", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetDiagnosticsSnapshotRejectsNil(t *testing.T) {
|
||||
if _, err := GetClientDiagnosticsSnapshot(nil); !errors.Is(err, errClientDiagnosticsSnapshotNil) {
|
||||
t.Fatalf("GetClientDiagnosticsSnapshot nil error = %v, want %v", err, errClientDiagnosticsSnapshotNil)
|
||||
}
|
||||
if _, err := GetServerDiagnosticsSnapshot(nil); !errors.Is(err, errServerDiagnosticsSnapshotNil) {
|
||||
t.Fatalf("GetServerDiagnosticsSnapshot nil error = %v, want %v", err, errServerDiagnosticsSnapshotNil)
|
||||
}
|
||||
}
|
||||
+184
@@ -0,0 +1,184 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"b612.me/notify/internal/timeutil"
|
||||
crand "crypto/rand"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
type EnvelopeKind uint8
|
||||
|
||||
const (
|
||||
EnvelopeSignal EnvelopeKind = iota
|
||||
EnvelopeSignalAck
|
||||
EnvelopeStreamData
|
||||
EnvelopeFileMeta
|
||||
EnvelopeFileChunk
|
||||
EnvelopeFileEnd
|
||||
EnvelopeFileAbort
|
||||
EnvelopeAck
|
||||
)
|
||||
|
||||
type Envelope struct {
|
||||
Kind EnvelopeKind
|
||||
ID uint64
|
||||
Body []byte
|
||||
Stream StreamPacket
|
||||
File FilePacket
|
||||
}
|
||||
|
||||
type StreamPacket struct {
|
||||
StreamID string
|
||||
Chunk []byte
|
||||
}
|
||||
|
||||
type FilePacket struct {
|
||||
FileID string
|
||||
Name string
|
||||
Size int64
|
||||
Mode uint32
|
||||
ModTime int64
|
||||
Offset int64
|
||||
Chunk []byte
|
||||
Checksum string
|
||||
Error string
|
||||
Stage string
|
||||
}
|
||||
|
||||
func wrapTransferMsgEnvelope(msg TransferMsg, enFn func(interface{}) ([]byte, error)) (Envelope, error) {
|
||||
body, err := enFn(msg)
|
||||
if err != nil {
|
||||
return Envelope{}, err
|
||||
}
|
||||
return Envelope{
|
||||
Kind: EnvelopeSignal,
|
||||
ID: msg.ID,
|
||||
Body: body,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func unwrapTransferMsgEnvelope(env Envelope, deFn func([]byte) (interface{}, error)) (TransferMsg, error) {
|
||||
if env.Kind != EnvelopeSignal {
|
||||
return TransferMsg{}, errors.New("envelope kind is not signal")
|
||||
}
|
||||
data, err := deFn(env.Body)
|
||||
if err != nil {
|
||||
return TransferMsg{}, err
|
||||
}
|
||||
msg, ok := data.(TransferMsg)
|
||||
if !ok {
|
||||
return TransferMsg{}, errors.New("invalid signal envelope payload")
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func newSignalAckEnvelope(signalID uint64) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeSignalAck,
|
||||
ID: signalID,
|
||||
}
|
||||
}
|
||||
|
||||
func newStreamDataEnvelope(streamID string, chunk []byte) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeStreamData,
|
||||
Stream: StreamPacket{
|
||||
StreamID: streamID,
|
||||
Chunk: chunk,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFileMetaEnvelope(fileID string, fileName string, fileSize int64, checksum string, mode uint32, modTime int64) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeFileMeta,
|
||||
File: FilePacket{
|
||||
FileID: fileID,
|
||||
Name: filepath.Base(fileName),
|
||||
Size: fileSize,
|
||||
Mode: mode,
|
||||
ModTime: modTime,
|
||||
Checksum: checksum,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFileChunkEnvelope(fileID string, offset int64, chunk []byte) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeFileChunk,
|
||||
File: FilePacket{
|
||||
FileID: fileID,
|
||||
Offset: offset,
|
||||
Chunk: chunk,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFileEndEnvelope(fileID string) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeFileEnd,
|
||||
File: FilePacket{
|
||||
FileID: fileID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFileAbortEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeFileAbort,
|
||||
File: FilePacket{
|
||||
FileID: fileID,
|
||||
Stage: stage,
|
||||
Offset: offset,
|
||||
Error: errMsg,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func newFileAckEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope {
|
||||
return Envelope{
|
||||
Kind: EnvelopeAck,
|
||||
File: FilePacket{
|
||||
FileID: fileID,
|
||||
Stage: stage,
|
||||
Offset: offset,
|
||||
Error: errMsg,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var fileIDSerial uint64
|
||||
|
||||
func buildFileID(fileName string) string {
|
||||
base := fileIDBaseName(fileName)
|
||||
ts := uint64(timeutil.NowUnixNano())
|
||||
pid := uint64(os.Getpid())
|
||||
seq := atomic.AddUint64(&fileIDSerial, 1)
|
||||
rnd := uint64(randomFileIDSuffix())
|
||||
return fmt.Sprintf("%s-%x-%x-%x-%x", base, ts, pid, seq, rnd)
|
||||
}
|
||||
|
||||
func fileIDBaseName(fileName string) string {
|
||||
base := sanitizeFileName(filepath.Base(fileName))
|
||||
switch base {
|
||||
case "", ".", "/", "\\":
|
||||
return "unnamed"
|
||||
default:
|
||||
return base
|
||||
}
|
||||
}
|
||||
|
||||
func randomFileIDSuffix() uint32 {
|
||||
var buf [4]byte
|
||||
if _, err := crand.Read(buf[:]); err == nil {
|
||||
return binary.BigEndian.Uint32(buf[:])
|
||||
}
|
||||
seq := atomic.LoadUint64(&fileIDSerial)
|
||||
mix := uint64(timeutil.NowUnixNano()) ^ (seq << 1) ^ uint64(os.Getpid())
|
||||
return uint32(mix ^ (mix >> 32))
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildFileIDUniqueAcrossBurst(t *testing.T) {
|
||||
const total = 512
|
||||
seen := make(map[string]struct{}, total)
|
||||
for i := 0; i < total; i++ {
|
||||
id := buildFileID("report.txt")
|
||||
if _, ok := seen[id]; ok {
|
||||
t.Fatalf("duplicate file id generated: %q", id)
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFileIDKeepsReadableBaseName(t *testing.T) {
|
||||
id := buildFileID("/tmp/demo/report.txt")
|
||||
if !strings.HasPrefix(id, "report.txt-") {
|
||||
t.Fatalf("unexpected file id prefix: %q", id)
|
||||
}
|
||||
|
||||
parts := strings.Split(id, "-")
|
||||
if got, want := len(parts), 5; got != want {
|
||||
t.Fatalf("unexpected file id segment count: got %d want %d, id=%q", got, want, id)
|
||||
}
|
||||
for _, part := range parts[1:] {
|
||||
if part == "" {
|
||||
t.Fatalf("unexpected empty file id segment: %q", id)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildFileIDFallsBackToUnnamedBase(t *testing.T) {
|
||||
id := buildFileID("")
|
||||
if !strings.HasPrefix(id, "unnamed-") {
|
||||
t.Fatalf("unexpected unnamed file id prefix: %q", id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewFileMetaEnvelopeKeepsOptionalMeta(t *testing.T) {
|
||||
env := newFileMetaEnvelope("file-1", "/tmp/demo/report.txt", 123, "sum", 0o640, 123456789)
|
||||
if got, want := env.File.Mode, uint32(0o640); got != want {
|
||||
t.Fatalf("mode mismatch: got %o want %o", got, want)
|
||||
}
|
||||
if got, want := env.File.ModTime, int64(123456789); got != want {
|
||||
t.Fatalf("modtime mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
# Signal Demo
|
||||
|
||||
`examples/signal` 演示 `notify` 的最小消息收发路径,覆盖服务端监听、客户端 `SendWait`、服务端 `Reply` 和并发请求。
|
||||
|
||||
## 功能
|
||||
|
||||
- `serve`:启动服务端并监听本地 IPC 端点
|
||||
- `signal`:发送消息并等待回包
|
||||
- 并发发送:`-n` 指定总请求数,`-c` 指定并发数
|
||||
|
||||
## 运行
|
||||
|
||||
在模块根目录执行:
|
||||
|
||||
```bash
|
||||
go run ./examples/signal serve
|
||||
```
|
||||
|
||||
另开终端发送单条消息:
|
||||
|
||||
```bash
|
||||
go run ./examples/signal signal --msg "hello"
|
||||
```
|
||||
|
||||
并发请求示例:
|
||||
|
||||
```bash
|
||||
go run ./examples/signal signal --msg "ping" --n 100 --c 10
|
||||
```
|
||||
|
||||
## 默认端点
|
||||
|
||||
- Windows:`network=npipe`,`addr=notify-signal-demo`
|
||||
- Linux:`network=unix`,`addr=/tmp/notify-signal-demo.sock`
|
||||
|
||||
可通过 `--addr` 覆盖默认地址。
|
||||
|
||||
## 说明
|
||||
|
||||
- 示例中使用固定 PSK,仅用于本地演示。
|
||||
- 示例的并发模式用于接口验证,不作为吞吐基准测试。
|
||||
@@ -0,0 +1,217 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"b612.me/notify"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultPipeName = "notify-signal-demo"
|
||||
defaultUnixSock = "/tmp/notify-signal-demo.sock"
|
||||
sharedSecret = "0123456789abcdef0123456789abcdef"
|
||||
)
|
||||
|
||||
func main() {
|
||||
args := os.Args[1:]
|
||||
if len(args) == 0 {
|
||||
if err := runServe(nil); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "serve", "server":
|
||||
if err := runServe(args[1:]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "signal":
|
||||
if err := runSignal(args[1:]); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "signal failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
case "-h", "--help", "help":
|
||||
printUsage()
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", args[0])
|
||||
printUsage()
|
||||
os.Exit(2)
|
||||
}
|
||||
}
|
||||
|
||||
func runServe(args []string) error {
|
||||
network, defaultAddr := defaultEndpoint()
|
||||
|
||||
fs := flag.NewFlagSet("serve", flag.ContinueOnError)
|
||||
addr := fs.String("addr", defaultAddr, "listen address (windows: pipe name or \\\\.\\pipe\\name; linux: unix socket path)")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
srv := notify.NewServer()
|
||||
if err := notify.UseModernPSKServer(srv, []byte(sharedSecret), nil); err != nil {
|
||||
return fmt.Errorf("configure modern psk server: %w", err)
|
||||
}
|
||||
srv.SetLink("signal", func(msg *notify.Message) {
|
||||
content := string(msg.Value)
|
||||
fmt.Printf("[server] recv signal: %s\n", content)
|
||||
reply := fmt.Sprintf("ack from server: %s", content)
|
||||
if err := msg.Reply([]byte(reply)); err != nil {
|
||||
fmt.Printf("[server] reply error: %v\n", err)
|
||||
}
|
||||
})
|
||||
|
||||
cleanup, err := prepareEndpoint(network, *addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if err := srv.Listen(network, *addr); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[server] listening on %s %s\n", network, *addr)
|
||||
|
||||
stopSig := make(chan os.Signal, 1)
|
||||
signal.Notify(stopSig, os.Interrupt, syscall.SIGTERM)
|
||||
<-stopSig
|
||||
|
||||
fmt.Println("[server] stopping...")
|
||||
return srv.Stop()
|
||||
}
|
||||
|
||||
func runSignal(args []string) error {
|
||||
network, defaultAddr := defaultEndpoint()
|
||||
|
||||
fs := flag.NewFlagSet("signal", flag.ContinueOnError)
|
||||
addr := fs.String("addr", defaultAddr, "target address")
|
||||
msg := fs.String("msg", "hello", "signal payload")
|
||||
count := fs.Int("n", 1, "total request count")
|
||||
concurrency := fs.Int("c", 1, "concurrency for requests")
|
||||
timeout := fs.Duration("timeout", 5*time.Second, "wait timeout per request")
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return err
|
||||
}
|
||||
if *count <= 0 {
|
||||
return errors.New("-n must be > 0")
|
||||
}
|
||||
if *concurrency <= 0 {
|
||||
return errors.New("-c must be > 0")
|
||||
}
|
||||
|
||||
if *count == 1 && *concurrency == 1 {
|
||||
reply, err := sendOne(network, *addr, *msg, *timeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("[client] recv reply: %s\n", reply)
|
||||
return nil
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
var wg sync.WaitGroup
|
||||
jobs := make(chan int)
|
||||
errCh := make(chan error, *count)
|
||||
|
||||
worker := func() {
|
||||
defer wg.Done()
|
||||
for i := range jobs {
|
||||
payload := fmt.Sprintf("%s #%d", *msg, i+1)
|
||||
reply, err := sendOne(network, *addr, payload, *timeout)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("job=%d: %w", i+1, err)
|
||||
continue
|
||||
}
|
||||
fmt.Printf("[client] job=%d reply=%s\n", i+1, reply)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < *concurrency; i++ {
|
||||
wg.Add(1)
|
||||
go worker()
|
||||
}
|
||||
for i := 0; i < *count; i++ {
|
||||
jobs <- i
|
||||
}
|
||||
close(jobs)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
failures := 0
|
||||
for err := range errCh {
|
||||
failures++
|
||||
fmt.Printf("[client] error: %v\n", err)
|
||||
}
|
||||
fmt.Printf("[client] done total=%d concurrency=%d failures=%d elapsed=%s\n", *count, *concurrency, failures, time.Since(start).Round(time.Millisecond))
|
||||
if failures > 0 {
|
||||
return fmt.Errorf("concurrent signal test finished with %d failures", failures)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendOne(network string, addr string, payload string, timeout time.Duration) (string, error) {
|
||||
cli := notify.NewClient()
|
||||
if err := notify.UseModernPSKClient(cli, []byte(sharedSecret), nil); err != nil {
|
||||
return "", fmt.Errorf("configure modern psk client: %w", err)
|
||||
}
|
||||
if err := cli.Connect(network, addr); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
_ = cli.Stop()
|
||||
}()
|
||||
|
||||
reply, err := cli.SendWait("signal", []byte(payload), timeout)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(reply.Value), nil
|
||||
}
|
||||
|
||||
func defaultEndpoint() (network string, addr string) {
|
||||
if runtime.GOOS == "windows" {
|
||||
return "npipe", defaultPipeName
|
||||
}
|
||||
return "unix", defaultUnixSock
|
||||
}
|
||||
|
||||
func prepareEndpoint(network string, addr string) (func(), error) {
|
||||
if network != "unix" {
|
||||
return func() {}, nil
|
||||
}
|
||||
if addr == "" {
|
||||
return nil, errors.New("unix socket path is empty")
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(addr), 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = os.Remove(addr)
|
||||
return func() {
|
||||
_ = os.Remove(addr)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func printUsage() {
|
||||
fmt.Println("Usage:")
|
||||
fmt.Println(" signal-demo serve [--addr <addr>]")
|
||||
fmt.Println(" signal-demo signal [--addr <addr>] [--msg <text>] [--n <count>] [--c <concurrency>] [--timeout <duration>]")
|
||||
fmt.Println("")
|
||||
fmt.Println("Defaults:")
|
||||
if runtime.GOOS == "windows" {
|
||||
fmt.Printf(" network=npipe addr=%s\n", defaultPipeName)
|
||||
} else {
|
||||
fmt.Printf(" network=unix addr=%s\n", defaultUnixSock)
|
||||
}
|
||||
}
|
||||
+177
@@ -0,0 +1,177 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
errFileAckCanceled = errors.New("file ack canceled")
|
||||
errFileAckTimeout = errors.New("file ack timeout")
|
||||
)
|
||||
|
||||
type fileAckWait struct {
|
||||
key string
|
||||
scope string
|
||||
pool *fileAckPool
|
||||
reply chan FileEvent
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
type fileAckPool struct {
|
||||
pool sync.Map
|
||||
}
|
||||
|
||||
func newFileAckPool() *fileAckPool {
|
||||
return &fileAckPool{}
|
||||
}
|
||||
|
||||
func fileAckKey(scope string, fileID string, stage string, offset int64) string {
|
||||
return normalizeFileScope(scope) + "|" + fileID + "|" + stage + "|" + formatInt(offset)
|
||||
}
|
||||
|
||||
func (p *fileAckPool) prepare(scope string, fileID string, stage string, offset int64) *fileAckWait {
|
||||
scope = normalizeFileScope(scope)
|
||||
wait := &fileAckWait{
|
||||
key: fileAckKey(scope, fileID, stage, offset),
|
||||
scope: scope,
|
||||
pool: p,
|
||||
reply: make(chan FileEvent, 1),
|
||||
}
|
||||
p.pool.Store(wait.key, wait)
|
||||
return wait
|
||||
}
|
||||
|
||||
func (p *fileAckPool) deliver(scope string, event FileEvent) bool {
|
||||
return p.deliverAny([]string{scope}, event)
|
||||
}
|
||||
|
||||
func (p *fileAckPool) deliverAny(scopes []string, event FileEvent) bool {
|
||||
if p == nil {
|
||||
return false
|
||||
}
|
||||
for _, scope := range scopes {
|
||||
key := fileAckKey(scope, event.Packet.FileID, event.Packet.Stage, event.Packet.Offset)
|
||||
data, ok := p.pool.LoadAndDelete(key)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
wait := data.(*fileAckWait)
|
||||
wait.deliver(event)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *fileAckWait) cancel() {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
if w.pool != nil {
|
||||
w.pool.pool.Delete(w.key)
|
||||
}
|
||||
w.closeReply()
|
||||
}
|
||||
|
||||
func (w *fileAckWait) deliver(event FileEvent) {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.closeOnce.Do(func() {
|
||||
select {
|
||||
case w.reply <- event:
|
||||
default:
|
||||
}
|
||||
close(w.reply)
|
||||
})
|
||||
}
|
||||
|
||||
func (w *fileAckWait) closeReply() {
|
||||
if w == nil {
|
||||
return
|
||||
}
|
||||
w.closeOnce.Do(func() {
|
||||
close(w.reply)
|
||||
})
|
||||
}
|
||||
|
||||
func (p *fileAckPool) waitPrepared(wait *fileAckWait, timeout time.Duration) error {
|
||||
if timeout <= 0 {
|
||||
timeout = defaultFileAckTimeout
|
||||
}
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case event, ok := <-wait.reply:
|
||||
if !ok {
|
||||
return errFileAckCanceled
|
||||
}
|
||||
if event.Err != nil {
|
||||
return event.Err
|
||||
}
|
||||
if event.Packet.Error != "" {
|
||||
return errors.New(event.Packet.Error)
|
||||
}
|
||||
return nil
|
||||
case <-timer.C:
|
||||
wait.cancel()
|
||||
return errFileAckTimeout
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fileAckPool) wait(scope string, fileID string, stage string, offset int64, timeout time.Duration) error {
|
||||
wait := p.prepare(scope, fileID, stage, offset)
|
||||
return p.waitPrepared(wait, timeout)
|
||||
}
|
||||
|
||||
func (p *fileAckPool) closeAll() {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
p.pool.Range(func(_, value interface{}) bool {
|
||||
value.(*fileAckWait).cancel()
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (p *fileAckPool) closeScope(scope string) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
scope = normalizeFileScope(scope)
|
||||
p.pool.Range(func(_, value interface{}) bool {
|
||||
wait := value.(*fileAckWait)
|
||||
if wait.scope == scope {
|
||||
wait.cancel()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func (p *fileAckPool) closeScopeFamily(scope string) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
base := normalizeFileScope(scope)
|
||||
p.pool.Range(func(_, value interface{}) bool {
|
||||
wait := value.(*fileAckWait)
|
||||
if scopeBelongsToServerFileScope(wait.scope, base) {
|
||||
wait.cancel()
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
func formatInt(v int64) string {
|
||||
return strconv.FormatInt(v, 10)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getFileAckPool() *fileAckPool {
|
||||
return c.getLogicalSessionState().fileAckWaits
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getFileAckPool() *fileAckPool {
|
||||
return s.getLogicalSessionState().fileAckWaits
|
||||
}
|
||||
@@ -0,0 +1,200 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fileTransferRetryHooks struct {
|
||||
onRetry func(err error, attempt int)
|
||||
onTimeout func(err error, attempt int)
|
||||
}
|
||||
|
||||
func fileStageByKind(kind EnvelopeKind) string {
|
||||
switch kind {
|
||||
case EnvelopeFileMeta:
|
||||
return "meta"
|
||||
case EnvelopeFileChunk:
|
||||
return "chunk"
|
||||
case EnvelopeFileEnd:
|
||||
return "end"
|
||||
case EnvelopeFileAbort:
|
||||
return "abort"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendFileAck(src Envelope, processErr error) error {
|
||||
errMsg := ""
|
||||
if processErr != nil {
|
||||
errMsg = processErr.Error()
|
||||
}
|
||||
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
|
||||
return c.sendEnvelope(ack)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileAck(logical *LogicalConn, src Envelope, processErr error) error {
|
||||
if logical == nil {
|
||||
return s.sendFileAckTransport(nil, src, processErr)
|
||||
}
|
||||
return s.sendFileAckTransport(s.resolveOutboundTransport(logical), src, processErr)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileAckTransport(transport *TransportConn, src Envelope, processErr error) error {
|
||||
errMsg := ""
|
||||
if processErr != nil {
|
||||
errMsg = processErr.Error()
|
||||
}
|
||||
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
|
||||
return s.sendEnvelopeTransport(transport, ack)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, src Envelope, processErr error) error {
|
||||
if conn == nil {
|
||||
return s.sendFileAckTransport(transport, src, processErr)
|
||||
}
|
||||
errMsg := ""
|
||||
if processErr != nil {
|
||||
errMsg = processErr.Error()
|
||||
}
|
||||
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
|
||||
return s.sendEnvelopeInboundTransport(logical, transport, conn, ack)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendFileAbort(fileID string, stage string, offset int64, cause error) error {
|
||||
errMsg := ""
|
||||
if cause != nil {
|
||||
errMsg = cause.Error()
|
||||
}
|
||||
return c.sendEnvelope(newFileAbortEnvelope(fileID, stage, offset, errMsg))
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileAbort(logical *LogicalConn, fileID string, stage string, offset int64, cause error) error {
|
||||
if logical == nil {
|
||||
return s.sendFileAbortTransport(nil, fileID, stage, offset, cause)
|
||||
}
|
||||
return s.sendFileAbortTransport(s.resolveOutboundTransport(logical), fileID, stage, offset, cause)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileAbortTransport(transport *TransportConn, fileID string, stage string, offset int64, cause error) error {
|
||||
errMsg := ""
|
||||
if cause != nil {
|
||||
errMsg = cause.Error()
|
||||
}
|
||||
return s.sendEnvelopeTransport(transport, newFileAbortEnvelope(fileID, stage, offset, errMsg))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendFileEnvelopeWithAck(env Envelope, timeout time.Duration) error {
|
||||
pool := c.getFileAckPool()
|
||||
wait := pool.prepare(clientFileScope(), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset)
|
||||
if err := c.sendEnvelope(env); err != nil {
|
||||
wait.cancel()
|
||||
return err
|
||||
}
|
||||
return pool.waitPrepared(wait, timeout)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileEnvelopeWithAck(logical *LogicalConn, env Envelope, timeout time.Duration) error {
|
||||
if logical == nil {
|
||||
return s.sendFileEnvelopeWithAckTransport(nil, env, timeout)
|
||||
}
|
||||
return s.sendFileEnvelopeWithAckTransport(s.resolveOutboundTransport(logical), env, timeout)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileEnvelopeWithAckTransport(transport *TransportConn, env Envelope, timeout time.Duration) error {
|
||||
pool := s.getFileAckPool()
|
||||
wait := pool.prepare(serverTransportScopeForTransport(transport), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset)
|
||||
if err := s.sendEnvelopeTransport(transport, env); err != nil {
|
||||
wait.cancel()
|
||||
return err
|
||||
}
|
||||
return pool.waitPrepared(wait, timeout)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendFileEnvelopeReliable(ctx context.Context, env Envelope, cfg fileTransferConfig) error {
|
||||
state := c.getFileTransferState()
|
||||
scope := clientFileScope()
|
||||
stage := fileStageByKind(env.Kind)
|
||||
state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||
return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error {
|
||||
return c.sendFileEnvelopeWithAck(env, cfg.AckTimeout)
|
||||
}, fileTransferRetryHooks{
|
||||
onRetry: func(err error, _ int) {
|
||||
state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID)
|
||||
},
|
||||
onTimeout: func(err error, _ int) {
|
||||
state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID)
|
||||
state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileEnvelopeReliable(ctx context.Context, logical *LogicalConn, env Envelope, cfg fileTransferConfig) error {
|
||||
if logical == nil {
|
||||
return s.sendFileEnvelopeReliableTransport(ctx, nil, env, cfg)
|
||||
}
|
||||
return s.sendFileEnvelopeReliableTransport(ctx, s.resolveOutboundTransport(logical), env, cfg)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileEnvelopeReliableTransport(ctx context.Context, transport *TransportConn, env Envelope, cfg fileTransferConfig) error {
|
||||
state := s.getFileTransferState()
|
||||
scope := serverTransportScopeForTransport(transport)
|
||||
stage := fileStageByKind(env.Kind)
|
||||
state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||
return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error {
|
||||
return s.sendFileEnvelopeWithAckTransport(transport, env, cfg.AckTimeout)
|
||||
}, fileTransferRetryHooks{
|
||||
onRetry: func(err error, _ int) {
|
||||
state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID)
|
||||
},
|
||||
onTimeout: func(err error, _ int) {
|
||||
state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID)
|
||||
state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func retryFileTransferSend(ctx context.Context, cfg fileTransferConfig, send func(fileTransferConfig) error, hooks ...fileTransferRetryHooks) error {
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
var lastErr error
|
||||
hook := mergeFileTransferRetryHooks(hooks...)
|
||||
for attempt := 0; attempt < cfg.SendRetry; attempt++ {
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
}
|
||||
lastErr = send(cfg)
|
||||
if lastErr == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(lastErr, errFileAckTimeout) && hook.onTimeout != nil {
|
||||
hook.onTimeout(lastErr, attempt+1)
|
||||
}
|
||||
if attempt+1 < cfg.SendRetry && hook.onRetry != nil {
|
||||
hook.onRetry(lastErr, attempt+1)
|
||||
}
|
||||
}
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("file send failed")
|
||||
}
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func mergeFileTransferRetryHooks(hooks ...fileTransferRetryHooks) fileTransferRetryHooks {
|
||||
var merged fileTransferRetryHooks
|
||||
for _, hook := range hooks {
|
||||
if hook.onRetry != nil {
|
||||
merged.onRetry = hook.onRetry
|
||||
}
|
||||
if hook.onTimeout != nil {
|
||||
merged.onTimeout = hook.onTimeout
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
@@ -0,0 +1,107 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRetryFileTransferSendHonorsRetryCount(t *testing.T) {
|
||||
var attempts int
|
||||
|
||||
err := retryFileTransferSend(context.Background(), fileTransferConfig{
|
||||
SendRetry: 3,
|
||||
}, func(cfg fileTransferConfig) error {
|
||||
attempts++
|
||||
return errors.New("send failed")
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("retryFileTransferSend should return the last error")
|
||||
}
|
||||
if got, want := attempts, 3; got != want {
|
||||
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryFileTransferSendStopsAfterSuccess(t *testing.T) {
|
||||
var attempts int
|
||||
|
||||
err := retryFileTransferSend(context.Background(), fileTransferConfig{
|
||||
SendRetry: 5,
|
||||
}, func(cfg fileTransferConfig) error {
|
||||
attempts++
|
||||
if attempts == 3 {
|
||||
return nil
|
||||
}
|
||||
return errors.New("send failed")
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("retryFileTransferSend should stop after success: %v", err)
|
||||
}
|
||||
if got, want := attempts, 3; got != want {
|
||||
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryFileTransferSendHonorsContextCancel(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
var attempts int
|
||||
err := retryFileTransferSend(ctx, fileTransferConfig{
|
||||
SendRetry: 3,
|
||||
}, func(cfg fileTransferConfig) error {
|
||||
attempts++
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Fatalf("expected context canceled, got %v", err)
|
||||
}
|
||||
if got, want := attempts, 0; got != want {
|
||||
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetryFileTransferSendReportsRetryAndTimeoutHooks(t *testing.T) {
|
||||
var attempts int
|
||||
var retries int
|
||||
var timeouts int
|
||||
|
||||
err := retryFileTransferSend(context.Background(), fileTransferConfig{
|
||||
SendRetry: 3,
|
||||
}, func(cfg fileTransferConfig) error {
|
||||
attempts++
|
||||
if attempts < 3 {
|
||||
return errFileAckTimeout
|
||||
}
|
||||
return nil
|
||||
}, fileTransferRetryHooks{
|
||||
onRetry: func(err error, attempt int) {
|
||||
retries++
|
||||
if !errors.Is(err, errFileAckTimeout) {
|
||||
t.Fatalf("retry err = %v, want %v", err, errFileAckTimeout)
|
||||
}
|
||||
if attempt != retries {
|
||||
t.Fatalf("retry attempt = %d, want %d", attempt, retries)
|
||||
}
|
||||
},
|
||||
onTimeout: func(err error, attempt int) {
|
||||
timeouts++
|
||||
if !errors.Is(err, errFileAckTimeout) {
|
||||
t.Fatalf("timeout err = %v, want %v", err, errFileAckTimeout)
|
||||
}
|
||||
if attempt != timeouts {
|
||||
t.Fatalf("timeout attempt = %d, want %d", attempt, timeouts)
|
||||
}
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("retryFileTransferSend should succeed after timeout retries: %v", err)
|
||||
}
|
||||
if got, want := retries, 2; got != want {
|
||||
t.Fatalf("retry hook count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := timeouts, 2; got != want {
|
||||
t.Fatalf("timeout hook count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package notify
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestFileAckPoolPreparedWaitConsumesEarlyAck(t *testing.T) {
|
||||
pool := newFileAckPool()
|
||||
wait := pool.prepare("client:a", "file-1", "chunk", 64)
|
||||
|
||||
ok := pool.deliver("client:a", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-1",
|
||||
Stage: "chunk",
|
||||
Offset: 64,
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatalf("deliver should match prepared waiter")
|
||||
}
|
||||
|
||||
if err := pool.waitPrepared(wait, defaultFileAckTimeout); err != nil {
|
||||
t.Fatalf("waitPrepared failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileAckPoolPreparedWaitReturnsAckError(t *testing.T) {
|
||||
pool := newFileAckPool()
|
||||
wait := pool.prepare("client:a", "file-2", "meta", 0)
|
||||
|
||||
ok := pool.deliver("client:a", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-2",
|
||||
Stage: "meta",
|
||||
Offset: 0,
|
||||
Error: "checksum mismatch",
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatalf("deliver should match prepared waiter")
|
||||
}
|
||||
|
||||
err := pool.waitPrepared(wait, defaultFileAckTimeout)
|
||||
if err == nil {
|
||||
t.Fatal("waitPrepared should return ack error")
|
||||
}
|
||||
if got, want := err.Error(), "checksum mismatch"; got != want {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileAckPoolCancelRemovesPreparedWaiter(t *testing.T) {
|
||||
pool := newFileAckPool()
|
||||
wait := pool.prepare("client:a", "file-3", "end", 0)
|
||||
wait.cancel()
|
||||
|
||||
ok := pool.deliver("client:a", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-3",
|
||||
Stage: "end",
|
||||
Offset: 0,
|
||||
},
|
||||
})
|
||||
if ok {
|
||||
t.Fatal("deliver should not match canceled waiter")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileAckPoolScopeIsolation(t *testing.T) {
|
||||
pool := newFileAckPool()
|
||||
waitA := pool.prepare("server:client-a", "file-4", "chunk", 128)
|
||||
waitB := pool.prepare("server:client-b", "file-4", "chunk", 128)
|
||||
|
||||
ok := pool.deliver("server:client-a", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-4",
|
||||
Stage: "chunk",
|
||||
Offset: 128,
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("deliver should match scopeA waiter")
|
||||
}
|
||||
|
||||
if err := pool.waitPrepared(waitA, defaultFileAckTimeout); err != nil {
|
||||
t.Fatalf("waitPrepared scopeA failed: %v", err)
|
||||
}
|
||||
|
||||
ok = pool.deliver("server:client-a", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-4",
|
||||
Stage: "chunk",
|
||||
Offset: 128,
|
||||
},
|
||||
})
|
||||
if ok {
|
||||
t.Fatal("scopeA ack should not consume scopeB waiter")
|
||||
}
|
||||
|
||||
ok = pool.deliver("server:client-b", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-4",
|
||||
Stage: "chunk",
|
||||
Offset: 128,
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("deliver should match scopeB waiter")
|
||||
}
|
||||
|
||||
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
|
||||
t.Fatalf("waitPrepared scopeB failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileAckPoolCloseAllCancelsPreparedWaiters(t *testing.T) {
|
||||
pool := newFileAckPool()
|
||||
wait := pool.prepare("client:a", "file-5", "chunk", 256)
|
||||
|
||||
pool.closeAll()
|
||||
|
||||
err := pool.waitPrepared(wait, defaultFileAckTimeout)
|
||||
if err == nil {
|
||||
t.Fatal("waitPrepared should return cancel error after closeAll")
|
||||
}
|
||||
if got, want := err.Error(), "file ack canceled"; got != want {
|
||||
t.Fatalf("unexpected error after closeAll: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileAckPoolCloseScopeCancelsMatchingWaitersOnly(t *testing.T) {
|
||||
pool := newFileAckPool()
|
||||
waitA := pool.prepare("server:client-a", "file-6", "chunk", 256)
|
||||
waitB := pool.prepare("server:client-b", "file-6", "chunk", 256)
|
||||
|
||||
pool.closeScope("server:client-a")
|
||||
|
||||
err := pool.waitPrepared(waitA, defaultFileAckTimeout)
|
||||
if err == nil {
|
||||
t.Fatal("scopeA waiter should be canceled")
|
||||
}
|
||||
if got, want := err.Error(), "file ack canceled"; got != want {
|
||||
t.Fatalf("unexpected scopeA error: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
ok := pool.deliver("server:client-b", FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-6",
|
||||
Stage: "chunk",
|
||||
Offset: 256,
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("scopeB waiter should remain deliverable")
|
||||
}
|
||||
|
||||
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
|
||||
t.Fatalf("waitPrepared scopeB failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerRemoveClientClosesScopedFileAckWaiters(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
clientA := &ClientConn{ClientID: "client-a"}
|
||||
clientB := &ClientConn{ClientID: "client-b"}
|
||||
pool := server.getFileAckPool()
|
||||
|
||||
waitA := pool.prepare(serverFileScope(clientA), "file-7", "end", 0)
|
||||
waitB := pool.prepare(serverFileScope(clientB), "file-7", "end", 0)
|
||||
|
||||
server.removeClient(clientA)
|
||||
|
||||
err := pool.waitPrepared(waitA, defaultFileAckTimeout)
|
||||
if err == nil {
|
||||
t.Fatal("clientA waiter should be canceled when client is removed")
|
||||
}
|
||||
if got, want := err.Error(), "file ack canceled"; got != want {
|
||||
t.Fatalf("unexpected clientA error: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
ok := pool.deliver(serverFileScope(clientB), FileEvent{
|
||||
Packet: FilePacket{
|
||||
FileID: "file-7",
|
||||
Stage: "end",
|
||||
Offset: 0,
|
||||
},
|
||||
})
|
||||
if !ok {
|
||||
t.Fatal("clientB waiter should remain deliverable")
|
||||
}
|
||||
|
||||
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
|
||||
t.Fatalf("waitPrepared clientB failed: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (c *ClientCommon) dispatchFileEnvelope(env Envelope, now time.Time) {
|
||||
event := FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
ServerConn: c,
|
||||
Kind: env.Kind,
|
||||
Packet: env.File,
|
||||
Time: now,
|
||||
}
|
||||
pool := c.getFileReceivePool()
|
||||
switch env.Kind {
|
||||
case EnvelopeAck:
|
||||
event.Packet.Stage = env.File.Stage
|
||||
event.Packet.Error = env.File.Error
|
||||
event.Received = env.File.Offset
|
||||
if c.getFileAckPool().deliver(clientFileScope(), event) {
|
||||
return
|
||||
}
|
||||
case EnvelopeFileMeta:
|
||||
session, err := pool.onMeta(clientFileScope(), env.File, now)
|
||||
if session != nil {
|
||||
event.Path = session.tmpPath
|
||||
event.Received = session.received
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
case EnvelopeFileChunk:
|
||||
session, err := pool.onChunk(clientFileScope(), env.File, now)
|
||||
if session != nil {
|
||||
event.Path = session.tmpPath
|
||||
event.Received = session.received
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
case EnvelopeFileEnd:
|
||||
finalPath, session, err := pool.onEnd(clientFileScope(), env.File, now)
|
||||
if session != nil {
|
||||
event.Path = finalPath
|
||||
event.Received = session.received
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
case EnvelopeFileAbort:
|
||||
session, err := pool.onAbort(clientFileScope(), env.File, now)
|
||||
event.Received = env.File.Offset
|
||||
if session != nil {
|
||||
event.Path = session.tmpPath
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
default:
|
||||
}
|
||||
if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort {
|
||||
if ackErr := c.sendFileAck(env, event.Err); ackErr != nil && event.Err == nil {
|
||||
event.Err = ackErr
|
||||
}
|
||||
}
|
||||
fillFileEventProgress(&event)
|
||||
c.publishReceivedFileEvent(event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) dispatchFileEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope, now time.Time) {
|
||||
if transport == nil && logical != nil {
|
||||
transport = logical.CurrentTransportConn()
|
||||
}
|
||||
event := FileEvent{
|
||||
LogicalConn: logical,
|
||||
NetType: NET_SERVER,
|
||||
TransportConn: transport,
|
||||
Kind: env.Kind,
|
||||
Packet: env.File,
|
||||
Time: now,
|
||||
}
|
||||
pool := s.getFileReceivePool()
|
||||
switch env.Kind {
|
||||
case EnvelopeAck:
|
||||
event.Packet.Stage = env.File.Stage
|
||||
event.Packet.Error = env.File.Error
|
||||
event.Received = env.File.Offset
|
||||
scopes := serverTransportDeliveryScopes(logical)
|
||||
if transport := fileEventTransportConnSnapshot(event); transport != nil {
|
||||
scopes = serverTransportDeliveryScopesForTransport(transport)
|
||||
}
|
||||
if s.getFileAckPool().deliverAny(scopes, event) {
|
||||
return
|
||||
}
|
||||
case EnvelopeFileMeta:
|
||||
session, err := pool.onMeta(serverFileScope(logical), env.File, now)
|
||||
if session != nil {
|
||||
event.Path = session.tmpPath
|
||||
event.Received = session.received
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
case EnvelopeFileChunk:
|
||||
session, err := pool.onChunk(serverFileScope(logical), env.File, now)
|
||||
if session != nil {
|
||||
event.Path = session.tmpPath
|
||||
event.Received = session.received
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
case EnvelopeFileEnd:
|
||||
finalPath, session, err := pool.onEnd(serverFileScope(logical), env.File, now)
|
||||
if session != nil {
|
||||
event.Path = finalPath
|
||||
event.Received = session.received
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
case EnvelopeFileAbort:
|
||||
session, err := pool.onAbort(serverFileScope(logical), env.File, now)
|
||||
event.Received = env.File.Offset
|
||||
if session != nil {
|
||||
event.Path = session.tmpPath
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
event.Err = err
|
||||
default:
|
||||
}
|
||||
if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort {
|
||||
if ackErr := s.sendFileAckInbound(logical, transport, conn, env, event.Err); ackErr != nil && event.Err == nil {
|
||||
event.Err = ackErr
|
||||
}
|
||||
}
|
||||
fillFileEventProgress(&event)
|
||||
s.publishReceivedFileEvent(event)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) emitFileEvent(event FileEvent) {
|
||||
c.mu.Lock()
|
||||
handler := c.onFileEvent
|
||||
c.mu.Unlock()
|
||||
if handler == nil {
|
||||
return
|
||||
}
|
||||
handler(event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) emitFileEvent(event FileEvent) {
|
||||
s.mu.Lock()
|
||||
handler := s.onFileEvent
|
||||
s.mu.Unlock()
|
||||
if handler == nil {
|
||||
return
|
||||
}
|
||||
handler(event)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) logFileEvent(role string, event FileEvent) {
|
||||
if !(c.debugMode || event.Err != nil) {
|
||||
return
|
||||
}
|
||||
fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n",
|
||||
role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) logFileEvent(role string, event FileEvent) {
|
||||
if !(s.debugMode || event.Err != nil) {
|
||||
return
|
||||
}
|
||||
fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n",
|
||||
role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err)
|
||||
}
|
||||
+243
@@ -0,0 +1,243 @@
|
||||
package notify
|
||||
|
||||
import "time"
|
||||
|
||||
type FileEvent struct {
|
||||
NetType NetType
|
||||
LogicalConn *LogicalConn
|
||||
// Deprecated: ClientConn aliases LogicalConn for compatibility.
|
||||
ClientConn *ClientConn
|
||||
TransportConn *TransportConn
|
||||
ServerConn Client
|
||||
Kind EnvelopeKind
|
||||
Packet FilePacket
|
||||
Path string
|
||||
Received int64
|
||||
Total int64
|
||||
Percent float64
|
||||
Done bool
|
||||
StartedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Duration time.Duration
|
||||
RateBPS float64
|
||||
StepDuration time.Duration
|
||||
InstantRateBPS float64
|
||||
Err error
|
||||
Time time.Time
|
||||
}
|
||||
|
||||
func normalizeFileEventTime(now time.Time) time.Time {
|
||||
if now.IsZero() {
|
||||
return time.Now()
|
||||
}
|
||||
return now
|
||||
}
|
||||
|
||||
func hydrateServerFileEventPeerFields(event FileEvent) FileEvent {
|
||||
if event.LogicalConn == nil {
|
||||
event.LogicalConn = logicalConnFromClient(event.ClientConn)
|
||||
}
|
||||
if event.ClientConn == nil {
|
||||
event.ClientConn = event.LogicalConn.compatClientConn()
|
||||
}
|
||||
if event.TransportConn == nil && event.LogicalConn != nil {
|
||||
event.TransportConn = event.LogicalConn.CurrentTransportConn()
|
||||
}
|
||||
return event
|
||||
}
|
||||
|
||||
func fileEventLogicalConnSnapshot(event FileEvent) *LogicalConn {
|
||||
if event.LogicalConn != nil {
|
||||
return event.LogicalConn
|
||||
}
|
||||
return logicalConnFromClient(event.ClientConn)
|
||||
}
|
||||
|
||||
func fileEventTransportConnSnapshot(event FileEvent) *TransportConn {
|
||||
if event.TransportConn != nil {
|
||||
return event.TransportConn
|
||||
}
|
||||
logical := fileEventLogicalConnSnapshot(event)
|
||||
if logical == nil {
|
||||
return nil
|
||||
}
|
||||
return logical.CurrentTransportConn()
|
||||
}
|
||||
|
||||
type fileEventTimeline struct {
|
||||
startedAt time.Time
|
||||
updatedAt time.Time
|
||||
previousUpdatedAt time.Time
|
||||
previousProgress int64
|
||||
}
|
||||
|
||||
func fillFileEventProgress(event *FileEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
event.Total = event.Packet.Size
|
||||
if event.Received < 0 {
|
||||
event.Received = 0
|
||||
}
|
||||
if event.Total > 0 && event.Received > event.Total {
|
||||
event.Received = event.Total
|
||||
}
|
||||
switch event.Kind {
|
||||
case EnvelopeFileEnd:
|
||||
event.Done = event.Err == nil
|
||||
if event.Done && event.Total > 0 {
|
||||
event.Received = event.Total
|
||||
}
|
||||
case EnvelopeFileAbort:
|
||||
event.Done = false
|
||||
}
|
||||
if event.Total <= 0 {
|
||||
if event.Done {
|
||||
event.Percent = 100
|
||||
}
|
||||
if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) {
|
||||
event.Duration = event.UpdatedAt.Sub(event.StartedAt)
|
||||
}
|
||||
return
|
||||
}
|
||||
event.Percent = float64(event.Received) * 100 / float64(event.Total)
|
||||
if event.Percent < 0 {
|
||||
event.Percent = 0
|
||||
}
|
||||
if event.Percent > 100 {
|
||||
event.Percent = 100
|
||||
}
|
||||
if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) {
|
||||
event.Duration = event.UpdatedAt.Sub(event.StartedAt)
|
||||
}
|
||||
if event.Duration > 0 && event.Received > 0 {
|
||||
event.RateBPS = float64(event.Received) / event.Duration.Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
func fillFileEventTimeline(event *FileEvent, timeline fileEventTimeline) {
|
||||
if event == nil {
|
||||
return
|
||||
}
|
||||
event.StartedAt = timeline.startedAt
|
||||
event.UpdatedAt = timeline.updatedAt
|
||||
if !timeline.previousUpdatedAt.IsZero() && !timeline.updatedAt.Before(timeline.previousUpdatedAt) {
|
||||
event.StepDuration = timeline.updatedAt.Sub(timeline.previousUpdatedAt)
|
||||
}
|
||||
if delta := event.Received - timeline.previousProgress; delta > 0 && event.StepDuration > 0 {
|
||||
event.InstantRateBPS = float64(delta) / event.StepDuration.Seconds()
|
||||
}
|
||||
}
|
||||
|
||||
func fillFileEventTiming(event *FileEvent, session *fileReceiveSession) {
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
fillFileEventTimeline(event, fileEventTimeline{
|
||||
startedAt: session.startedAt,
|
||||
updatedAt: session.updatedAt,
|
||||
previousUpdatedAt: session.previousUpdatedAt,
|
||||
previousProgress: session.previousReceived,
|
||||
})
|
||||
}
|
||||
|
||||
func fillFileSendEventTiming(event *FileEvent, session *fileSendSession) {
|
||||
if session == nil {
|
||||
return
|
||||
}
|
||||
fillFileEventTimeline(event, fileEventTimeline{
|
||||
startedAt: session.startedAt,
|
||||
updatedAt: session.updatedAt,
|
||||
previousUpdatedAt: session.previousUpdatedAt,
|
||||
previousProgress: session.previousSent,
|
||||
})
|
||||
}
|
||||
|
||||
func normalizeFileEventCallback(fn func(FileEvent)) func(FileEvent) {
|
||||
if fn == nil {
|
||||
return func(FileEvent) {}
|
||||
}
|
||||
return fn
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setFileEventObserver(fn func(FileEvent)) {
|
||||
c.mu.Lock()
|
||||
c.fileEventObserver = normalizeFileEventCallback(fn)
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *ServerCommon) setFileEventObserver(fn func(FileEvent)) {
|
||||
s.mu.Lock()
|
||||
s.fileEventObserver = normalizeFileEventCallback(fn)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *ClientCommon) observeFileEvent(event FileEvent) {
|
||||
c.mu.Lock()
|
||||
observer := c.fileEventObserver
|
||||
c.mu.Unlock()
|
||||
normalizeFileEventCallback(observer)(event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) observeFileEvent(event FileEvent) {
|
||||
s.mu.RLock()
|
||||
observer := s.fileEventObserver
|
||||
s.mu.RUnlock()
|
||||
normalizeFileEventCallback(observer)(hydrateServerFileEventPeerFields(event))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) publishReceivedFileEvent(event FileEvent) {
|
||||
c.getFileTransferState().observe(fileTransferDirectionReceive, event)
|
||||
c.observeFileEvent(event)
|
||||
c.logFileEvent("client", event)
|
||||
c.emitFileEvent(event)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) publishReceivedFileEventMonitorOnly(event FileEvent) {
|
||||
c.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event)
|
||||
c.observeFileEvent(event)
|
||||
c.logFileEvent("client", event)
|
||||
c.emitFileEvent(event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) publishReceivedFileEvent(event FileEvent) {
|
||||
event = hydrateServerFileEventPeerFields(event)
|
||||
s.getFileTransferState().observe(fileTransferDirectionReceive, event)
|
||||
s.observeFileEvent(event)
|
||||
s.logFileEvent("server", event)
|
||||
s.emitFileEvent(event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) publishReceivedFileEventMonitorOnly(event FileEvent) {
|
||||
event = hydrateServerFileEventPeerFields(event)
|
||||
s.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event)
|
||||
s.observeFileEvent(event)
|
||||
s.logFileEvent("server", event)
|
||||
s.emitFileEvent(event)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) publishSendFileEvent(event FileEvent) {
|
||||
c.getFileTransferState().observe(fileTransferDirectionSend, event)
|
||||
c.observeFileEvent(event)
|
||||
c.logFileEvent("client-send", event)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) publishSendFileEventMonitorOnly(event FileEvent) {
|
||||
c.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event)
|
||||
c.observeFileEvent(event)
|
||||
c.logFileEvent("client-send", event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) publishSendFileEvent(event FileEvent) {
|
||||
event = hydrateServerFileEventPeerFields(event)
|
||||
s.getFileTransferState().observe(fileTransferDirectionSend, event)
|
||||
s.observeFileEvent(event)
|
||||
s.logFileEvent("server-send", event)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) publishSendFileEventMonitorOnly(event FileEvent) {
|
||||
event = hydrateServerFileEventPeerFields(event)
|
||||
s.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event)
|
||||
s.observeFileEvent(event)
|
||||
s.logFileEvent("server-send", event)
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFillFileEventTimeline(t *testing.T) {
|
||||
event := FileEvent{
|
||||
Received: 150,
|
||||
}
|
||||
timeline := fileEventTimeline{
|
||||
startedAt: time.Unix(100, 0),
|
||||
updatedAt: time.Unix(110, 0),
|
||||
previousUpdatedAt: time.Unix(106, 0),
|
||||
previousProgress: 90,
|
||||
}
|
||||
|
||||
fillFileEventTimeline(&event, timeline)
|
||||
|
||||
if got, want := event.StartedAt, timeline.startedAt; !got.Equal(want) {
|
||||
t.Fatalf("startedAt mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.UpdatedAt, timeline.updatedAt; !got.Equal(want) {
|
||||
t.Fatalf("updatedAt mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.StepDuration, 4*time.Second; got != want {
|
||||
t.Fatalf("step duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.InstantRateBPS, 15.0; got != want {
|
||||
t.Fatalf("instant rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package notify
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestClientPublishSendFileEventObserverOnly(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
|
||||
var observed []FileEvent
|
||||
var handled []FileEvent
|
||||
client.setFileEventObserver(func(event FileEvent) {
|
||||
observed = append(observed, event)
|
||||
})
|
||||
client.SetFileHandler(func(event FileEvent) {
|
||||
handled = append(handled, event)
|
||||
})
|
||||
|
||||
event := FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "send-1", Size: 32},
|
||||
}
|
||||
client.publishSendFileEvent(event)
|
||||
|
||||
if got, want := len(observed), 1; got != want {
|
||||
t.Fatalf("observed count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := len(handled), 0; got != want {
|
||||
t.Fatalf("handled count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := observed[0].Packet.FileID, "send-1"; got != want {
|
||||
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientPublishReceivedFileEventObserverAndHandler(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
|
||||
var observed []FileEvent
|
||||
var handled []FileEvent
|
||||
client.setFileEventObserver(func(event FileEvent) {
|
||||
observed = append(observed, event)
|
||||
})
|
||||
client.SetFileHandler(func(event FileEvent) {
|
||||
handled = append(handled, event)
|
||||
})
|
||||
|
||||
event := FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "recv-1", Size: 64},
|
||||
Received: 64,
|
||||
Done: true,
|
||||
}
|
||||
client.publishReceivedFileEvent(event)
|
||||
|
||||
if got, want := len(observed), 1; got != want {
|
||||
t.Fatalf("observed count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := len(handled), 1; got != want {
|
||||
t.Fatalf("handled count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := observed[0].Packet.FileID, "recv-1"; got != want {
|
||||
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := handled[0].Packet.FileID, "recv-1"; got != want {
|
||||
t.Fatalf("handled fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerPublishSendFileEventObserverOnly(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
|
||||
var observed []FileEvent
|
||||
var handled []FileEvent
|
||||
server.setFileEventObserver(func(event FileEvent) {
|
||||
observed = append(observed, event)
|
||||
})
|
||||
server.SetFileHandler(func(event FileEvent) {
|
||||
handled = append(handled, event)
|
||||
})
|
||||
|
||||
event := FileEvent{
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "server-send-1", Size: 128},
|
||||
}
|
||||
server.publishSendFileEvent(event)
|
||||
|
||||
if got, want := len(observed), 1; got != want {
|
||||
t.Fatalf("observed count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := len(handled), 0; got != want {
|
||||
t.Fatalf("handled count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := observed[0].Packet.FileID, "server-send-1"; got != want {
|
||||
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fileReceiveCheckpoint struct {
|
||||
FileID string `json:"file_id"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Mode uint32 `json:"mode"`
|
||||
ModTime int64 `json:"mod_time"`
|
||||
Checksum string `json:"checksum"`
|
||||
Received int64 `json:"received"`
|
||||
TmpPath string `json:"tmp_path"`
|
||||
FinalPath string `json:"final_path"`
|
||||
StartedAt int64 `json:"started_at"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
PreviousUpdatedAt int64 `json:"previous_updated_at"`
|
||||
PreviousReceived int64 `json:"previous_received"`
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) restoreCheckpointLocked(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, bool, error) {
|
||||
checkpoint, ok, err := p.loadCheckpointLocked(scope, packet.FileID)
|
||||
if err != nil || !ok {
|
||||
return nil, ok, err
|
||||
}
|
||||
name := filepath.Base(packet.Name)
|
||||
if name == "." || name == "/" || name == "" {
|
||||
name = "unnamed.bin"
|
||||
}
|
||||
if checkpoint.FileID != packet.FileID || checkpoint.Name != name || checkpoint.Size != packet.Size || !strings.EqualFold(checkpoint.Checksum, packet.Checksum) {
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
if checkpoint.TmpPath != "" {
|
||||
_ = os.Remove(checkpoint.TmpPath)
|
||||
}
|
||||
return nil, false, nil
|
||||
}
|
||||
if checkpoint.TmpPath == "" {
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
return nil, false, nil
|
||||
}
|
||||
info, statErr := os.Stat(checkpoint.TmpPath)
|
||||
if statErr != nil {
|
||||
if checkpoint.FinalPath != "" && pathExists(checkpoint.FinalPath) {
|
||||
session := checkpoint.toSession(now)
|
||||
session.tmpPath = checkpoint.FinalPath
|
||||
session.finalPath = checkpoint.FinalPath
|
||||
session.received = session.size
|
||||
p.completed[fileReceiveKey(scope, packet.FileID)] = session.copy()
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
return session.copy(), true, nil
|
||||
}
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
return nil, false, nil
|
||||
}
|
||||
received := info.Size()
|
||||
if received < 0 {
|
||||
received = 0
|
||||
}
|
||||
if packet.Size > 0 && received > packet.Size {
|
||||
received = packet.Size
|
||||
}
|
||||
session := checkpoint.toSession(now)
|
||||
session.name = name
|
||||
session.mode = os.FileMode(packet.Mode)
|
||||
session.modTime = filePacketModTime(packet)
|
||||
session.checksum = packet.Checksum
|
||||
session.received = received
|
||||
if session.finalPath == "" || (session.finalPath != session.tmpPath && pathExists(session.finalPath)) {
|
||||
session.finalPath = p.uniqueFinalPathLocked(p.receiveDirLocked(), name, packet.FileID)
|
||||
}
|
||||
p.sessions[fileReceiveKey(scope, packet.FileID)] = session
|
||||
if session.received != checkpoint.Received || session.finalPath != checkpoint.FinalPath {
|
||||
if err := p.saveCheckpointLocked(scope, session); err != nil {
|
||||
return nil, true, err
|
||||
}
|
||||
}
|
||||
return session.copy(), true, nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) loadCheckpointLocked(scope string, fileID string) (fileReceiveCheckpoint, bool, error) {
|
||||
path := p.checkpointPathLocked(scope, fileID)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return fileReceiveCheckpoint{}, false, nil
|
||||
}
|
||||
return fileReceiveCheckpoint{}, false, err
|
||||
}
|
||||
var checkpoint fileReceiveCheckpoint
|
||||
if err := json.Unmarshal(data, &checkpoint); err != nil {
|
||||
_ = os.Remove(path)
|
||||
return fileReceiveCheckpoint{}, false, nil
|
||||
}
|
||||
return checkpoint, true, nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) saveCheckpointLocked(scope string, session *fileReceiveSession) error {
|
||||
if p == nil || session == nil || session.fileID == "" {
|
||||
return nil
|
||||
}
|
||||
path := p.checkpointPathLocked(scope, session.fileID)
|
||||
checkpoint := fileReceiveCheckpoint{
|
||||
FileID: session.fileID,
|
||||
Name: session.name,
|
||||
Size: session.size,
|
||||
Mode: uint32(session.mode.Perm()),
|
||||
ModTime: session.modTime.UnixNano(),
|
||||
Checksum: session.checksum,
|
||||
Received: session.received,
|
||||
TmpPath: session.tmpPath,
|
||||
FinalPath: session.finalPath,
|
||||
StartedAt: session.startedAt.UnixNano(),
|
||||
UpdatedAt: session.updatedAt.UnixNano(),
|
||||
PreviousUpdatedAt: session.previousUpdatedAt.UnixNano(),
|
||||
PreviousReceived: session.previousReceived,
|
||||
}
|
||||
data, err := json.Marshal(checkpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tmpPath := path + ".tmp"
|
||||
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Rename(tmpPath, path)
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) removeCheckpointLocked(scope string, fileID string) {
|
||||
if p == nil || fileID == "" {
|
||||
return
|
||||
}
|
||||
_ = os.Remove(p.checkpointPathLocked(scope, fileID))
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) checkpointPathLocked(scope string, fileID string) string {
|
||||
baseDir := p.receiveDirLocked()
|
||||
sum := sha256.Sum256([]byte(fileReceiveKey(scope, fileID)))
|
||||
return filepath.Join(baseDir, ".notify_recv_"+hex.EncodeToString(sum[:8])+".json")
|
||||
}
|
||||
|
||||
func (checkpoint fileReceiveCheckpoint) toSession(now time.Time) *fileReceiveSession {
|
||||
now = normalizeFileEventTime(now)
|
||||
session := &fileReceiveSession{
|
||||
fileID: checkpoint.FileID,
|
||||
name: checkpoint.Name,
|
||||
size: checkpoint.Size,
|
||||
mode: os.FileMode(checkpoint.Mode),
|
||||
modTime: time.Unix(0, checkpoint.ModTime),
|
||||
checksum: checkpoint.Checksum,
|
||||
received: checkpoint.Received,
|
||||
tmpPath: checkpoint.TmpPath,
|
||||
finalPath: checkpoint.FinalPath,
|
||||
previousReceived: checkpoint.PreviousReceived,
|
||||
}
|
||||
session.startedAt = unixNanoTime(checkpoint.StartedAt)
|
||||
session.updatedAt = unixNanoTime(checkpoint.UpdatedAt)
|
||||
session.previousUpdatedAt = unixNanoTime(checkpoint.PreviousUpdatedAt)
|
||||
if session.startedAt.IsZero() {
|
||||
session.startedAt = now
|
||||
}
|
||||
if session.updatedAt.IsZero() {
|
||||
session.updatedAt = now
|
||||
}
|
||||
return session
|
||||
}
|
||||
@@ -0,0 +1,147 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func computeFileChecksum(path string) (string, error) {
|
||||
fd, err := os.Open(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer fd.Close()
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, fd); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(h.Sum(nil)), nil
|
||||
}
|
||||
|
||||
func filePacketModTime(packet FilePacket) time.Time {
|
||||
if packet.ModTime <= 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
return time.Unix(0, packet.ModTime)
|
||||
}
|
||||
|
||||
func applyReceivedFileMeta(path string, mode os.FileMode, modTime time.Time) {
|
||||
if mode != 0 {
|
||||
_ = os.Chmod(path, mode.Perm())
|
||||
}
|
||||
if !modTime.IsZero() {
|
||||
_ = os.Chtimes(path, modTime, modTime)
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeFileName(name string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return "unnamed"
|
||||
}
|
||||
trimmed = strings.ReplaceAll(trimmed, "/", "_")
|
||||
trimmed = strings.ReplaceAll(trimmed, "\\", "_")
|
||||
trimmed = strings.ReplaceAll(trimmed, ":", "_")
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func shortFileIDSuffix(fileID string) string {
|
||||
cleaned := sanitizeFileName(fileID)
|
||||
if len(cleaned) > 12 {
|
||||
return cleaned[:12]
|
||||
}
|
||||
if cleaned == "" {
|
||||
return "copy"
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func pathExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) receiveDirLocked() string {
|
||||
if p.dir != "" {
|
||||
return p.dir
|
||||
}
|
||||
return os.TempDir()
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) uniqueFinalPathLocked(baseDir string, name string, fileID string) string {
|
||||
cleanName := sanitizeFileName(filepath.Base(name))
|
||||
if cleanName == "" {
|
||||
cleanName = "unnamed.bin"
|
||||
}
|
||||
ext := filepath.Ext(cleanName)
|
||||
base := strings.TrimSuffix(cleanName, ext)
|
||||
candidate := filepath.Join(baseDir, cleanName)
|
||||
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
|
||||
return candidate
|
||||
}
|
||||
suffix := shortFileIDSuffix(fileID)
|
||||
candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s%s", base, suffix, ext))
|
||||
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
|
||||
return candidate
|
||||
}
|
||||
for i := 1; ; i++ {
|
||||
candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s.%d%s", base, suffix, i, ext))
|
||||
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
|
||||
return candidate
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) pathReservedLocked(path string) bool {
|
||||
for _, session := range p.sessions {
|
||||
if session.finalPath == path || session.tmpPath == path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) trimCompletedLocked() {
|
||||
if p.completedLimit <= 0 || len(p.completed) <= p.completedLimit {
|
||||
return
|
||||
}
|
||||
for len(p.completed) > p.completedLimit {
|
||||
oldestKey := ""
|
||||
oldestTime := time.Time{}
|
||||
for key, session := range p.completed {
|
||||
candidateTime := completedFileReceiveTime(session)
|
||||
if oldestKey == "" || candidateTime.Before(oldestTime) || (candidateTime.Equal(oldestTime) && key < oldestKey) {
|
||||
oldestKey = key
|
||||
oldestTime = candidateTime
|
||||
}
|
||||
}
|
||||
if oldestKey == "" {
|
||||
return
|
||||
}
|
||||
delete(p.completed, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
func completedFileReceiveTime(session *fileReceiveSession) time.Time {
|
||||
if session == nil {
|
||||
return time.Time{}
|
||||
}
|
||||
if !session.updatedAt.IsZero() {
|
||||
return session.updatedAt
|
||||
}
|
||||
return session.startedAt
|
||||
}
|
||||
|
||||
func (s *fileReceiveSession) copy() *fileReceiveSession {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
dup := *s
|
||||
return &dup
|
||||
}
|
||||
@@ -0,0 +1,278 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fileReceiveSession struct {
|
||||
fileID string
|
||||
name string
|
||||
size int64
|
||||
mode os.FileMode
|
||||
modTime time.Time
|
||||
checksum string
|
||||
received int64
|
||||
tmpPath string
|
||||
finalPath string
|
||||
startedAt time.Time
|
||||
updatedAt time.Time
|
||||
previousUpdatedAt time.Time
|
||||
previousReceived int64
|
||||
}
|
||||
|
||||
const defaultFileReceiveCompletedLimit = 128
|
||||
|
||||
type fileReceivePool struct {
|
||||
mu sync.Mutex
|
||||
dir string
|
||||
sessions map[string]*fileReceiveSession
|
||||
completed map[string]*fileReceiveSession
|
||||
completedLimit int
|
||||
}
|
||||
|
||||
func fileReceiveKey(scope string, fileID string) string {
|
||||
return normalizeFileScope(scope) + "|" + fileID
|
||||
}
|
||||
|
||||
func newFileReceivePool() *fileReceivePool {
|
||||
return newFileReceivePoolWithConfig(defaultFileTransferConfig())
|
||||
}
|
||||
|
||||
func newFileReceivePoolWithConfig(cfg fileTransferConfig) *fileReceivePool {
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
return newFileReceivePoolWithCompletedLimit(cfg.ReceiveCompletedLimit)
|
||||
}
|
||||
|
||||
func newFileReceivePoolWithCompletedLimit(limit int) *fileReceivePool {
|
||||
if limit <= 0 {
|
||||
limit = defaultFileReceiveCompletedLimit
|
||||
}
|
||||
return &fileReceivePool{
|
||||
sessions: make(map[string]*fileReceiveSession),
|
||||
completed: make(map[string]*fileReceiveSession),
|
||||
completedLimit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) applyConfig(cfg fileTransferConfig) {
|
||||
if p == nil {
|
||||
return
|
||||
}
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
p.mu.Lock()
|
||||
p.completedLimit = cfg.ReceiveCompletedLimit
|
||||
p.trimCompletedLocked()
|
||||
p.mu.Unlock()
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) setDir(dir string) error {
|
||||
cleaned := strings.TrimSpace(dir)
|
||||
if cleaned == "" {
|
||||
p.mu.Lock()
|
||||
p.dir = ""
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
cleaned = filepath.Clean(cleaned)
|
||||
if err := os.MkdirAll(cleaned, 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
info, err := os.Stat(cleaned)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return errors.New("file receive path is not a directory")
|
||||
}
|
||||
p.mu.Lock()
|
||||
p.dir = cleaned
|
||||
p.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) onMeta(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
|
||||
if packet.FileID == "" {
|
||||
return nil, errors.New("empty file id")
|
||||
}
|
||||
now = normalizeFileEventTime(now)
|
||||
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||
name := filepath.Base(packet.Name)
|
||||
if name == "." || name == "/" || name == "" {
|
||||
name = "unnamed.bin"
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if old, ok := p.completed[sessionKey]; ok {
|
||||
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
|
||||
return old.copy(), nil
|
||||
}
|
||||
delete(p.completed, sessionKey)
|
||||
}
|
||||
if old, ok := p.sessions[sessionKey]; ok {
|
||||
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
|
||||
return old.copy(), nil
|
||||
}
|
||||
_ = os.Remove(old.tmpPath)
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
delete(p.sessions, sessionKey)
|
||||
}
|
||||
if restored, ok, err := p.restoreCheckpointLocked(scope, packet, now); ok || err != nil {
|
||||
return restored, err
|
||||
}
|
||||
baseDir := p.receiveDirLocked()
|
||||
finalPath := p.uniqueFinalPathLocked(baseDir, name, packet.FileID)
|
||||
prefix := "notify_recv_" + sanitizeFileName(name) + "_"
|
||||
tmp, err := os.CreateTemp(baseDir, prefix+"*.part")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = tmp.Close()
|
||||
session := &fileReceiveSession{
|
||||
fileID: packet.FileID,
|
||||
name: name,
|
||||
size: packet.Size,
|
||||
mode: os.FileMode(packet.Mode),
|
||||
modTime: filePacketModTime(packet),
|
||||
checksum: packet.Checksum,
|
||||
received: 0,
|
||||
tmpPath: tmp.Name(),
|
||||
finalPath: finalPath,
|
||||
startedAt: now,
|
||||
updatedAt: now,
|
||||
}
|
||||
p.sessions[sessionKey] = session
|
||||
if err := p.saveCheckpointLocked(scope, session); err != nil {
|
||||
_ = os.Remove(session.tmpPath)
|
||||
delete(p.sessions, sessionKey)
|
||||
return nil, err
|
||||
}
|
||||
return session.copy(), nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) onChunk(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
|
||||
now = normalizeFileEventTime(now)
|
||||
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
session, ok := p.sessions[sessionKey]
|
||||
if !ok {
|
||||
if completed, ok := p.completed[sessionKey]; ok {
|
||||
return completed.copy(), nil
|
||||
}
|
||||
return nil, errors.New("unknown file id")
|
||||
}
|
||||
if packet.Offset < session.received {
|
||||
return session.copy(), nil
|
||||
}
|
||||
if packet.Offset > session.received {
|
||||
return nil, errors.New("chunk offset mismatch")
|
||||
}
|
||||
if len(packet.Chunk) == 0 {
|
||||
return session.copy(), nil
|
||||
}
|
||||
prevUpdatedAt := session.updatedAt
|
||||
prevReceived := session.received
|
||||
fd, err := os.OpenFile(session.tmpPath, os.O_WRONLY|os.O_APPEND, 0o600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer fd.Close()
|
||||
n, err := fd.Write(packet.Chunk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
session.received += int64(n)
|
||||
session.previousUpdatedAt = prevUpdatedAt
|
||||
session.previousReceived = prevReceived
|
||||
session.updatedAt = now
|
||||
if err := p.saveCheckpointLocked(scope, session); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return session.copy(), nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) onEnd(scope string, packet FilePacket, now time.Time) (string, *fileReceiveSession, error) {
|
||||
now = normalizeFileEventTime(now)
|
||||
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
session, ok := p.sessions[sessionKey]
|
||||
if !ok {
|
||||
if completed, ok := p.completed[sessionKey]; ok {
|
||||
return completed.finalPath, completed.copy(), nil
|
||||
}
|
||||
return "", nil, errors.New("unknown file id")
|
||||
}
|
||||
if session.size > 0 && session.received != session.size {
|
||||
return "", session.copy(), errors.New("file size not match")
|
||||
}
|
||||
if session.checksum != "" {
|
||||
sum, err := computeFileChecksum(session.tmpPath)
|
||||
if err != nil {
|
||||
return "", session.copy(), err
|
||||
}
|
||||
if !strings.EqualFold(sum, session.checksum) {
|
||||
_ = os.Remove(session.tmpPath)
|
||||
delete(p.sessions, sessionKey)
|
||||
return "", session.copy(), errors.New("file checksum not match")
|
||||
}
|
||||
}
|
||||
finalPath := session.finalPath
|
||||
baseDir := filepath.Dir(session.tmpPath)
|
||||
if baseDir == "" || baseDir == "." {
|
||||
baseDir = p.receiveDirLocked()
|
||||
}
|
||||
if finalPath == "" || pathExists(finalPath) {
|
||||
finalPath = p.uniqueFinalPathLocked(baseDir, session.name, packet.FileID)
|
||||
}
|
||||
if err := os.Rename(session.tmpPath, finalPath); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
session.previousUpdatedAt = session.updatedAt
|
||||
session.previousReceived = session.received
|
||||
session.updatedAt = now
|
||||
applyReceivedFileMeta(finalPath, session.mode, session.modTime)
|
||||
delete(p.sessions, sessionKey)
|
||||
session.tmpPath = finalPath
|
||||
session.finalPath = finalPath
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
p.completed[sessionKey] = session.copy()
|
||||
p.trimCompletedLocked()
|
||||
return finalPath, session.copy(), nil
|
||||
}
|
||||
|
||||
func (p *fileReceivePool) onAbort(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
|
||||
now = normalizeFileEventTime(now)
|
||||
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
session, ok := p.sessions[sessionKey]
|
||||
if !ok {
|
||||
if completed, ok := p.completed[sessionKey]; ok {
|
||||
return completed.copy(), nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
session.previousUpdatedAt = session.updatedAt
|
||||
session.previousReceived = session.received
|
||||
session.updatedAt = now
|
||||
dup := session.copy()
|
||||
_ = os.Remove(session.tmpPath)
|
||||
p.removeCheckpointLocked(scope, packet.FileID)
|
||||
delete(p.sessions, sessionKey)
|
||||
delete(p.completed, sessionKey)
|
||||
return dup, nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getFileReceivePool() *fileReceivePool {
|
||||
return c.getLogicalSessionState().fileReceives
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getFileReceivePool() *fileReceivePool {
|
||||
return s.getLogicalSessionState().fileReceives
|
||||
}
|
||||
@@ -0,0 +1,520 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileReceivePoolUsesConfiguredDirAndStableName(t *testing.T) {
|
||||
pool := newFileReceivePool()
|
||||
scope := "client:test"
|
||||
dir := t.TempDir()
|
||||
now := time.Now()
|
||||
if err := pool.setDir(dir); err != nil {
|
||||
t.Fatalf("setDir failed: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte("hello notify")
|
||||
meta := FilePacket{
|
||||
FileID: "file-1",
|
||||
Name: "greeting.txt",
|
||||
Size: int64(len(payload)),
|
||||
Checksum: testFileChecksum(payload),
|
||||
}
|
||||
|
||||
session, err := pool.onMeta(scope, meta, now)
|
||||
if err != nil {
|
||||
t.Fatalf("onMeta failed: %v", err)
|
||||
}
|
||||
if got, want := filepath.Dir(session.tmpPath), dir; got != want {
|
||||
t.Fatalf("tmp dir mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := session.finalPath, filepath.Join(dir, "greeting.txt"); got != want {
|
||||
t.Fatalf("final path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
session, err = pool.onChunk(scope, FilePacket{
|
||||
FileID: meta.FileID,
|
||||
Offset: 0,
|
||||
Chunk: payload,
|
||||
}, now.Add(time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onChunk failed: %v", err)
|
||||
}
|
||||
if got, want := session.received, int64(len(payload)); got != want {
|
||||
t.Fatalf("received mismatch after chunk: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
finalPath, session, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onEnd failed: %v", err)
|
||||
}
|
||||
if got, want := finalPath, filepath.Join(dir, "greeting.txt"); got != want {
|
||||
t.Fatalf("completed path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := session.finalPath, finalPath; got != want {
|
||||
t.Fatalf("session final path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
gotData, err := os.ReadFile(finalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(gotData, payload) {
|
||||
t.Fatalf("completed file content mismatch: got %q want %q", gotData, payload)
|
||||
}
|
||||
|
||||
dupMeta, err := pool.onMeta(scope, meta, now.Add(3*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("duplicate onMeta failed: %v", err)
|
||||
}
|
||||
if got, want := dupMeta.finalPath, finalPath; got != want {
|
||||
t.Fatalf("duplicate meta final path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
dupChunk, err := pool.onChunk(scope, FilePacket{
|
||||
FileID: meta.FileID,
|
||||
Offset: 0,
|
||||
Chunk: payload,
|
||||
}, now.Add(4*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("duplicate onChunk failed: %v", err)
|
||||
}
|
||||
if got, want := dupChunk.received, int64(len(payload)); got != want {
|
||||
t.Fatalf("duplicate chunk received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
dupPath, dupEnd, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(5*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("duplicate onEnd failed: %v", err)
|
||||
}
|
||||
if got, want := dupPath, finalPath; got != want {
|
||||
t.Fatalf("duplicate end path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := dupEnd.finalPath, finalPath; got != want {
|
||||
t.Fatalf("duplicate end session final path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReceivePoolAvoidsOverwriteWhenFinalPathBecomesBusy(t *testing.T) {
|
||||
pool := newFileReceivePool()
|
||||
scope := "client:test"
|
||||
dir := t.TempDir()
|
||||
now := time.Now()
|
||||
if err := pool.setDir(dir); err != nil {
|
||||
t.Fatalf("setDir failed: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte("new report payload")
|
||||
meta := FilePacket{
|
||||
FileID: "file-2",
|
||||
Name: "report.txt",
|
||||
Size: int64(len(payload)),
|
||||
Checksum: testFileChecksum(payload),
|
||||
}
|
||||
|
||||
session, err := pool.onMeta(scope, meta, now)
|
||||
if err != nil {
|
||||
t.Fatalf("onMeta failed: %v", err)
|
||||
}
|
||||
|
||||
occupiedPath := session.finalPath
|
||||
occupiedContent := []byte("existing report")
|
||||
if err := os.WriteFile(occupiedPath, occupiedContent, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile occupied path failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := pool.onChunk(scope, FilePacket{
|
||||
FileID: meta.FileID,
|
||||
Offset: 0,
|
||||
Chunk: payload,
|
||||
}, now.Add(time.Second)); err != nil {
|
||||
t.Fatalf("onChunk failed: %v", err)
|
||||
}
|
||||
|
||||
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onEnd failed: %v", err)
|
||||
}
|
||||
if finalPath == occupiedPath {
|
||||
t.Fatalf("expected final path to avoid occupied path %q", occupiedPath)
|
||||
}
|
||||
|
||||
gotOccupied, err := os.ReadFile(occupiedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile occupied path failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(gotOccupied, occupiedContent) {
|
||||
t.Fatalf("occupied file content changed: got %q want %q", gotOccupied, occupiedContent)
|
||||
}
|
||||
|
||||
gotFinal, err := os.ReadFile(finalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile final path failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(gotFinal, payload) {
|
||||
t.Fatalf("final file content mismatch: got %q want %q", gotFinal, payload)
|
||||
}
|
||||
if got, want := filepath.Dir(finalPath), dir; got != want {
|
||||
t.Fatalf("final dir mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReceivePoolAbortAfterCompletionKeepsDeliveredFile(t *testing.T) {
|
||||
pool := newFileReceivePool()
|
||||
scope := "client:test"
|
||||
dir := t.TempDir()
|
||||
now := time.Now()
|
||||
if err := pool.setDir(dir); err != nil {
|
||||
t.Fatalf("setDir failed: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte("keep me")
|
||||
meta := FilePacket{
|
||||
FileID: "file-3",
|
||||
Name: "keep.txt",
|
||||
Size: int64(len(payload)),
|
||||
Checksum: testFileChecksum(payload),
|
||||
}
|
||||
|
||||
if _, err := pool.onMeta(scope, meta, now); err != nil {
|
||||
t.Fatalf("onMeta failed: %v", err)
|
||||
}
|
||||
if _, err := pool.onChunk(scope, FilePacket{
|
||||
FileID: meta.FileID,
|
||||
Offset: 0,
|
||||
Chunk: payload,
|
||||
}, now.Add(time.Second)); err != nil {
|
||||
t.Fatalf("onChunk failed: %v", err)
|
||||
}
|
||||
|
||||
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onEnd failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := pool.onAbort(scope, FilePacket{FileID: meta.FileID}, now.Add(3*time.Second)); err != nil {
|
||||
t.Fatalf("onAbort failed: %v", err)
|
||||
}
|
||||
|
||||
gotData, err := os.ReadFile(finalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile final path after abort failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(gotData, payload) {
|
||||
t.Fatalf("final file content mismatch after abort: got %q want %q", gotData, payload)
|
||||
}
|
||||
|
||||
dupPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(4*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("duplicate onEnd after abort failed: %v", err)
|
||||
}
|
||||
if got, want := dupPath, finalPath; got != want {
|
||||
t.Fatalf("duplicate end path mismatch after abort: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReceivePoolAppliesMetaModeAndModTime(t *testing.T) {
|
||||
pool := newFileReceivePool()
|
||||
scope := "client:test"
|
||||
dir := t.TempDir()
|
||||
now := time.Now()
|
||||
if err := pool.setDir(dir); err != nil {
|
||||
t.Fatalf("setDir failed: %v", err)
|
||||
}
|
||||
|
||||
payload := []byte("meta test")
|
||||
wantMode := os.FileMode(0o640)
|
||||
wantTime := time.Now().Add(-2 * time.Hour).Truncate(time.Second)
|
||||
meta := FilePacket{
|
||||
FileID: "file-meta",
|
||||
Name: "meta.txt",
|
||||
Size: int64(len(payload)),
|
||||
Checksum: testFileChecksum(payload),
|
||||
Mode: uint32(wantMode),
|
||||
ModTime: wantTime.UnixNano(),
|
||||
}
|
||||
if _, err := pool.onMeta(scope, meta, now); err != nil {
|
||||
t.Fatalf("onMeta failed: %v", err)
|
||||
}
|
||||
if _, err := pool.onChunk(scope, FilePacket{
|
||||
FileID: meta.FileID,
|
||||
Offset: 0,
|
||||
Chunk: payload,
|
||||
}, now.Add(time.Second)); err != nil {
|
||||
t.Fatalf("onChunk failed: %v", err)
|
||||
}
|
||||
|
||||
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onEnd failed: %v", err)
|
||||
}
|
||||
info, err := os.Stat(finalPath)
|
||||
if err != nil {
|
||||
t.Fatalf("Stat failed: %v", err)
|
||||
}
|
||||
if got, want := info.Mode().Perm(), wantMode; got != want {
|
||||
t.Fatalf("mode mismatch: got %o want %o", got, want)
|
||||
}
|
||||
gotMTime := info.ModTime().Truncate(time.Second)
|
||||
if got, want := gotMTime, wantTime; !got.Equal(want) {
|
||||
t.Fatalf("mtime mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReceivePoolScopeIsolation(t *testing.T) {
|
||||
pool := newFileReceivePool()
|
||||
dir := t.TempDir()
|
||||
now := time.Now()
|
||||
if err := pool.setDir(dir); err != nil {
|
||||
t.Fatalf("setDir failed: %v", err)
|
||||
}
|
||||
|
||||
const sharedFileID = "shared-file-id"
|
||||
payloadA := []byte("from client A")
|
||||
payloadB := []byte("from client B")
|
||||
metaA := FilePacket{
|
||||
FileID: sharedFileID,
|
||||
Name: "shared.txt",
|
||||
Size: int64(len(payloadA)),
|
||||
Checksum: testFileChecksum(payloadA),
|
||||
}
|
||||
metaB := FilePacket{
|
||||
FileID: sharedFileID,
|
||||
Name: "shared.txt",
|
||||
Size: int64(len(payloadB)),
|
||||
Checksum: testFileChecksum(payloadB),
|
||||
}
|
||||
|
||||
scopeA := "server:client-a"
|
||||
scopeB := "server:client-b"
|
||||
if _, err := pool.onMeta(scopeA, metaA, now); err != nil {
|
||||
t.Fatalf("onMeta scopeA failed: %v", err)
|
||||
}
|
||||
if _, err := pool.onMeta(scopeB, metaB, now); err != nil {
|
||||
t.Fatalf("onMeta scopeB failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := pool.onChunk(scopeA, FilePacket{
|
||||
FileID: sharedFileID,
|
||||
Offset: 0,
|
||||
Chunk: payloadA,
|
||||
}, now.Add(time.Second)); err != nil {
|
||||
t.Fatalf("onChunk scopeA failed: %v", err)
|
||||
}
|
||||
if _, err := pool.onChunk(scopeB, FilePacket{
|
||||
FileID: sharedFileID,
|
||||
Offset: 0,
|
||||
Chunk: payloadB,
|
||||
}, now.Add(time.Second)); err != nil {
|
||||
t.Fatalf("onChunk scopeB failed: %v", err)
|
||||
}
|
||||
|
||||
finalPathA, _, err := pool.onEnd(scopeA, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onEnd scopeA failed: %v", err)
|
||||
}
|
||||
finalPathB, _, err := pool.onEnd(scopeB, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("onEnd scopeB failed: %v", err)
|
||||
}
|
||||
if finalPathA == finalPathB {
|
||||
t.Fatalf("scope-isolated files should not share path: %q", finalPathA)
|
||||
}
|
||||
|
||||
gotA, err := os.ReadFile(finalPathA)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile scopeA failed: %v", err)
|
||||
}
|
||||
gotB, err := os.ReadFile(finalPathB)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile scopeB failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(gotA, payloadA) {
|
||||
t.Fatalf("scopeA content mismatch: got %q want %q", gotA, payloadA)
|
||||
}
|
||||
if !bytes.Equal(gotB, payloadB) {
|
||||
t.Fatalf("scopeB content mismatch: got %q want %q", gotB, payloadB)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileReceivePoolCompletedRetentionEvictsOldest(t *testing.T) {
|
||||
pool := newFileReceivePoolWithCompletedLimit(2)
|
||||
dir := t.TempDir()
|
||||
now := time.Now()
|
||||
scope := "client:test"
|
||||
if err := pool.setDir(dir); err != nil {
|
||||
t.Fatalf("setDir failed: %v", err)
|
||||
}
|
||||
|
||||
complete := func(fileID string, offset time.Duration) {
|
||||
payload := []byte("payload-" + fileID)
|
||||
meta := FilePacket{
|
||||
FileID: fileID,
|
||||
Name: fileID + ".txt",
|
||||
Size: int64(len(payload)),
|
||||
Checksum: testFileChecksum(payload),
|
||||
}
|
||||
eventTime := now.Add(offset)
|
||||
if _, err := pool.onMeta(scope, meta, eventTime); err != nil {
|
||||
t.Fatalf("onMeta %s failed: %v", fileID, err)
|
||||
}
|
||||
if _, err := pool.onChunk(scope, FilePacket{
|
||||
FileID: fileID,
|
||||
Offset: 0,
|
||||
Chunk: payload,
|
||||
}, eventTime.Add(time.Second)); err != nil {
|
||||
t.Fatalf("onChunk %s failed: %v", fileID, err)
|
||||
}
|
||||
if _, _, err := pool.onEnd(scope, FilePacket{FileID: fileID}, eventTime.Add(2*time.Second)); err != nil {
|
||||
t.Fatalf("onEnd %s failed: %v", fileID, err)
|
||||
}
|
||||
}
|
||||
|
||||
complete("done-1", 0)
|
||||
complete("done-2", 10*time.Second)
|
||||
|
||||
activePayload := []byte("still-active")
|
||||
if _, err := pool.onMeta(scope, FilePacket{
|
||||
FileID: "active-1",
|
||||
Name: "active-1.txt",
|
||||
Size: int64(len(activePayload)),
|
||||
Checksum: testFileChecksum(activePayload),
|
||||
}, now.Add(20*time.Second)); err != nil {
|
||||
t.Fatalf("onMeta active-1 failed: %v", err)
|
||||
}
|
||||
|
||||
complete("done-3", 30*time.Second)
|
||||
|
||||
if got, want := len(pool.completed), 2; got != want {
|
||||
t.Fatalf("completed size mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := len(pool.sessions), 1; got != want {
|
||||
t.Fatalf("active session size mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if _, ok := pool.sessions[fileReceiveKey(scope, "active-1")]; !ok {
|
||||
t.Fatal("active session should be retained")
|
||||
}
|
||||
if _, ok := pool.completed[fileReceiveKey(scope, "done-1")]; ok {
|
||||
t.Fatal("oldest completed session should be evicted")
|
||||
}
|
||||
if _, ok := pool.completed[fileReceiveKey(scope, "done-2")]; !ok {
|
||||
t.Fatal("newer completed session should be retained")
|
||||
}
|
||||
if _, ok := pool.completed[fileReceiveKey(scope, "done-3")]; !ok {
|
||||
t.Fatal("latest completed session should be retained")
|
||||
}
|
||||
|
||||
if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-1"}, now.Add(40*time.Second)); err == nil {
|
||||
t.Fatal("evicted completed session should no longer resolve duplicate end")
|
||||
}
|
||||
if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-3"}, now.Add(41*time.Second)); err != nil {
|
||||
t.Fatalf("latest completed session should still resolve duplicate end: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillFileEventProgress(t *testing.T) {
|
||||
event := FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{Size: 200},
|
||||
Received: 50,
|
||||
StartedAt: time.Unix(100, 0),
|
||||
UpdatedAt: time.Unix(102, 0),
|
||||
}
|
||||
fillFileEventProgress(&event)
|
||||
if got, want := event.Total, int64(200); got != want {
|
||||
t.Fatalf("total mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := event.Percent, 25.0; got != want {
|
||||
t.Fatalf("percent mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if event.Done {
|
||||
t.Fatal("chunk event should not be done")
|
||||
}
|
||||
if got, want := event.Duration, 2*time.Second; got != want {
|
||||
t.Fatalf("duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.RateBPS, 25.0; got != want {
|
||||
t.Fatalf("rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
endEvent := FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{Size: 200},
|
||||
Received: 180,
|
||||
StartedAt: time.Unix(200, 0),
|
||||
UpdatedAt: time.Unix(204, 0),
|
||||
}
|
||||
fillFileEventProgress(&endEvent)
|
||||
if !endEvent.Done {
|
||||
t.Fatal("end event should be done")
|
||||
}
|
||||
if got, want := endEvent.Received, int64(200); got != want {
|
||||
t.Fatalf("end received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := endEvent.Percent, 100.0; got != want {
|
||||
t.Fatalf("end percent mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := endEvent.Duration, 4*time.Second; got != want {
|
||||
t.Fatalf("end duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := endEvent.RateBPS, 50.0; got != want {
|
||||
t.Fatalf("end rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
abortEvent := FileEvent{
|
||||
Kind: EnvelopeFileAbort,
|
||||
Packet: FilePacket{Size: 200},
|
||||
Received: 60,
|
||||
StartedAt: time.Unix(300, 0),
|
||||
UpdatedAt: time.Unix(303, 0),
|
||||
}
|
||||
fillFileEventProgress(&abortEvent)
|
||||
if abortEvent.Done {
|
||||
t.Fatal("abort event should not be done")
|
||||
}
|
||||
if got, want := abortEvent.Percent, 30.0; got != want {
|
||||
t.Fatalf("abort percent mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := abortEvent.Duration, 3*time.Second; got != want {
|
||||
t.Fatalf("abort duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := abortEvent.RateBPS, 20.0; got != want {
|
||||
t.Fatalf("abort rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFillFileEventTiming(t *testing.T) {
|
||||
event := FileEvent{
|
||||
Received: 120,
|
||||
}
|
||||
session := &fileReceiveSession{
|
||||
startedAt: time.Unix(100, 0),
|
||||
updatedAt: time.Unix(110, 0),
|
||||
previousUpdatedAt: time.Unix(108, 0),
|
||||
previousReceived: 80,
|
||||
}
|
||||
fillFileEventTiming(&event, session)
|
||||
|
||||
if got, want := event.StartedAt, session.startedAt; !got.Equal(want) {
|
||||
t.Fatalf("startedAt mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.UpdatedAt, session.updatedAt; !got.Equal(want) {
|
||||
t.Fatalf("updatedAt mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.StepDuration, 2*time.Second; got != want {
|
||||
t.Fatalf("step duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := event.InstantRateBPS, 20.0; got != want {
|
||||
t.Fatalf("instant rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func testFileChecksum(data []byte) string {
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultFileScope = "default"
|
||||
clientFileDomain = "client"
|
||||
serverFileDomain = "server"
|
||||
serverTransportScopeSuffix = "#tg:"
|
||||
)
|
||||
|
||||
func normalizeFileScope(scope string) string {
|
||||
cleaned := strings.TrimSpace(scope)
|
||||
if cleaned == "" {
|
||||
return defaultFileScope
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func clientFileScope() string {
|
||||
return clientFileDomain
|
||||
}
|
||||
|
||||
func serverFileScope(peer any) string {
|
||||
logical := logicalConnFromPeer(peer)
|
||||
if logical == nil {
|
||||
return serverFileDomain + ":unknown"
|
||||
}
|
||||
id := strings.TrimSpace(logical.ID())
|
||||
if id == "" {
|
||||
return serverFileDomain + ":unknown"
|
||||
}
|
||||
return serverFileDomain + ":" + id
|
||||
}
|
||||
|
||||
func serverTransportScope(peer any) string {
|
||||
logical := logicalConnFromPeer(peer)
|
||||
if logical == nil {
|
||||
return serverFileDomain + ":unknown"
|
||||
}
|
||||
return serverTransportScopeByGeneration(logical, logical.transportGenerationSnapshot())
|
||||
}
|
||||
|
||||
func serverTransportScopeForTransport(transport *TransportConn) string {
|
||||
if transport == nil {
|
||||
return serverFileDomain + ":unknown"
|
||||
}
|
||||
return transport.transportScope()
|
||||
}
|
||||
|
||||
func serverTransportScopeByGeneration(peer any, generation uint64) string {
|
||||
base := serverFileScope(peer)
|
||||
if generation == 0 {
|
||||
return base
|
||||
}
|
||||
return base + serverTransportScopeSuffix + strconv.FormatUint(generation, 10)
|
||||
}
|
||||
|
||||
func serverTransportDeliveryScopes(peer any) []string {
|
||||
logical := logicalConnFromPeer(peer)
|
||||
if logical == nil {
|
||||
return []string{serverFileDomain + ":unknown"}
|
||||
}
|
||||
base := serverFileScope(logical)
|
||||
transport := serverTransportScope(logical)
|
||||
if transport == base {
|
||||
return []string{base}
|
||||
}
|
||||
return []string{transport, base}
|
||||
}
|
||||
|
||||
func serverTransportDeliveryScopesForTransport(transport *TransportConn) []string {
|
||||
if transport == nil {
|
||||
return []string{serverFileDomain + ":unknown"}
|
||||
}
|
||||
return transport.deliveryScopes()
|
||||
}
|
||||
|
||||
func scopeBelongsToServerFileScope(scope string, base string) bool {
|
||||
scope = normalizeFileScope(scope)
|
||||
base = normalizeFileScope(base)
|
||||
return scope == base || strings.HasPrefix(scope, base+serverTransportScopeSuffix)
|
||||
}
|
||||
+451
@@ -0,0 +1,451 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultFileChunkSize = 64 * 1024
|
||||
|
||||
type fileSendHooks struct {
|
||||
config fileTransferConfig
|
||||
startSession func(*fileSendSession)
|
||||
sendReliable func(context.Context, Envelope) error
|
||||
sendAbort func(fileID string, stage string, offset int64, cause error) error
|
||||
publishEvent func(FileEvent)
|
||||
}
|
||||
|
||||
type fileSendError struct {
|
||||
stage string
|
||||
offset int64
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *fileSendError) Error() string {
|
||||
if e == nil || e.err == nil {
|
||||
return ""
|
||||
}
|
||||
return e.err.Error()
|
||||
}
|
||||
|
||||
func (e *fileSendError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
func (c *ClientCommon) SendFile(ctx context.Context, filePath string) error {
|
||||
target := transferSendTarget{
|
||||
runtime: c.getTransferRuntime(),
|
||||
runtimeScope: clientFileScope(),
|
||||
publicScope: clientFileScope(),
|
||||
transportGeneration: 0,
|
||||
sequenceEn: c.sequenceEn,
|
||||
sequenceDe: c.sequenceDe,
|
||||
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
|
||||
return c.OpenStream(ctx, opt)
|
||||
},
|
||||
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||||
return SendTransferBeginClient(ctx, c, req)
|
||||
},
|
||||
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
|
||||
return SendTransferResumeClient(ctx, c, req)
|
||||
},
|
||||
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
|
||||
return SendTransferCommitClient(ctx, c, req)
|
||||
},
|
||||
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
|
||||
return SendTransferAbortClient(ctx, c, req)
|
||||
},
|
||||
}
|
||||
return c.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) {
|
||||
event.NetType = NET_CLIENT
|
||||
event.ServerConn = c
|
||||
c.publishSendFileEventMonitorOnly(event)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *ServerCommon) SendFile(ctx context.Context, client *ClientConn, filePath string) error {
|
||||
return s.SendFileLogical(ctx, logicalConnFromClient(client), filePath)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) SendFileLogical(ctx context.Context, client *LogicalConn, filePath string) error {
|
||||
if client == nil {
|
||||
return s.SendFileTransport(ctx, nil, filePath)
|
||||
}
|
||||
return s.SendFileTransport(ctx, s.resolveOutboundTransport(client), filePath)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) SendFileTransport(ctx context.Context, transport *TransportConn, filePath string) error {
|
||||
if transport == nil {
|
||||
return transportDetachedErrorForTransport(transport)
|
||||
}
|
||||
logical := transport.logicalConnSnapshot()
|
||||
if logical == nil || !transport.Attached() || !transport.IsCurrent() {
|
||||
return transportDetachedErrorForTransport(transport)
|
||||
}
|
||||
target := transferSendTarget{
|
||||
runtime: s.getTransferRuntime(),
|
||||
runtimeScope: serverTransportScopeForTransport(transport),
|
||||
publicScope: serverFileScope(logical),
|
||||
transportGeneration: transport.TransportGeneration(),
|
||||
logical: logical,
|
||||
transport: transport,
|
||||
sequenceEn: s.sequenceEn,
|
||||
sequenceDe: s.sequenceDe,
|
||||
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
|
||||
return s.OpenStreamTransport(ctx, transport, opt)
|
||||
},
|
||||
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||||
return SendTransferBeginTransport(ctx, s, transport, req)
|
||||
},
|
||||
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
|
||||
return SendTransferResumeTransport(ctx, s, transport, req)
|
||||
},
|
||||
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
|
||||
return SendTransferCommitTransport(ctx, s, transport, req)
|
||||
},
|
||||
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
|
||||
return SendTransferAbortTransport(ctx, s, transport, req)
|
||||
},
|
||||
}
|
||||
return s.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) {
|
||||
event.NetType = NET_SERVER
|
||||
event.LogicalConn = logical
|
||||
event.TransportConn = transport
|
||||
s.publishSendFileEventMonitorOnly(event)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ClientCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
|
||||
return sendFileViaTransfer(ctx, filePath, target, publishEvent)
|
||||
}
|
||||
|
||||
func (s *ServerCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
|
||||
return sendFileViaTransfer(ctx, filePath, target, publishEvent)
|
||||
}
|
||||
|
||||
func sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
session, err := newFileSendSession(filePath, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
session.fileID = buildStableFileTransferID(session)
|
||||
source, err := newTransferFileSource(filePath, session.size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer source.Close()
|
||||
if publishEvent != nil {
|
||||
hooks := transferSendHooks{
|
||||
onNegotiated: func(nextOffset int64, _ bool) {
|
||||
session.syncProgress(nextOffset, time.Now())
|
||||
publishEvent(session.onMetaSent(time.Now()))
|
||||
},
|
||||
onSegmentSent: func(offset int64, sentBytes int64) {
|
||||
event, chunkErr := session.onChunkSent(offset, sentBytes, time.Now())
|
||||
if chunkErr == nil {
|
||||
publishEvent(event)
|
||||
}
|
||||
},
|
||||
onCommitted: func() {
|
||||
publishEvent(session.onEndSent(time.Now()))
|
||||
},
|
||||
onAbort: func(stage string, offset int64, cause error) {
|
||||
publishEvent(session.onAbort(stage, offset, cause, time.Now()))
|
||||
},
|
||||
}
|
||||
handle, err := startTransferSendWithHooks(ctx, TransferSendOptions{
|
||||
Descriptor: buildFileTransferDescriptor(session),
|
||||
Source: source,
|
||||
ChunkSize: defaultFileChunkSize,
|
||||
VerifyChecksum: false,
|
||||
}, target, hooks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return handle.Wait(ctx)
|
||||
}
|
||||
handle, err := startTransferSend(ctx, TransferSendOptions{
|
||||
Descriptor: buildFileTransferDescriptor(session),
|
||||
Source: source,
|
||||
ChunkSize: defaultFileChunkSize,
|
||||
VerifyChecksum: false,
|
||||
}, target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return handle.Wait(ctx)
|
||||
}
|
||||
|
||||
func sendFileWithHooks(ctx context.Context, filePath string, hooks fileSendHooks) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
hooks.config = normalizeFileTransferConfig(hooks.config)
|
||||
session, err := newFileSendSession(filePath, time.Now())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hooks.startSession != nil {
|
||||
hooks.startSession(session)
|
||||
}
|
||||
if err := sendFileMetaWithHooks(ctx, session, hooks); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sendFileChunksWithHooks(ctx, session, hooks); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := sendFileEndWithHooks(ctx, session, hooks); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func newFileSendSession(filePath string, now time.Time) (*fileSendSession, error) {
|
||||
fi, err := os.Stat(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fi.IsDir() {
|
||||
return nil, fmt.Errorf("file path is a directory: %s", filePath)
|
||||
}
|
||||
checksum, err := computeFileChecksum(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
now = normalizeFileEventTime(now)
|
||||
name := filepath.Base(filePath)
|
||||
if name == "" || name == "." || name == string(filepath.Separator) {
|
||||
name = "unnamed.bin"
|
||||
}
|
||||
return &fileSendSession{
|
||||
fileID: buildFileID(filePath),
|
||||
path: filePath,
|
||||
name: name,
|
||||
size: fi.Size(),
|
||||
mode: fi.Mode().Perm(),
|
||||
modTime: fi.ModTime(),
|
||||
checksum: checksum,
|
||||
startedAt: now,
|
||||
updatedAt: now,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type fileSendSession struct {
|
||||
fileID string
|
||||
path string
|
||||
name string
|
||||
size int64
|
||||
mode os.FileMode
|
||||
modTime time.Time
|
||||
checksum string
|
||||
sent int64
|
||||
startedAt time.Time
|
||||
updatedAt time.Time
|
||||
previousUpdatedAt time.Time
|
||||
previousSent int64
|
||||
}
|
||||
|
||||
func (s *fileSendSession) metaEnvelope() Envelope {
|
||||
return newFileMetaEnvelope(s.fileID, s.name, s.size, s.checksum, uint32(s.mode.Perm()), s.modTime.UnixNano())
|
||||
}
|
||||
|
||||
func (s *fileSendSession) chunkEnvelope(offset int64, chunk []byte) Envelope {
|
||||
return newFileChunkEnvelope(s.fileID, offset, chunk)
|
||||
}
|
||||
|
||||
func (s *fileSendSession) endEnvelope() Envelope {
|
||||
return newFileEndEnvelope(s.fileID)
|
||||
}
|
||||
|
||||
func (s *fileSendSession) filePacket() FilePacket {
|
||||
return FilePacket{
|
||||
FileID: s.fileID,
|
||||
Name: s.name,
|
||||
Size: s.size,
|
||||
Mode: uint32(s.mode.Perm()),
|
||||
ModTime: s.modTime.UnixNano(),
|
||||
Checksum: s.checksum,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fileSendSession) advance(delta int64, now time.Time) {
|
||||
now = normalizeFileEventTime(now)
|
||||
if s.startedAt.IsZero() {
|
||||
s.startedAt = now
|
||||
}
|
||||
s.previousUpdatedAt = s.updatedAt
|
||||
s.previousSent = s.sent
|
||||
s.updatedAt = now
|
||||
s.sent += delta
|
||||
if s.sent < 0 {
|
||||
s.sent = 0
|
||||
}
|
||||
if s.size > 0 && s.sent > s.size {
|
||||
s.sent = s.size
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fileSendSession) syncProgress(progress int64, now time.Time) {
|
||||
now = normalizeFileEventTime(now)
|
||||
if progress < 0 {
|
||||
progress = 0
|
||||
}
|
||||
if s.size > 0 && progress > s.size {
|
||||
progress = s.size
|
||||
}
|
||||
if s.startedAt.IsZero() {
|
||||
s.startedAt = now
|
||||
}
|
||||
s.previousUpdatedAt = s.updatedAt
|
||||
s.previousSent = s.sent
|
||||
s.updatedAt = now
|
||||
s.sent = progress
|
||||
}
|
||||
|
||||
func (s *fileSendSession) buildEvent(kind EnvelopeKind, packet FilePacket, err error, now time.Time) FileEvent {
|
||||
now = normalizeFileEventTime(now)
|
||||
if err != nil && packet.Error == "" {
|
||||
packet.Error = err.Error()
|
||||
}
|
||||
event := FileEvent{
|
||||
Kind: kind,
|
||||
Packet: packet,
|
||||
Path: s.path,
|
||||
Received: s.sent,
|
||||
Err: err,
|
||||
Time: now,
|
||||
}
|
||||
fillFileSendEventTiming(&event, s)
|
||||
fillFileEventProgress(&event)
|
||||
return event
|
||||
}
|
||||
|
||||
func (s *fileSendSession) onMetaSent(now time.Time) FileEvent {
|
||||
s.advance(0, now)
|
||||
return s.buildEvent(EnvelopeFileMeta, s.filePacket(), nil, now)
|
||||
}
|
||||
|
||||
func (s *fileSendSession) onChunkSent(offset int64, chunkSize int64, now time.Time) (FileEvent, error) {
|
||||
if offset != s.sent {
|
||||
return FileEvent{}, fmt.Errorf("file chunk offset mismatch: got %d want %d", offset, s.sent)
|
||||
}
|
||||
packet := s.filePacket()
|
||||
packet.Offset = offset
|
||||
s.advance(chunkSize, now)
|
||||
return s.buildEvent(EnvelopeFileChunk, packet, nil, now), nil
|
||||
}
|
||||
|
||||
func (s *fileSendSession) onEndSent(now time.Time) FileEvent {
|
||||
s.advance(0, now)
|
||||
return s.buildEvent(EnvelopeFileEnd, s.filePacket(), nil, now)
|
||||
}
|
||||
|
||||
func (s *fileSendSession) onAbort(stage string, offset int64, cause error, now time.Time) FileEvent {
|
||||
packet := s.filePacket()
|
||||
packet.Stage = stage
|
||||
packet.Offset = offset
|
||||
s.advance(0, now)
|
||||
return s.buildEvent(EnvelopeFileAbort, packet, cause, now)
|
||||
}
|
||||
|
||||
func sendFileMetaWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
|
||||
if err := hooks.sendReliable(ctx, session.metaEnvelope()); err != nil {
|
||||
return handleFileSendFailure(session, hooks, "meta", 0, err)
|
||||
}
|
||||
publishFileSendEvent(hooks, session.onMetaSent(time.Now()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func sendFileChunksWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
|
||||
fd, err := os.Open(session.path)
|
||||
if err != nil {
|
||||
return handleFileSendFailure(session, hooks, "chunk", session.sent, err)
|
||||
}
|
||||
defer fd.Close()
|
||||
streamErr := streamFileChunks(ctx, fd, hooks.config.ChunkSize, func(offset int64, chunk []byte) error {
|
||||
err := hooks.sendReliable(ctx, session.chunkEnvelope(offset, chunk))
|
||||
if err != nil {
|
||||
return &fileSendError{stage: "chunk", offset: offset, err: err}
|
||||
}
|
||||
event, stateErr := session.onChunkSent(offset, int64(len(chunk)), time.Now())
|
||||
if stateErr != nil {
|
||||
return &fileSendError{stage: "chunk", offset: offset, err: stateErr}
|
||||
}
|
||||
publishFileSendEvent(hooks, event)
|
||||
return nil
|
||||
})
|
||||
if streamErr == nil {
|
||||
return nil
|
||||
}
|
||||
var sendErr *fileSendError
|
||||
if errors.As(streamErr, &sendErr) {
|
||||
return handleFileSendFailure(session, hooks, sendErr.stage, sendErr.offset, sendErr.err)
|
||||
}
|
||||
return handleFileSendFailure(session, hooks, "chunk", session.sent, streamErr)
|
||||
}
|
||||
|
||||
func sendFileEndWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
|
||||
if err := hooks.sendReliable(ctx, session.endEnvelope()); err != nil {
|
||||
return handleFileSendFailure(session, hooks, "end", session.sent, err)
|
||||
}
|
||||
publishFileSendEvent(hooks, session.onEndSent(time.Now()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleFileSendFailure(session *fileSendSession, hooks fileSendHooks, stage string, offset int64, cause error) error {
|
||||
if session != nil && hooks.sendAbort != nil && session.fileID != "" {
|
||||
_ = hooks.sendAbort(session.fileID, stage, offset, cause)
|
||||
}
|
||||
if session != nil {
|
||||
publishFileSendEvent(hooks, session.onAbort(stage, offset, cause, time.Now()))
|
||||
}
|
||||
return cause
|
||||
}
|
||||
|
||||
func publishFileSendEvent(hooks fileSendHooks, event FileEvent) {
|
||||
if hooks.publishEvent != nil {
|
||||
hooks.publishEvent(event)
|
||||
}
|
||||
}
|
||||
|
||||
func streamFileChunks(ctx context.Context, reader io.Reader, chunkSize int, sendChunk func(offset int64, chunk []byte) error) error {
|
||||
if chunkSize <= 0 {
|
||||
chunkSize = defaultFileChunkSize
|
||||
}
|
||||
buf := make([]byte, chunkSize)
|
||||
var offset int64
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("file stream canceled: %w", ctx.Err())
|
||||
default:
|
||||
}
|
||||
n, readErr := reader.Read(buf)
|
||||
if n > 0 {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
if err := sendChunk(offset, chunk); err != nil {
|
||||
return err
|
||||
}
|
||||
offset += int64(n)
|
||||
}
|
||||
if readErr == nil {
|
||||
continue
|
||||
}
|
||||
if errors.Is(readErr, io.EOF) {
|
||||
return nil
|
||||
}
|
||||
return readErr
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,224 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileSendSessionProgress(t *testing.T) {
|
||||
session := &fileSendSession{
|
||||
fileID: "file-1",
|
||||
path: "/tmp/demo.bin",
|
||||
name: "demo.bin",
|
||||
size: 200,
|
||||
checksum: "sum",
|
||||
startedAt: time.Unix(100, 0),
|
||||
updatedAt: time.Unix(100, 0),
|
||||
}
|
||||
|
||||
metaEvent := session.onMetaSent(time.Unix(100, 0))
|
||||
if got, want := metaEvent.Kind, EnvelopeFileMeta; got != want {
|
||||
t.Fatalf("meta kind mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := metaEvent.Total, int64(200); got != want {
|
||||
t.Fatalf("meta total mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := metaEvent.Received, int64(0); got != want {
|
||||
t.Fatalf("meta received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
chunkEvent, err := session.onChunkSent(0, 80, time.Unix(104, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("onChunkSent failed: %v", err)
|
||||
}
|
||||
if got, want := chunkEvent.Received, int64(80); got != want {
|
||||
t.Fatalf("chunk received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := chunkEvent.Percent, 40.0; got != want {
|
||||
t.Fatalf("chunk percent mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := chunkEvent.Duration, 4*time.Second; got != want {
|
||||
t.Fatalf("chunk duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := chunkEvent.StepDuration, 4*time.Second; got != want {
|
||||
t.Fatalf("chunk step duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := chunkEvent.InstantRateBPS, 20.0; got != want {
|
||||
t.Fatalf("chunk instant rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
secondChunkEvent, err := session.onChunkSent(80, 120, time.Unix(108, 0))
|
||||
if err != nil {
|
||||
t.Fatalf("second onChunkSent failed: %v", err)
|
||||
}
|
||||
if got, want := secondChunkEvent.Received, int64(200); got != want {
|
||||
t.Fatalf("second chunk received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := secondChunkEvent.Percent, 100.0; got != want {
|
||||
t.Fatalf("second chunk percent mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := secondChunkEvent.RateBPS, 25.0; got != want {
|
||||
t.Fatalf("second chunk rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := secondChunkEvent.StepDuration, 4*time.Second; got != want {
|
||||
t.Fatalf("second chunk step duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := secondChunkEvent.InstantRateBPS, 30.0; got != want {
|
||||
t.Fatalf("second chunk instant rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
endEvent := session.onEndSent(time.Unix(110, 0))
|
||||
if !endEvent.Done {
|
||||
t.Fatal("end event should be done")
|
||||
}
|
||||
if got, want := endEvent.Received, int64(200); got != want {
|
||||
t.Fatalf("end received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := endEvent.Percent, 100.0; got != want {
|
||||
t.Fatalf("end percent mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := endEvent.Duration, 10*time.Second; got != want {
|
||||
t.Fatalf("end duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := endEvent.StepDuration, 2*time.Second; got != want {
|
||||
t.Fatalf("end step duration mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := endEvent.RateBPS, 20.0; got != want {
|
||||
t.Fatalf("end rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := endEvent.InstantRateBPS, 0.0; got != want {
|
||||
t.Fatalf("end instant rate mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileWithHooksLogsLocalProgress(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
filePath := filepath.Join(dir, "demo.txt")
|
||||
data := []byte("hello notify send progress")
|
||||
if err := os.WriteFile(filePath, data, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile failed: %v", err)
|
||||
}
|
||||
|
||||
var sentKinds []EnvelopeKind
|
||||
var events []FileEvent
|
||||
err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{
|
||||
sendReliable: func(ctx context.Context, env Envelope) error {
|
||||
sentKinds = append(sentKinds, env.Kind)
|
||||
return nil
|
||||
},
|
||||
sendAbort: func(fileID string, stage string, offset int64, cause error) error {
|
||||
t.Fatalf("unexpected abort: fileID=%s stage=%s offset=%d err=%v", fileID, stage, offset, cause)
|
||||
return nil
|
||||
},
|
||||
publishEvent: func(event FileEvent) {
|
||||
events = append(events, event)
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("sendFileWithHooks failed: %v", err)
|
||||
}
|
||||
|
||||
if got, want := len(sentKinds), 3; got != want {
|
||||
t.Fatalf("sent kinds count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if sentKinds[0] != EnvelopeFileMeta || sentKinds[1] != EnvelopeFileChunk || sentKinds[2] != EnvelopeFileEnd {
|
||||
t.Fatalf("unexpected sent kinds: %v", sentKinds)
|
||||
}
|
||||
|
||||
if got, want := len(events), 3; got != want {
|
||||
t.Fatalf("event count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if events[0].Kind != EnvelopeFileMeta || events[1].Kind != EnvelopeFileChunk || events[2].Kind != EnvelopeFileEnd {
|
||||
t.Fatalf("unexpected event kinds: %+v", []EnvelopeKind{events[0].Kind, events[1].Kind, events[2].Kind})
|
||||
}
|
||||
if got, want := events[1].Received, int64(len(data)); got != want {
|
||||
t.Fatalf("chunk received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if !events[2].Done {
|
||||
t.Fatal("end event should be done")
|
||||
}
|
||||
if got, want := events[2].Received, int64(len(data)); got != want {
|
||||
t.Fatalf("end received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := events[2].Path, filePath; got != want {
|
||||
t.Fatalf("end path mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if events[0].Packet.FileID == "" {
|
||||
t.Fatal("fileID should not be empty")
|
||||
}
|
||||
if events[0].Packet.FileID != events[1].Packet.FileID || events[1].Packet.FileID != events[2].Packet.FileID {
|
||||
t.Fatalf("fileID should stay stable across events: %+v", events)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileWithHooksAbortOnChunkFailure(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
filePath := filepath.Join(dir, "demo.txt")
|
||||
data := []byte("hello notify send failure")
|
||||
if err := os.WriteFile(filePath, data, 0o644); err != nil {
|
||||
t.Fatalf("WriteFile failed: %v", err)
|
||||
}
|
||||
|
||||
wantErr := errors.New("chunk ack timeout")
|
||||
var abortFileID string
|
||||
var abortStage string
|
||||
var abortOffset int64
|
||||
var abortCause error
|
||||
var events []FileEvent
|
||||
|
||||
err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{
|
||||
sendReliable: func(ctx context.Context, env Envelope) error {
|
||||
if env.Kind == EnvelopeFileChunk {
|
||||
return wantErr
|
||||
}
|
||||
return nil
|
||||
},
|
||||
sendAbort: func(fileID string, stage string, offset int64, cause error) error {
|
||||
abortFileID = fileID
|
||||
abortStage = stage
|
||||
abortOffset = offset
|
||||
abortCause = cause
|
||||
return nil
|
||||
},
|
||||
publishEvent: func(event FileEvent) {
|
||||
events = append(events, event)
|
||||
},
|
||||
})
|
||||
if !errors.Is(err, wantErr) {
|
||||
t.Fatalf("sendFileWithHooks error mismatch: got %v want %v", err, wantErr)
|
||||
}
|
||||
if abortFileID == "" {
|
||||
t.Fatal("abort should capture fileID")
|
||||
}
|
||||
if got, want := abortStage, "chunk"; got != want {
|
||||
t.Fatalf("abort stage mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := abortOffset, int64(0); got != want {
|
||||
t.Fatalf("abort offset mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if !errors.Is(abortCause, wantErr) {
|
||||
t.Fatalf("abort cause mismatch: got %v want %v", abortCause, wantErr)
|
||||
}
|
||||
if got, want := len(events), 2; got != want {
|
||||
t.Fatalf("event count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := events[0].Kind, EnvelopeFileMeta; got != want {
|
||||
t.Fatalf("first event kind mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := events[1].Kind, EnvelopeFileAbort; got != want {
|
||||
t.Fatalf("abort event kind mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := events[1].Packet.Stage, "chunk"; got != want {
|
||||
t.Fatalf("abort packet stage mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := events[1].Received, int64(0); got != want {
|
||||
t.Fatalf("abort received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if !errors.Is(events[1].Err, wantErr) {
|
||||
t.Fatalf("abort event error mismatch: got %v want %v", events[1].Err, wantErr)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
fileTransferMetadataKindKey = "_notify.file_adapter_kind"
|
||||
fileTransferMetadataKindValue = "file"
|
||||
fileTransferMetadataNameKey = "_notify.file_name"
|
||||
fileTransferMetadataModeKey = "_notify.file_mode"
|
||||
fileTransferMetadataModTimeKey = "_notify.file_mod_time"
|
||||
)
|
||||
|
||||
type transferFileSource struct {
|
||||
file *os.File
|
||||
size int64
|
||||
}
|
||||
|
||||
func newTransferFileSource(path string, size int64) (*transferFileSource, error) {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &transferFileSource{
|
||||
file: file,
|
||||
size: size,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *transferFileSource) ReadAt(p []byte, off int64) (int, error) {
|
||||
if s == nil || s.file == nil {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
return s.file.ReadAt(p, off)
|
||||
}
|
||||
|
||||
func (s *transferFileSource) Size() int64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
return s.size
|
||||
}
|
||||
|
||||
func (s *transferFileSource) Close() error {
|
||||
if s == nil || s.file == nil {
|
||||
return nil
|
||||
}
|
||||
return s.file.Close()
|
||||
}
|
||||
|
||||
type transferCloseWithError interface {
|
||||
CloseWithError(error) error
|
||||
}
|
||||
|
||||
type transferReceiveOffsetProvider interface {
|
||||
NextOffset() int64
|
||||
}
|
||||
|
||||
type fileTransferReceiveSink struct {
|
||||
pool *fileReceivePool
|
||||
scope string
|
||||
packet FilePacket
|
||||
publishEvent func(FileEvent)
|
||||
|
||||
mu sync.Mutex
|
||||
offset int64
|
||||
committed bool
|
||||
closed bool
|
||||
}
|
||||
|
||||
func newFileTransferReceiveSink(pool *fileReceivePool, scope string, packet FilePacket, publishEvent func(FileEvent)) (*fileTransferReceiveSink, error) {
|
||||
if pool == nil {
|
||||
return nil, errTransferSinkNil
|
||||
}
|
||||
now := time.Now()
|
||||
session, err := pool.onMeta(scope, packet, now)
|
||||
if publishEvent != nil {
|
||||
publishEvent(fileReceiveEventFromSession(EnvelopeFileMeta, packet, session, "", err, now))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fileTransferReceiveSink{
|
||||
pool: pool,
|
||||
scope: normalizeFileScope(scope),
|
||||
packet: packet,
|
||||
publishEvent: publishEvent,
|
||||
offset: session.received,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) NextOffset() int64 {
|
||||
if s == nil {
|
||||
return 0
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.offset
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) WriteAt(p []byte, off int64) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
closed := s.closed
|
||||
s.mu.Unlock()
|
||||
if closed {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
now := time.Now()
|
||||
packet := s.packet
|
||||
packet.Offset = off
|
||||
packet.Chunk = append([]byte(nil), p...)
|
||||
session, err := s.pool.onChunk(s.scope, packet, now)
|
||||
if s.publishEvent != nil {
|
||||
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileChunk, packet, session, "", err, now))
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
s.mu.Lock()
|
||||
if end := off + int64(len(p)); end > s.offset {
|
||||
s.offset = end
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) Sync(context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) Commit(context.Context) error {
|
||||
s.mu.Lock()
|
||||
closed := s.closed
|
||||
s.mu.Unlock()
|
||||
if closed {
|
||||
return os.ErrClosed
|
||||
}
|
||||
now := time.Now()
|
||||
finalPath, session, err := s.pool.onEnd(s.scope, FilePacket{FileID: s.packet.FileID}, now)
|
||||
if s.publishEvent != nil {
|
||||
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileEnd, s.packet, session, finalPath, err, now))
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.committed = true
|
||||
s.offset = s.packet.Size
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) Close() error {
|
||||
return s.closeWithError(nil, false)
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) CloseWithError(err error) error {
|
||||
return s.closeWithError(err, true)
|
||||
}
|
||||
|
||||
func (s *fileTransferReceiveSink) closeWithError(err error, publish bool) error {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.closed = true
|
||||
committed := s.committed
|
||||
offset := s.offset
|
||||
s.mu.Unlock()
|
||||
if committed {
|
||||
return nil
|
||||
}
|
||||
packet := FilePacket{
|
||||
FileID: s.packet.FileID,
|
||||
Offset: offset,
|
||||
}
|
||||
if err != nil {
|
||||
packet.Stage = "abort"
|
||||
packet.Error = err.Error()
|
||||
}
|
||||
now := time.Now()
|
||||
session, abortErr := s.pool.onAbort(s.scope, packet, now)
|
||||
if publish && err != nil && s.publishEvent != nil {
|
||||
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileAbort, packet, session, "", firstErr(abortErr, err), now))
|
||||
}
|
||||
return abortErr
|
||||
}
|
||||
|
||||
func firstErr(primary error, fallback error) error {
|
||||
if primary != nil {
|
||||
return primary
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func fileReceiveEventFromSession(kind EnvelopeKind, packet FilePacket, session *fileReceiveSession, path string, err error, now time.Time) FileEvent {
|
||||
event := FileEvent{
|
||||
Kind: kind,
|
||||
Packet: packet,
|
||||
Time: now,
|
||||
Err: err,
|
||||
}
|
||||
switch kind {
|
||||
case EnvelopeFileAbort:
|
||||
event.Received = packet.Offset
|
||||
case EnvelopeFileEnd:
|
||||
event.Path = path
|
||||
}
|
||||
if session != nil {
|
||||
if event.Path == "" {
|
||||
if kind == EnvelopeFileEnd && session.finalPath != "" {
|
||||
event.Path = session.finalPath
|
||||
} else {
|
||||
event.Path = session.tmpPath
|
||||
}
|
||||
}
|
||||
if kind != EnvelopeFileAbort {
|
||||
event.Received = session.received
|
||||
}
|
||||
fillFileEventTiming(&event, session)
|
||||
}
|
||||
fillFileEventProgress(&event)
|
||||
return event
|
||||
}
|
||||
|
||||
func buildFileTransferDescriptor(session *fileSendSession) TransferDescriptor {
|
||||
return TransferDescriptor{
|
||||
ID: session.fileID,
|
||||
Channel: TransferChannelData,
|
||||
Size: session.size,
|
||||
Checksum: session.checksum,
|
||||
Metadata: map[string]string{
|
||||
fileTransferMetadataKindKey: fileTransferMetadataKindValue,
|
||||
fileTransferMetadataNameKey: session.name,
|
||||
fileTransferMetadataModeKey: strconv.FormatUint(uint64(session.mode.Perm()), 10),
|
||||
fileTransferMetadataModTimeKey: strconv.FormatInt(session.modTime.UnixNano(), 10),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func buildStableFileTransferID(session *fileSendSession) string {
|
||||
if session == nil {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(session.name + "|" + strconv.FormatInt(session.size, 10) + "|" + normalizeChecksum(session.checksum)))
|
||||
return fmt.Sprintf("%s-%s", fileIDBaseName(session.name), hex.EncodeToString(sum[:8]))
|
||||
}
|
||||
|
||||
func parseFileTransferPacket(desc TransferDescriptor) (FilePacket, bool) {
|
||||
if desc.Metadata[fileTransferMetadataKindKey] != fileTransferMetadataKindValue {
|
||||
return FilePacket{}, false
|
||||
}
|
||||
packet := FilePacket{
|
||||
FileID: desc.ID,
|
||||
Name: desc.Metadata[fileTransferMetadataNameKey],
|
||||
Size: desc.Size,
|
||||
Checksum: desc.Checksum,
|
||||
}
|
||||
if modeValue := desc.Metadata[fileTransferMetadataModeKey]; modeValue != "" {
|
||||
if mode, err := strconv.ParseUint(modeValue, 10, 32); err == nil {
|
||||
packet.Mode = uint32(mode)
|
||||
}
|
||||
}
|
||||
if modTimeValue := desc.Metadata[fileTransferMetadataModTimeKey]; modTimeValue != "" {
|
||||
if modTime, err := strconv.ParseInt(modTimeValue, 10, 64); err == nil {
|
||||
packet.ModTime = modTime
|
||||
}
|
||||
}
|
||||
return packet, packet.FileID != "" && packet.Name != ""
|
||||
}
|
||||
|
||||
func (c *ClientCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
|
||||
packet, ok := parseFileTransferPacket(info.Descriptor)
|
||||
if !ok {
|
||||
return TransferReceiveOptions{}, false, nil
|
||||
}
|
||||
sink, err := newFileTransferReceiveSink(c.getFileReceivePool(), clientFileScope(), packet, func(event FileEvent) {
|
||||
event.NetType = NET_CLIENT
|
||||
event.ServerConn = c
|
||||
c.publishReceivedFileEventMonitorOnly(event)
|
||||
})
|
||||
if err != nil {
|
||||
return TransferReceiveOptions{}, true, err
|
||||
}
|
||||
return TransferReceiveOptions{
|
||||
Descriptor: cloneTransferDescriptor(info.Descriptor),
|
||||
Sink: sink,
|
||||
VerifyChecksum: false,
|
||||
SyncOnCheckpoint: false,
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func (s *ServerCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
|
||||
packet, ok := parseFileTransferPacket(info.Descriptor)
|
||||
if !ok {
|
||||
return TransferReceiveOptions{}, false, nil
|
||||
}
|
||||
sink, err := newFileTransferReceiveSink(s.getFileReceivePool(), transferPublicScopeForPeer(info.LogicalConn), packet, func(event FileEvent) {
|
||||
event.NetType = NET_SERVER
|
||||
event.LogicalConn = info.LogicalConn
|
||||
event.TransportConn = info.TransportConn
|
||||
s.publishReceivedFileEventMonitorOnly(event)
|
||||
})
|
||||
if err != nil {
|
||||
return TransferReceiveOptions{}, true, err
|
||||
}
|
||||
return TransferReceiveOptions{
|
||||
Descriptor: cloneTransferDescriptor(info.Descriptor),
|
||||
Sink: sink,
|
||||
VerifyChecksum: false,
|
||||
SyncOnCheckpoint: false,
|
||||
}, true, nil
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSendFileUsesTransferKernelAndBuiltinFileReceiver(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||
}
|
||||
receiveDir := t.TempDir()
|
||||
if err := server.SetFileReceiveDir(receiveDir); err != nil {
|
||||
t.Fatalf("SetFileReceiveDir failed: %v", err)
|
||||
}
|
||||
var serverMu sync.Mutex
|
||||
var serverEvents []FileEvent
|
||||
server.SetFileHandler(func(event FileEvent) {
|
||||
serverMu.Lock()
|
||||
serverEvents = append(serverEvents, event)
|
||||
serverMu.Unlock()
|
||||
})
|
||||
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||
t.Fatalf("server Listen failed: %v", err)
|
||||
}
|
||||
defer func() { _ = server.Stop() }()
|
||||
|
||||
client := NewClient().(*ClientCommon)
|
||||
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||
}
|
||||
var clientMu sync.Mutex
|
||||
var clientEvents []FileEvent
|
||||
client.setFileEventObserver(func(event FileEvent) {
|
||||
clientMu.Lock()
|
||||
clientEvents = append(clientEvents, event)
|
||||
clientMu.Unlock()
|
||||
})
|
||||
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||
t.Fatalf("client Connect failed: %v", err)
|
||||
}
|
||||
defer func() { _ = client.Stop() }()
|
||||
|
||||
payload := bytes.Repeat([]byte("send-file-transfer-kernel-"), 1024)
|
||||
sendPath := filepath.Join(t.TempDir(), "payload.bin")
|
||||
if err := os.WriteFile(sendPath, payload, 0o600); err != nil {
|
||||
t.Fatalf("WriteFile failed: %v", err)
|
||||
}
|
||||
|
||||
if err := client.SendFile(context.Background(), sendPath); err != nil {
|
||||
t.Fatalf("SendFile failed: %v", err)
|
||||
}
|
||||
|
||||
receivedPath := waitForSingleFileInDir(t, receiveDir, 2*time.Second)
|
||||
received, err := os.ReadFile(receivedPath)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadFile failed: %v", err)
|
||||
}
|
||||
if !bytes.Equal(received, payload) {
|
||||
t.Fatalf("received payload mismatch: got %d want %d", len(received), len(payload))
|
||||
}
|
||||
|
||||
clientSnapshots, err := GetClientTransferSnapshots(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientTransferSnapshots failed: %v", err)
|
||||
}
|
||||
serverSnapshots, err := GetServerTransferSnapshots(server)
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerTransferSnapshots failed: %v", err)
|
||||
}
|
||||
if !containsFileTransferSnapshot(clientSnapshots) {
|
||||
t.Fatalf("client snapshots do not contain file transfer metadata: %+v", clientSnapshots)
|
||||
}
|
||||
if !containsFileTransferSnapshot(serverSnapshots) {
|
||||
t.Fatalf("server snapshots do not contain file transfer metadata: %+v", serverSnapshots)
|
||||
}
|
||||
|
||||
clientMu.Lock()
|
||||
serverMu.Lock()
|
||||
defer clientMu.Unlock()
|
||||
defer serverMu.Unlock()
|
||||
if !containsFileEventKind(clientEvents, EnvelopeFileMeta) || !containsFileEventKind(clientEvents, EnvelopeFileEnd) {
|
||||
t.Fatalf("client file events missing meta/end: %+v", clientEvents)
|
||||
}
|
||||
if !containsFileEventKind(serverEvents, EnvelopeFileMeta) || !containsFileEventKind(serverEvents, EnvelopeFileEnd) {
|
||||
t.Fatalf("server file events missing meta/end: %+v", serverEvents)
|
||||
}
|
||||
}
|
||||
|
||||
func waitForSingleFileInDir(t *testing.T, dir string, timeout time.Duration) string {
|
||||
t.Helper()
|
||||
deadline := time.Now().Add(timeout)
|
||||
for time.Now().Before(deadline) {
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err == nil {
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
return filepath.Join(dir, entry.Name())
|
||||
}
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
}
|
||||
t.Fatalf("timed out waiting for received file in %s", dir)
|
||||
return ""
|
||||
}
|
||||
|
||||
func containsFileTransferSnapshot(list []TransferSnapshot) bool {
|
||||
for _, snapshot := range list {
|
||||
if snapshot.Metadata[fileTransferMetadataKindKey] == fileTransferMetadataKindValue && snapshot.State == TransferStateDone {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsFileEventKind(list []FileEvent, kind EnvelopeKind) bool {
|
||||
for _, event := range list {
|
||||
if event.Kind == kind {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
package notify
|
||||
|
||||
import "time"
|
||||
|
||||
const defaultFileSendRetry = 3
|
||||
|
||||
const defaultFileAckTimeout = 5 * time.Second
|
||||
|
||||
type fileTransferConfig struct {
|
||||
ChunkSize int
|
||||
AckTimeout time.Duration
|
||||
SendRetry int
|
||||
ReceiveCompletedLimit int
|
||||
MonitorCompletedLimit int
|
||||
}
|
||||
|
||||
func defaultFileTransferConfig() fileTransferConfig {
|
||||
return fileTransferConfig{
|
||||
ChunkSize: defaultFileChunkSize,
|
||||
AckTimeout: defaultFileAckTimeout,
|
||||
SendRetry: defaultFileSendRetry,
|
||||
ReceiveCompletedLimit: defaultFileReceiveCompletedLimit,
|
||||
MonitorCompletedLimit: defaultFileTransferCompletedLimit,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeFileTransferConfig(cfg fileTransferConfig) fileTransferConfig {
|
||||
defaults := defaultFileTransferConfig()
|
||||
if cfg.ChunkSize <= 0 {
|
||||
cfg.ChunkSize = defaults.ChunkSize
|
||||
}
|
||||
if cfg.AckTimeout <= 0 {
|
||||
cfg.AckTimeout = defaults.AckTimeout
|
||||
}
|
||||
if cfg.SendRetry <= 0 {
|
||||
cfg.SendRetry = defaults.SendRetry
|
||||
}
|
||||
if cfg.ReceiveCompletedLimit <= 0 {
|
||||
cfg.ReceiveCompletedLimit = defaults.ReceiveCompletedLimit
|
||||
}
|
||||
if cfg.MonitorCompletedLimit <= 0 {
|
||||
cfg.MonitorCompletedLimit = defaults.MonitorCompletedLimit
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getFileTransferConfig() fileTransferConfig {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.fileTransferCfg = normalizeFileTransferConfig(c.fileTransferCfg)
|
||||
return c.fileTransferCfg
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getFileTransferConfig() fileTransferConfig {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.fileTransferCfg = normalizeFileTransferConfig(s.fileTransferCfg)
|
||||
return s.fileTransferCfg
|
||||
}
|
||||
|
||||
func (c *ClientCommon) setFileTransferConfig(cfg fileTransferConfig) {
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
c.mu.Lock()
|
||||
c.fileTransferCfg = cfg
|
||||
state := c.logicalSession
|
||||
c.mu.Unlock()
|
||||
if state != nil {
|
||||
state.applyFileTransferConfig(cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ServerCommon) setFileTransferConfig(cfg fileTransferConfig) {
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
s.mu.Lock()
|
||||
s.fileTransferCfg = cfg
|
||||
state := s.logicalSession
|
||||
s.mu.Unlock()
|
||||
if state != nil {
|
||||
state.applyFileTransferConfig(cfg)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientFileTransferConfigDefaults(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
|
||||
cfg := client.getFileTransferConfig()
|
||||
|
||||
if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want {
|
||||
t.Fatalf("chunk size mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want {
|
||||
t.Fatalf("ack timeout mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := cfg.SendRetry, defaultFileSendRetry; got != want {
|
||||
t.Fatalf("send retry mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want {
|
||||
t.Fatalf("receive completed limit mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want {
|
||||
t.Fatalf("monitor completed limit mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerFileTransferConfigNormalization(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
|
||||
server.setFileTransferConfig(fileTransferConfig{})
|
||||
cfg := server.getFileTransferConfig()
|
||||
|
||||
if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want {
|
||||
t.Fatalf("normalized chunk size mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want {
|
||||
t.Fatalf("normalized ack timeout mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := cfg.SendRetry, defaultFileSendRetry; got != want {
|
||||
t.Fatalf("normalized retry mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want {
|
||||
t.Fatalf("normalized receive completed limit mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want {
|
||||
t.Fatalf("normalized monitor completed limit mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientFileTransferConfigPropagatesRetentionLimits(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
|
||||
client.setFileTransferConfig(fileTransferConfig{
|
||||
ChunkSize: 64,
|
||||
AckTimeout: time.Second,
|
||||
SendRetry: 2,
|
||||
ReceiveCompletedLimit: 7,
|
||||
MonitorCompletedLimit: 9,
|
||||
})
|
||||
|
||||
if got, want := client.getFileReceivePool().completedLimit, 7; got != want {
|
||||
t.Fatalf("client receive pool completed limit mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := client.getFileTransferState().monitorView().completedLimit, 9; got != want {
|
||||
t.Fatalf("client transfer monitor completed limit mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSendFileWithHooksHonorsConfiguredChunkSize(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "payload.bin")
|
||||
if err := os.WriteFile(path, []byte("abcdefg"), 0o600); err != nil {
|
||||
t.Fatalf("write temp file failed: %v", err)
|
||||
}
|
||||
|
||||
var chunks []int
|
||||
err := sendFileWithHooks(context.Background(), path, fileSendHooks{
|
||||
config: fileTransferConfig{
|
||||
ChunkSize: 3,
|
||||
AckTimeout: time.Millisecond,
|
||||
SendRetry: 1,
|
||||
},
|
||||
sendReliable: func(ctx context.Context, env Envelope) error {
|
||||
if env.Kind == EnvelopeFileChunk {
|
||||
chunks = append(chunks, len(env.File.Chunk))
|
||||
}
|
||||
return nil
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("sendFileWithHooks failed: %v", err)
|
||||
}
|
||||
|
||||
if got, want := chunks, []int{3, 3, 1}; !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("chunk sizes mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,174 @@
|
||||
package notify
|
||||
|
||||
import "sync"
|
||||
|
||||
const defaultFileTransferCompletedLimit = 128
|
||||
|
||||
type fileTransferMonitor struct {
|
||||
mu sync.Mutex
|
||||
active map[string]fileTransferSnapshot
|
||||
completed map[string]fileTransferSnapshot
|
||||
runtimeActive map[string]fileTransferSnapshot
|
||||
runtimeCompleted map[string]fileTransferSnapshot
|
||||
completedLimit int
|
||||
}
|
||||
|
||||
func newFileTransferMonitor() *fileTransferMonitor {
|
||||
return newFileTransferMonitorWithConfig(defaultFileTransferConfig())
|
||||
}
|
||||
|
||||
func newFileTransferMonitorWithConfig(cfg fileTransferConfig) *fileTransferMonitor {
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
return newFileTransferMonitorWithCompletedLimit(cfg.MonitorCompletedLimit)
|
||||
}
|
||||
|
||||
func newFileTransferMonitorWithCompletedLimit(limit int) *fileTransferMonitor {
|
||||
if limit <= 0 {
|
||||
limit = defaultFileTransferCompletedLimit
|
||||
}
|
||||
return &fileTransferMonitor{
|
||||
active: make(map[string]fileTransferSnapshot),
|
||||
completed: make(map[string]fileTransferSnapshot),
|
||||
runtimeActive: make(map[string]fileTransferSnapshot),
|
||||
runtimeCompleted: make(map[string]fileTransferSnapshot),
|
||||
completedLimit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) applyConfig(cfg fileTransferConfig) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
cfg = normalizeFileTransferConfig(cfg)
|
||||
m.mu.Lock()
|
||||
m.completedLimit = cfg.MonitorCompletedLimit
|
||||
m.trimCompletedLocked()
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) observe(direction fileTransferDirection, event FileEvent) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if !isFileTransferObservable(event.Kind) {
|
||||
return
|
||||
}
|
||||
snapshot := fileTransferSnapshotFromEvent(direction, event)
|
||||
key := fileTransferMonitorKey(direction, snapshot.Scope, snapshot.FileID)
|
||||
runtimeKey := fileTransferRuntimeMonitorKey(direction, snapshot.RuntimeScope, snapshot.FileID)
|
||||
if key == "" || runtimeKey == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if isFileTransferTerminal(snapshot.Kind) {
|
||||
delete(m.active, key)
|
||||
m.completed[key] = snapshot
|
||||
delete(m.runtimeActive, runtimeKey)
|
||||
m.runtimeCompleted[runtimeKey] = snapshot
|
||||
m.trimCompletedLocked()
|
||||
return
|
||||
}
|
||||
delete(m.completed, key)
|
||||
m.active[key] = snapshot
|
||||
delete(m.runtimeCompleted, runtimeKey)
|
||||
m.runtimeActive[runtimeKey] = snapshot
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) activeSnapshots() []fileTransferSnapshot {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return sortedFileTransferSnapshots(m.active)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) activeSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return filteredFileTransferSnapshots(m.active, direction)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) completedSnapshots() []fileTransferSnapshot {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return sortedFileTransferSnapshots(m.completed)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) completedSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return filteredFileTransferSnapshots(m.completed, direction)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) latestSnapshot(direction fileTransferDirection, scope string, fileID string) (fileTransferSnapshot, bool) {
|
||||
if m == nil {
|
||||
return fileTransferSnapshot{}, false
|
||||
}
|
||||
key := fileTransferMonitorKey(direction, scope, fileID)
|
||||
if key == "" {
|
||||
return fileTransferSnapshot{}, false
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if snapshot, ok := m.active[key]; ok {
|
||||
return snapshot, true
|
||||
}
|
||||
snapshot, ok := m.completed[key]
|
||||
return snapshot, ok
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) snapshotsByFileID(fileID string) []fileTransferSnapshot {
|
||||
if m == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
latest := latestFileTransferSnapshotsLocked(m.active, m.completed)
|
||||
return filterFileTransferSnapshotsByFileID(latest, fileID)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) snapshotsByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSnapshot {
|
||||
if m == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
latest := latestFileTransferSnapshotsLocked(m.active, m.completed)
|
||||
return filterFileTransferSnapshotsByDirectionAndFileID(latest, direction, fileID)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) trimCompletedLocked() {
|
||||
trimFileTransferSnapshotsLocked(m.completed, m.completedLimit)
|
||||
trimFileTransferSnapshotsLocked(m.runtimeCompleted, m.completedLimit)
|
||||
}
|
||||
|
||||
func trimFileTransferSnapshotsLocked(snapshots map[string]fileTransferSnapshot, limit int) {
|
||||
if limit <= 0 || len(snapshots) <= limit {
|
||||
return
|
||||
}
|
||||
for len(snapshots) > limit {
|
||||
oldestKey := ""
|
||||
oldestSnapshot := fileTransferSnapshot{}
|
||||
for key, snapshot := range snapshots {
|
||||
if oldestKey == "" || fileTransferSnapshotOlder(snapshot, oldestSnapshot, key, oldestKey) {
|
||||
oldestKey = key
|
||||
oldestSnapshot = snapshot
|
||||
}
|
||||
}
|
||||
if oldestKey == "" {
|
||||
return
|
||||
}
|
||||
delete(snapshots, oldestKey)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestClientTransferMonitorTracksSendLifecycle(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
monitor := client.getFileTransferState().monitorView()
|
||||
now := time.Unix(100, 0)
|
||||
|
||||
client.publishSendFileEvent(FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "send-1", Size: 100},
|
||||
Path: "/tmp/send-1.bin",
|
||||
Total: 100,
|
||||
Time: now,
|
||||
})
|
||||
client.publishSendFileEvent(FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "send-1", Size: 100},
|
||||
Path: "/tmp/send-1.bin",
|
||||
Received: 40,
|
||||
Total: 100,
|
||||
Percent: 40,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(2 * time.Second),
|
||||
Duration: 2 * time.Second,
|
||||
RateBPS: 20,
|
||||
Time: now.Add(2 * time.Second),
|
||||
StepDuration: 2 * time.Second,
|
||||
})
|
||||
|
||||
active := monitor.activeSnapshots()
|
||||
if got, want := len(active), 1; got != want {
|
||||
t.Fatalf("active count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active[0].Direction, fileTransferDirectionSend; got != want {
|
||||
t.Fatalf("direction mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := active[0].Scope, clientFileScope(); got != want {
|
||||
t.Fatalf("scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := active[0].Received, int64(40); got != want {
|
||||
t.Fatalf("received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1")
|
||||
if !ok {
|
||||
t.Fatal("latest snapshot should exist while active")
|
||||
}
|
||||
if got, want := snapshot.Kind, EnvelopeFileChunk; got != want {
|
||||
t.Fatalf("latest active kind mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := snapshot.Received, int64(40); got != want {
|
||||
t.Fatalf("latest active received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
client.publishSendFileEvent(FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "send-1", Size: 100},
|
||||
Path: "/tmp/send-1.bin",
|
||||
Received: 100,
|
||||
Total: 100,
|
||||
Percent: 100,
|
||||
Done: true,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(4 * time.Second),
|
||||
Duration: 4 * time.Second,
|
||||
RateBPS: 25,
|
||||
Time: now.Add(4 * time.Second),
|
||||
})
|
||||
|
||||
active = monitor.activeSnapshots()
|
||||
if got, want := len(active), 0; got != want {
|
||||
t.Fatalf("active count after end mismatch: got %d want %d", got, want)
|
||||
}
|
||||
completed := monitor.completedSnapshots()
|
||||
if got, want := len(completed), 1; got != want {
|
||||
t.Fatalf("completed count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := completed[0].Done, true; got != want {
|
||||
t.Fatalf("done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := completed[0].Received, int64(100); got != want {
|
||||
t.Fatalf("completed received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
snapshot, ok = monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1")
|
||||
if !ok {
|
||||
t.Fatal("latest snapshot should exist after completion")
|
||||
}
|
||||
if got, want := snapshot.Kind, EnvelopeFileEnd; got != want {
|
||||
t.Fatalf("latest completed kind mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := snapshot.Done, true; got != want {
|
||||
t.Fatalf("latest completed done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerTransferMonitorUsesClientScope(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
monitor := server.getFileTransferState().monitorView()
|
||||
client := &ClientConn{ClientID: "client-1"}
|
||||
now := time.Unix(200, 0)
|
||||
|
||||
server.publishReceivedFileEvent(FileEvent{
|
||||
NetType: NET_SERVER,
|
||||
ClientConn: client,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "recv-1", Size: 50},
|
||||
Path: "/tmp/recv-1.part",
|
||||
Received: 20,
|
||||
Total: 50,
|
||||
Percent: 40,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(time.Second),
|
||||
Duration: time.Second,
|
||||
RateBPS: 20,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
active := monitor.activeSnapshots()
|
||||
if got, want := len(active), 1; got != want {
|
||||
t.Fatalf("active count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active[0].Direction, fileTransferDirectionReceive; got != want {
|
||||
t.Fatalf("direction mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := active[0].Scope, serverFileScope(client); got != want {
|
||||
t.Fatalf("scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := active[0].FileID, "recv-1"; got != want {
|
||||
t.Fatalf("fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferMonitorDirectionQueries(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
now := time.Unix(300, 0)
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 10},
|
||||
Received: 4,
|
||||
Total: 10,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 10},
|
||||
Received: 7,
|
||||
Total: 10,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
sendSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionSend)
|
||||
if got, want := len(sendSnapshots), 1; got != want {
|
||||
t.Fatalf("send snapshots count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := sendSnapshots[0].Received, int64(4); got != want {
|
||||
t.Fatalf("send snapshot received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
recvSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionReceive)
|
||||
if got, want := len(recvSnapshots), 1; got != want {
|
||||
t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := recvSnapshots[0].Received, int64(7); got != want {
|
||||
t.Fatalf("recv snapshot received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
sendSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "shared")
|
||||
if !ok {
|
||||
t.Fatal("send latest snapshot should exist")
|
||||
}
|
||||
if got, want := sendSnapshot.Received, int64(4); got != want {
|
||||
t.Fatalf("send latest received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
recvSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionReceive, clientFileScope(), "shared")
|
||||
if !ok {
|
||||
t.Fatal("recv latest snapshot should exist")
|
||||
}
|
||||
if got, want := recvSnapshot.Received, int64(7); got != want {
|
||||
t.Fatalf("recv latest received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferMonitorSnapshotsByFileID(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
now := time.Unix(400, 0)
|
||||
serverClientA := &ClientConn{ClientID: "client-a"}
|
||||
serverClientB := &ClientConn{ClientID: "client-b"}
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 8,
|
||||
Total: 20,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClientA,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 12,
|
||||
Total: 20,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClientB,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 20,
|
||||
Total: 20,
|
||||
Done: true,
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
|
||||
allSnapshots := monitor.snapshotsByFileID("shared")
|
||||
if got, want := len(allSnapshots), 3; got != want {
|
||||
t.Fatalf("all snapshots count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := allSnapshots[0].Direction, fileTransferDirectionReceive; got != want {
|
||||
t.Fatalf("first snapshot direction mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := allSnapshots[0].Scope, serverFileScope(serverClientA); got != want {
|
||||
t.Fatalf("first snapshot scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := allSnapshots[1].Scope, serverFileScope(serverClientB); got != want {
|
||||
t.Fatalf("second snapshot scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := allSnapshots[2].Direction, fileTransferDirectionSend; got != want {
|
||||
t.Fatalf("third snapshot direction mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
recvSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionReceive, "shared")
|
||||
if got, want := len(recvSnapshots), 2; got != want {
|
||||
t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := recvSnapshots[0].Scope, serverFileScope(serverClientA); got != want {
|
||||
t.Fatalf("recv first scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := recvSnapshots[1].Scope, serverFileScope(serverClientB); got != want {
|
||||
t.Fatalf("recv second scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := recvSnapshots[1].Done, true; got != want {
|
||||
t.Fatalf("recv completed snapshot mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
sendSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionSend, "shared")
|
||||
if got, want := len(sendSnapshots), 1; got != want {
|
||||
t.Fatalf("send snapshots count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := sendSnapshots[0].Received, int64(8); got != want {
|
||||
t.Fatalf("send snapshot received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferMonitorCompletedRetentionEvictsOldest(t *testing.T) {
|
||||
monitor := newFileTransferMonitorWithCompletedLimit(2)
|
||||
now := time.Unix(500, 0)
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "active-1", Size: 10},
|
||||
Received: 3,
|
||||
Total: 10,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "done-1", Size: 10},
|
||||
Received: 10,
|
||||
Total: 10,
|
||||
Done: true,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "done-2", Size: 10},
|
||||
Received: 10,
|
||||
Total: 10,
|
||||
Done: true,
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "done-3", Size: 10},
|
||||
Received: 10,
|
||||
Total: 10,
|
||||
Done: true,
|
||||
Time: now.Add(3 * time.Second),
|
||||
})
|
||||
|
||||
active := monitor.activeSnapshots()
|
||||
if got, want := len(active), 1; got != want {
|
||||
t.Fatalf("active count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active[0].FileID, "active-1"; got != want {
|
||||
t.Fatalf("active fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
completed := monitor.completedSnapshots()
|
||||
if got, want := len(completed), 2; got != want {
|
||||
t.Fatalf("completed count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := completed[0].FileID, "done-2"; got != want {
|
||||
t.Fatalf("first completed fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := completed[1].FileID, "done-3"; got != want {
|
||||
t.Fatalf("second completed fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-1"); ok {
|
||||
t.Fatal("oldest completed snapshot should be evicted")
|
||||
}
|
||||
if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-3"); !ok {
|
||||
t.Fatal("latest completed snapshot should be retained")
|
||||
}
|
||||
if snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "active-1"); !ok {
|
||||
t.Fatal("active snapshot should remain available")
|
||||
} else if got, want := snapshot.Kind, EnvelopeFileChunk; got != want {
|
||||
t.Fatalf("active latest kind mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,283 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
)
|
||||
|
||||
type FileTransferSummary struct {
|
||||
Direction TransferDirection
|
||||
Scope string
|
||||
RuntimeScope string
|
||||
TransportGeneration uint64
|
||||
NetType NetType
|
||||
Kind EnvelopeKind
|
||||
FileID string
|
||||
Path string
|
||||
Received int64
|
||||
Total int64
|
||||
Percent float64
|
||||
Active bool
|
||||
Terminal bool
|
||||
Done bool
|
||||
Failed bool
|
||||
Err error
|
||||
StartedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Duration time.Duration
|
||||
RateBPS float64
|
||||
StepDuration time.Duration
|
||||
InstantRateBPS float64
|
||||
Time time.Time
|
||||
Stage string
|
||||
}
|
||||
|
||||
type FileTransferSummaryGroup struct {
|
||||
Send []FileTransferSummary
|
||||
Receive []FileTransferSummary
|
||||
}
|
||||
|
||||
type FileTransferSummaryQuery struct {
|
||||
Scope string
|
||||
RuntimeScope string
|
||||
TransportGeneration uint64
|
||||
MatchTransportGeneration bool
|
||||
}
|
||||
|
||||
type clientFileTransferSummaryReader interface {
|
||||
clientFileTransferActiveSummaries() FileTransferSummaryGroup
|
||||
clientFileTransferCompletedSummaries() FileTransferSummaryGroup
|
||||
clientFileTransferFailedSummaries() FileTransferSummaryGroup
|
||||
clientFileTransferLatestByFileID(string) FileTransferSummaryGroup
|
||||
clientFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup
|
||||
}
|
||||
|
||||
type serverFileTransferSummaryReader interface {
|
||||
serverFileTransferActiveSummaries() FileTransferSummaryGroup
|
||||
serverFileTransferCompletedSummaries() FileTransferSummaryGroup
|
||||
serverFileTransferFailedSummaries() FileTransferSummaryGroup
|
||||
serverFileTransferLatestByFileID(string) FileTransferSummaryGroup
|
||||
serverFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup
|
||||
}
|
||||
|
||||
var (
|
||||
errClientFileTransferSummaryNil = errors.New("client file transfer summary target is nil")
|
||||
errServerFileTransferSummaryNil = errors.New("server file transfer summary target is nil")
|
||||
errClientFileTransferSummaryUnsupported = errors.New("client file transfer summary target type is unsupported")
|
||||
errServerFileTransferSummaryUnsupported = errors.New("server file transfer summary target type is unsupported")
|
||||
)
|
||||
|
||||
func GetClientFileTransferActiveSummaries(c Client) (FileTransferSummaryGroup, error) {
|
||||
if c == nil {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.clientFileTransferActiveSummaries(), nil
|
||||
}
|
||||
|
||||
func GetServerFileTransferActiveSummaries(s Server) (FileTransferSummaryGroup, error) {
|
||||
if s == nil {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.serverFileTransferActiveSummaries(), nil
|
||||
}
|
||||
|
||||
func GetClientFileTransferCompletedSummaries(c Client) (FileTransferSummaryGroup, error) {
|
||||
if c == nil {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.clientFileTransferCompletedSummaries(), nil
|
||||
}
|
||||
|
||||
func GetServerFileTransferCompletedSummaries(s Server) (FileTransferSummaryGroup, error) {
|
||||
if s == nil {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.serverFileTransferCompletedSummaries(), nil
|
||||
}
|
||||
|
||||
func GetClientFileTransferFailedSummaries(c Client) (FileTransferSummaryGroup, error) {
|
||||
if c == nil {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.clientFileTransferFailedSummaries(), nil
|
||||
}
|
||||
|
||||
func GetServerFileTransferFailedSummaries(s Server) (FileTransferSummaryGroup, error) {
|
||||
if s == nil {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.serverFileTransferFailedSummaries(), nil
|
||||
}
|
||||
|
||||
func GetClientFileTransferLatestByFileID(c Client, fileID string) (FileTransferSummaryGroup, error) {
|
||||
if c == nil {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.clientFileTransferLatestByFileID(fileID), nil
|
||||
}
|
||||
|
||||
func GetServerFileTransferLatestByFileID(s Server, fileID string) (FileTransferSummaryGroup, error) {
|
||||
if s == nil {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.serverFileTransferLatestByFileID(fileID), nil
|
||||
}
|
||||
|
||||
func GetClientFileTransferLatestByFileIDQuery(c Client, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) {
|
||||
if c == nil {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.clientFileTransferLatestByFileIDQuery(fileID, query), nil
|
||||
}
|
||||
|
||||
func GetServerFileTransferLatestByFileIDQuery(s Server, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) {
|
||||
if s == nil {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||
}
|
||||
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||
if !ok {
|
||||
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||
}
|
||||
return reader.serverFileTransferLatestByFileIDQuery(fileID, query), nil
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientFileTransferActiveSummaries() FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(c.getFileTransferState().active())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientFileTransferCompletedSummaries() FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(c.getFileTransferState().completed())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientFileTransferFailedSummaries() FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(c.getFileTransferState().failed())
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileID(fileID))
|
||||
}
|
||||
|
||||
func (c *ClientCommon) clientFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query)))
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverFileTransferActiveSummaries() FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(s.getFileTransferState().active())
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverFileTransferCompletedSummaries() FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(s.getFileTransferState().completed())
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverFileTransferFailedSummaries() FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(s.getFileTransferState().failed())
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileID(fileID))
|
||||
}
|
||||
|
||||
func (s *ServerCommon) serverFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup {
|
||||
return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query)))
|
||||
}
|
||||
|
||||
func publicFileTransferSummaryGroup(src fileTransferSummaryGroup) FileTransferSummaryGroup {
|
||||
return FileTransferSummaryGroup{
|
||||
Send: publicFileTransferSummaries(src.Send),
|
||||
Receive: publicFileTransferSummaries(src.Receive),
|
||||
}
|
||||
}
|
||||
|
||||
func publicFileTransferSummaries(src []fileTransferSummary) []FileTransferSummary {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]FileTransferSummary, 0, len(src))
|
||||
for _, summary := range src {
|
||||
out = append(out, publicFileTransferSummary(summary))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func publicFileTransferSummary(summary fileTransferSummary) FileTransferSummary {
|
||||
return FileTransferSummary{
|
||||
Direction: publicFileTransferDirection(summary.Direction),
|
||||
Scope: summary.Scope,
|
||||
RuntimeScope: summary.RuntimeScope,
|
||||
TransportGeneration: summary.TransportGeneration,
|
||||
NetType: summary.NetType,
|
||||
Kind: summary.Kind,
|
||||
FileID: summary.FileID,
|
||||
Path: summary.Path,
|
||||
Received: summary.Received,
|
||||
Total: summary.Total,
|
||||
Percent: summary.Percent,
|
||||
Active: summary.Active,
|
||||
Terminal: summary.Terminal,
|
||||
Done: summary.Done,
|
||||
Failed: summary.Failed,
|
||||
Err: summary.Err,
|
||||
StartedAt: summary.StartedAt,
|
||||
UpdatedAt: summary.UpdatedAt,
|
||||
Duration: summary.Duration,
|
||||
RateBPS: summary.RateBPS,
|
||||
StepDuration: summary.StepDuration,
|
||||
InstantRateBPS: summary.InstantRateBPS,
|
||||
Time: summary.Time,
|
||||
Stage: summary.Stage,
|
||||
}
|
||||
}
|
||||
|
||||
func publicFileTransferDirection(direction fileTransferDirection) TransferDirection {
|
||||
switch direction {
|
||||
case fileTransferDirectionReceive:
|
||||
return TransferDirectionReceive
|
||||
default:
|
||||
return TransferDirectionSend
|
||||
}
|
||||
}
|
||||
|
||||
func internalFileTransferSummaryQuery(query FileTransferSummaryQuery) fileTransferSummaryQuery {
|
||||
return fileTransferSummaryQuery{
|
||||
Scope: query.Scope,
|
||||
RuntimeScope: query.RuntimeScope,
|
||||
TransportGeneration: query.TransportGeneration,
|
||||
MatchTransportGeneration: query.MatchTransportGeneration,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,202 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestGetClientFileTransferSummariesRejectNil(t *testing.T) {
|
||||
if _, err := GetClientFileTransferActiveSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||
t.Fatalf("GetClientFileTransferActiveSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetClientFileTransferCompletedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||
t.Fatalf("GetClientFileTransferCompletedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetClientFileTransferFailedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||
t.Fatalf("GetClientFileTransferFailedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetClientFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||
t.Fatalf("GetClientFileTransferLatestByFileID nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetClientFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||
t.Fatalf("GetClientFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetServerFileTransferActiveSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||
t.Fatalf("GetServerFileTransferActiveSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetServerFileTransferCompletedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||
t.Fatalf("GetServerFileTransferCompletedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetServerFileTransferFailedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||
t.Fatalf("GetServerFileTransferFailedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetServerFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||
t.Fatalf("GetServerFileTransferLatestByFileID nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||
}
|
||||
if _, err := GetServerFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||
t.Fatalf("GetServerFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientFileTransferSummariesPublicAPI(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
now := time.Unix(2000, 0)
|
||||
|
||||
client.publishSendFileEvent(FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "client-public", Size: 16},
|
||||
Received: 6,
|
||||
Total: 16,
|
||||
Percent: 37.5,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(time.Second),
|
||||
Duration: time.Second,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
active, err := GetClientFileTransferActiveSummaries(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientFileTransferActiveSummaries failed: %v", err)
|
||||
}
|
||||
if got, want := len(active.Send), 1; got != want {
|
||||
t.Fatalf("active send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active.Send[0].RuntimeScope, clientFileScope(); got != want {
|
||||
t.Fatalf("active runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got := active.Send[0].TransportGeneration; got != 0 {
|
||||
t.Fatalf("active transport generation mismatch: got %d want 0", got)
|
||||
}
|
||||
|
||||
latest, err := GetClientFileTransferLatestByFileID(client, "client-public")
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientFileTransferLatestByFileID failed: %v", err)
|
||||
}
|
||||
if got, want := len(latest.Send), 1; got != want {
|
||||
t.Fatalf("latest send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := latest.Send[0].Direction, TransferDirectionSend; got != want {
|
||||
t.Fatalf("latest direction mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
query, err := GetClientFileTransferLatestByFileIDQuery(client, "client-public", FileTransferSummaryQuery{
|
||||
RuntimeScope: clientFileScope(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientFileTransferLatestByFileIDQuery failed: %v", err)
|
||||
}
|
||||
if got, want := len(query.Send), 1; got != want {
|
||||
t.Fatalf("query send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
client.publishSendFileEvent(FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "client-public", Size: 16},
|
||||
Received: 16,
|
||||
Total: 16,
|
||||
Percent: 100,
|
||||
Done: true,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(2 * time.Second),
|
||||
Duration: 2 * time.Second,
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
|
||||
completed, err := GetClientFileTransferCompletedSummaries(client)
|
||||
if err != nil {
|
||||
t.Fatalf("GetClientFileTransferCompletedSummaries failed: %v", err)
|
||||
}
|
||||
if got, want := len(completed.Send), 1; got != want {
|
||||
t.Fatalf("completed send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := completed.Send[0].Done, true; got != want {
|
||||
t.Fatalf("completed done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetServerFileTransferLatestByFileIDQueryResolvesTransportGenerationPublicAPI(t *testing.T) {
|
||||
server := NewServer().(*ServerCommon)
|
||||
now := time.Unix(2100, 0)
|
||||
serverClient := &ClientConn{ClientID: "public-gen"}
|
||||
serverClient.markClientConnIdentityBound()
|
||||
serverClient.markClientConnStreamTransport()
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared-public", Size: 20},
|
||||
Received: 5,
|
||||
Total: 20,
|
||||
Time: now,
|
||||
})
|
||||
firstRuntimeScope := serverTransportScope(serverClient)
|
||||
logicalScope := serverFileScope(serverClient)
|
||||
|
||||
serverClient.markClientConnTransportDetached("read error", nil)
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared-public", Size: 20},
|
||||
Received: 9,
|
||||
Total: 20,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
secondRuntimeScope := serverTransportScope(serverClient)
|
||||
|
||||
legacy, err := GetServerFileTransferLatestByFileID(server, "shared-public")
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerFileTransferLatestByFileID failed: %v", err)
|
||||
}
|
||||
if got, want := len(legacy.Receive), 1; got != want {
|
||||
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := legacy.Receive[0].TransportGeneration, uint64(2); got != want {
|
||||
t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
allRuntime, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerFileTransferLatestByFileIDQuery scope failed: %v", err)
|
||||
}
|
||||
if got, want := len(allRuntime.Receive), 2; got != want {
|
||||
t.Fatalf("runtime receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
gen1, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
TransportGeneration: 1,
|
||||
MatchTransportGeneration: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-1 failed: %v", err)
|
||||
}
|
||||
if got, want := len(gen1.Receive), 1; got != want {
|
||||
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := gen1.Receive[0].RuntimeScope, firstRuntimeScope; got != want {
|
||||
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
gen2, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
TransportGeneration: 2,
|
||||
MatchTransportGeneration: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-2 failed: %v", err)
|
||||
}
|
||||
if got, want := len(gen2.Receive), 1; got != want {
|
||||
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := gen2.Receive[0].RuntimeScope, secondRuntimeScope; got != want {
|
||||
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,146 @@
|
||||
package notify
|
||||
|
||||
import "sort"
|
||||
|
||||
type fileTransferSummaryGroup struct {
|
||||
Send []fileTransferSummary
|
||||
Receive []fileTransferSummary
|
||||
}
|
||||
|
||||
type fileTransferSummaryQuery struct {
|
||||
Scope string
|
||||
RuntimeScope string
|
||||
TransportGeneration uint64
|
||||
MatchTransportGeneration bool
|
||||
}
|
||||
|
||||
type fileTransferQuery struct {
|
||||
monitor *fileTransferMonitor
|
||||
}
|
||||
|
||||
func newFileTransferQuery(m *fileTransferMonitor) fileTransferQuery {
|
||||
return fileTransferQuery{monitor: m}
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) active() fileTransferSummaryGroup {
|
||||
if q.monitor == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return groupFileTransferSummaries(q.monitor.activeSummaries())
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) completed() fileTransferSummaryGroup {
|
||||
if q.monitor == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.completedSummaries(), func(summary fileTransferSummary) bool {
|
||||
return summary.Done && !summary.Failed
|
||||
}))
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) failed() fileTransferSummaryGroup {
|
||||
return groupFileTransferSummaries(filterFileTransferSummaries(latestFileTransferSummaries(q.monitor), func(summary fileTransferSummary) bool {
|
||||
return summary.Failed
|
||||
}))
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) latestByFileID(fileID string) fileTransferSummaryGroup {
|
||||
if q.monitor == nil || fileID == "" {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return groupFileTransferSummaries(q.monitor.summariesByFileID(fileID))
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) latestSendByFileID(fileID string) []fileTransferSummary {
|
||||
if q.monitor == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionSend, fileID)
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) latestReceiveByFileID(fileID string) []fileTransferSummary {
|
||||
if q.monitor == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionReceive, fileID)
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup {
|
||||
if q.monitor == nil || fileID == "" {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.runtimeSummariesByFileID(fileID), func(summary fileTransferSummary) bool {
|
||||
return fileTransferSummaryQueryMatch(summary, query)
|
||||
}))
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||
if q.monitor == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionSend, fileID), func(summary fileTransferSummary) bool {
|
||||
return fileTransferSummaryQueryMatch(summary, query)
|
||||
})
|
||||
}
|
||||
|
||||
func (q fileTransferQuery) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||
if q.monitor == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionReceive, fileID), func(summary fileTransferSummary) bool {
|
||||
return fileTransferSummaryQueryMatch(summary, query)
|
||||
})
|
||||
}
|
||||
|
||||
func latestFileTransferSummaries(m *fileTransferMonitor) []fileTransferSummary {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
summaries := append([]fileTransferSummary{}, m.activeSummaries()...)
|
||||
summaries = append(summaries, m.completedSummaries()...)
|
||||
sort.Slice(summaries, func(i int, j int) bool {
|
||||
return fileTransferSummarySortKey(summaries[i]) < fileTransferSummarySortKey(summaries[j])
|
||||
})
|
||||
return summaries
|
||||
}
|
||||
|
||||
func fileTransferSummarySortKey(summary fileTransferSummary) string {
|
||||
return fileTransferMonitorKey(summary.Direction, summary.Scope, summary.FileID)
|
||||
}
|
||||
|
||||
func groupFileTransferSummaries(src []fileTransferSummary) fileTransferSummaryGroup {
|
||||
var group fileTransferSummaryGroup
|
||||
for _, summary := range src {
|
||||
switch summary.Direction {
|
||||
case fileTransferDirectionReceive:
|
||||
group.Receive = append(group.Receive, summary)
|
||||
case fileTransferDirectionSend:
|
||||
group.Send = append(group.Send, summary)
|
||||
}
|
||||
}
|
||||
return group
|
||||
}
|
||||
|
||||
func filterFileTransferSummaries(src []fileTransferSummary, keep func(fileTransferSummary) bool) []fileTransferSummary {
|
||||
out := make([]fileTransferSummary, 0, len(src))
|
||||
for _, summary := range src {
|
||||
if !keep(summary) {
|
||||
continue
|
||||
}
|
||||
out = append(out, summary)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fileTransferSummaryQueryMatch(summary fileTransferSummary, query fileTransferSummaryQuery) bool {
|
||||
if query.Scope != "" && normalizeFileScope(summary.Scope) != normalizeFileScope(query.Scope) {
|
||||
return false
|
||||
}
|
||||
if query.RuntimeScope != "" && normalizeFileScope(summary.RuntimeScope) != normalizeFileScope(query.RuntimeScope) {
|
||||
return false
|
||||
}
|
||||
if query.MatchTransportGeneration && summary.TransportGeneration != query.TransportGeneration {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -0,0 +1,248 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileTransferQueryActiveCompletedAndFailed(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
query := newFileTransferQuery(monitor)
|
||||
now := time.Unix(800, 0)
|
||||
serverClient := &ClientConn{ClientID: "client-a"}
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "active-send", Size: 10},
|
||||
Received: 4,
|
||||
Total: 10,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "done-recv", Size: 12},
|
||||
Received: 12,
|
||||
Total: 12,
|
||||
Done: true,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileAbort,
|
||||
Packet: FilePacket{FileID: "failed-send", Size: 8, Stage: "chunk"},
|
||||
Received: 3,
|
||||
Total: 8,
|
||||
Time: now.Add(2 * time.Second),
|
||||
Err: errString("send failed"),
|
||||
})
|
||||
|
||||
active := query.active()
|
||||
if got, want := len(active.Send), 1; got != want {
|
||||
t.Fatalf("active send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active.Send[0].FileID, "active-send"; got != want {
|
||||
t.Fatalf("active send fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := len(active.Receive), 0; got != want {
|
||||
t.Fatalf("active receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
completed := query.completed()
|
||||
if got, want := len(completed.Send), 0; got != want {
|
||||
t.Fatalf("completed send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := len(completed.Receive), 1; got != want {
|
||||
t.Fatalf("completed receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := completed.Receive[0].FileID, "done-recv"; got != want {
|
||||
t.Fatalf("completed receive fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := completed.Receive[0].Done, true; got != want {
|
||||
t.Fatalf("completed receive done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
failed := query.failed()
|
||||
if got, want := len(failed.Send), 1; got != want {
|
||||
t.Fatalf("failed send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := failed.Send[0].FileID, "failed-send"; got != want {
|
||||
t.Fatalf("failed send fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := failed.Send[0].Failed, true; got != want {
|
||||
t.Fatalf("failed send flag mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := len(failed.Receive), 0; got != want {
|
||||
t.Fatalf("failed receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferQueryLatestByFileID(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
query := newFileTransferQuery(monitor)
|
||||
now := time.Unix(900, 0)
|
||||
serverClientA := &ClientConn{ClientID: "client-a"}
|
||||
serverClientB := &ClientConn{ClientID: "client-b"}
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 6,
|
||||
Total: 20,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClientA,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 9,
|
||||
Total: 20,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClientB,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 20,
|
||||
Total: 20,
|
||||
Done: true,
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
|
||||
group := query.latestByFileID("shared")
|
||||
if got, want := len(group.Send), 1; got != want {
|
||||
t.Fatalf("group send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := group.Send[0].FileID, "shared"; got != want {
|
||||
t.Fatalf("group send fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := len(group.Receive), 2; got != want {
|
||||
t.Fatalf("group receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := group.Receive[0].Scope, serverFileScope(serverClientA); got != want {
|
||||
t.Fatalf("first receive scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := group.Receive[1].Scope, serverFileScope(serverClientB); got != want {
|
||||
t.Fatalf("second receive scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
send := query.latestSendByFileID("shared")
|
||||
if got, want := len(send), 1; got != want {
|
||||
t.Fatalf("send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := send[0].Received, int64(6); got != want {
|
||||
t.Fatalf("send received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
receive := query.latestReceiveByFileID("shared")
|
||||
if got, want := len(receive), 2; got != want {
|
||||
t.Fatalf("receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := receive[0].Scope, serverFileScope(serverClientA); got != want {
|
||||
t.Fatalf("receive first scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := receive[1].Done, true; got != want {
|
||||
t.Fatalf("receive second done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientTransferQueryFollowsPublishedEvents(t *testing.T) {
|
||||
client := NewClient().(*ClientCommon)
|
||||
now := time.Unix(1000, 0)
|
||||
|
||||
client.publishSendFileEvent(FileEvent{
|
||||
NetType: NET_CLIENT,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "client-done", Size: 16},
|
||||
Received: 16,
|
||||
Total: 16,
|
||||
Done: true,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(time.Second),
|
||||
Duration: time.Second,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
completed := client.getFileTransferState().completed()
|
||||
if got, want := len(completed.Send), 1; got != want {
|
||||
t.Fatalf("client completed send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := completed.Send[0].FileID, "client-done"; got != want {
|
||||
t.Fatalf("client completed send fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferQueryLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
query := newFileTransferQuery(monitor)
|
||||
now := time.Unix(960, 0)
|
||||
serverClient := &ClientConn{ClientID: "client-gen"}
|
||||
serverClient.markClientConnIdentityBound()
|
||||
serverClient.markClientConnStreamTransport()
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 5,
|
||||
Total: 20,
|
||||
Time: now,
|
||||
})
|
||||
firstRuntimeScope := serverTransportScope(serverClient)
|
||||
logicalScope := serverFileScope(serverClient)
|
||||
|
||||
serverClient.markClientConnTransportDetached("read error", nil)
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||
Received: 9,
|
||||
Total: 20,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
secondRuntimeScope := serverTransportScope(serverClient)
|
||||
if secondRuntimeScope == firstRuntimeScope {
|
||||
t.Fatalf("runtime scope should change across transport generations: got %q", secondRuntimeScope)
|
||||
}
|
||||
|
||||
legacy := query.latestReceiveByFileID("shared")
|
||||
if got, want := len(legacy), 1; got != want {
|
||||
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := legacy[0].TransportGeneration, uint64(2); got != want {
|
||||
t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
runtimeAll := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
})
|
||||
if got, want := len(runtimeAll), 2; got != want {
|
||||
t.Fatalf("runtime receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
gen1 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
TransportGeneration: 1,
|
||||
MatchTransportGeneration: true,
|
||||
})
|
||||
if got, want := len(gen1), 1; got != want {
|
||||
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want {
|
||||
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
gen2 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
TransportGeneration: 2,
|
||||
MatchTransportGeneration: true,
|
||||
})
|
||||
if got, want := len(gen2), 1; got != want {
|
||||
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want {
|
||||
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fileTransferDirection uint8
|
||||
|
||||
const (
|
||||
fileTransferDirectionReceive fileTransferDirection = iota
|
||||
fileTransferDirectionSend
|
||||
)
|
||||
|
||||
type fileTransferSnapshot struct {
|
||||
Direction fileTransferDirection
|
||||
Scope string
|
||||
RuntimeScope string
|
||||
TransportGeneration uint64
|
||||
NetType NetType
|
||||
Kind EnvelopeKind
|
||||
FileID string
|
||||
Path string
|
||||
Received int64
|
||||
Total int64
|
||||
Percent float64
|
||||
Done bool
|
||||
Err error
|
||||
StartedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Duration time.Duration
|
||||
RateBPS float64
|
||||
StepDuration time.Duration
|
||||
InstantRateBPS float64
|
||||
Time time.Time
|
||||
Stage string
|
||||
}
|
||||
|
||||
func fileTransferMonitorScope(event FileEvent) string {
|
||||
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
|
||||
return serverFileScope(logical)
|
||||
}
|
||||
return clientFileScope()
|
||||
}
|
||||
|
||||
func fileTransferRuntimeScope(event FileEvent) string {
|
||||
if event.TransportConn != nil {
|
||||
return serverTransportScopeForTransport(event.TransportConn)
|
||||
}
|
||||
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
|
||||
return serverTransportScope(logical)
|
||||
}
|
||||
return clientFileScope()
|
||||
}
|
||||
|
||||
func fileTransferTransportGeneration(event FileEvent) uint64 {
|
||||
if event.TransportConn != nil {
|
||||
return event.TransportConn.TransportGeneration()
|
||||
}
|
||||
logical := fileEventLogicalConnSnapshot(event)
|
||||
if logical == nil {
|
||||
return 0
|
||||
}
|
||||
return logical.transportGenerationSnapshot()
|
||||
}
|
||||
|
||||
func fileTransferMonitorKey(direction fileTransferDirection, scope string, fileID string) string {
|
||||
if fileID == "" {
|
||||
return ""
|
||||
}
|
||||
return strconv.Itoa(int(direction)) + "|" + scope + "|" + fileID
|
||||
}
|
||||
|
||||
func fileTransferRuntimeMonitorKey(direction fileTransferDirection, runtimeScope string, fileID string) string {
|
||||
return fileTransferMonitorKey(direction, normalizeFileScope(runtimeScope), fileID)
|
||||
}
|
||||
|
||||
func fileTransferSnapshotFromEvent(direction fileTransferDirection, event FileEvent) fileTransferSnapshot {
|
||||
return fileTransferSnapshot{
|
||||
Direction: direction,
|
||||
Scope: fileTransferMonitorScope(event),
|
||||
RuntimeScope: fileTransferRuntimeScope(event),
|
||||
TransportGeneration: fileTransferTransportGeneration(event),
|
||||
NetType: event.NetType,
|
||||
Kind: event.Kind,
|
||||
FileID: event.Packet.FileID,
|
||||
Path: event.Path,
|
||||
Received: event.Received,
|
||||
Total: event.Total,
|
||||
Percent: event.Percent,
|
||||
Done: event.Done,
|
||||
Err: event.Err,
|
||||
StartedAt: event.StartedAt,
|
||||
UpdatedAt: event.UpdatedAt,
|
||||
Duration: event.Duration,
|
||||
RateBPS: event.RateBPS,
|
||||
StepDuration: event.StepDuration,
|
||||
InstantRateBPS: event.InstantRateBPS,
|
||||
Time: event.Time,
|
||||
Stage: event.Packet.Stage,
|
||||
}
|
||||
}
|
||||
|
||||
func isFileTransferTerminal(kind EnvelopeKind) bool {
|
||||
return kind == EnvelopeFileEnd || kind == EnvelopeFileAbort
|
||||
}
|
||||
|
||||
func isFileTransferObservable(kind EnvelopeKind) bool {
|
||||
return kind == EnvelopeFileMeta || kind == EnvelopeFileChunk || kind == EnvelopeFileEnd || kind == EnvelopeFileAbort
|
||||
}
|
||||
|
||||
func sortedFileTransferSnapshots(src map[string]fileTransferSnapshot) []fileTransferSnapshot {
|
||||
keys := make([]string, 0, len(src))
|
||||
for key := range src {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
out := make([]fileTransferSnapshot, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
out = append(out, src[key])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func latestFileTransferSnapshotsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSnapshot {
|
||||
merged := make(map[string]fileTransferSnapshot, len(active)+len(completed))
|
||||
for key, snapshot := range completed {
|
||||
merged[key] = snapshot
|
||||
}
|
||||
for key, snapshot := range active {
|
||||
merged[key] = snapshot
|
||||
}
|
||||
return sortedFileTransferSnapshots(merged)
|
||||
}
|
||||
|
||||
func filteredFileTransferSnapshots(src map[string]fileTransferSnapshot, direction fileTransferDirection) []fileTransferSnapshot {
|
||||
out := make([]fileTransferSnapshot, 0, len(src))
|
||||
for _, snapshot := range sortedFileTransferSnapshots(src) {
|
||||
if snapshot.Direction != direction {
|
||||
continue
|
||||
}
|
||||
out = append(out, snapshot)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterFileTransferSnapshotsByFileID(src []fileTransferSnapshot, fileID string) []fileTransferSnapshot {
|
||||
out := make([]fileTransferSnapshot, 0, len(src))
|
||||
for _, snapshot := range src {
|
||||
if snapshot.FileID != fileID {
|
||||
continue
|
||||
}
|
||||
out = append(out, snapshot)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterFileTransferSnapshotsByDirectionAndFileID(src []fileTransferSnapshot, direction fileTransferDirection, fileID string) []fileTransferSnapshot {
|
||||
out := make([]fileTransferSnapshot, 0, len(src))
|
||||
for _, snapshot := range src {
|
||||
if snapshot.Direction != direction || snapshot.FileID != fileID {
|
||||
continue
|
||||
}
|
||||
out = append(out, snapshot)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func fileTransferSnapshotOlder(candidate fileTransferSnapshot, current fileTransferSnapshot, candidateKey string, currentKey string) bool {
|
||||
candidateTime := fileTransferSnapshotCompletedTime(candidate)
|
||||
currentTime := fileTransferSnapshotCompletedTime(current)
|
||||
if candidateTime.Before(currentTime) {
|
||||
return true
|
||||
}
|
||||
if currentTime.Before(candidateTime) {
|
||||
return false
|
||||
}
|
||||
return candidateKey < currentKey
|
||||
}
|
||||
|
||||
func fileTransferSnapshotCompletedTime(snapshot fileTransferSnapshot) time.Time {
|
||||
if !snapshot.Time.IsZero() {
|
||||
return snapshot.Time
|
||||
}
|
||||
if !snapshot.UpdatedAt.IsZero() {
|
||||
return snapshot.UpdatedAt
|
||||
}
|
||||
return snapshot.StartedAt
|
||||
}
|
||||
@@ -0,0 +1,302 @@
|
||||
package notify
|
||||
|
||||
import itransfer "b612.me/notify/internal/transfer"
|
||||
|
||||
type fileTransferState struct {
|
||||
monitor *fileTransferMonitor
|
||||
query fileTransferQuery
|
||||
runtime *transferRuntime
|
||||
}
|
||||
|
||||
func newFileTransferState() *fileTransferState {
|
||||
return newFileTransferStateWithConfig(defaultFileTransferConfig())
|
||||
}
|
||||
|
||||
func newFileTransferStateWithConfig(cfg fileTransferConfig) *fileTransferState {
|
||||
monitor := newFileTransferMonitorWithConfig(cfg)
|
||||
return &fileTransferState{
|
||||
monitor: monitor,
|
||||
query: newFileTransferQuery(monitor),
|
||||
runtime: newTransferRuntime(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fileTransferState) observe(direction fileTransferDirection, event FileEvent) {
|
||||
if s == nil || s.monitor == nil {
|
||||
return
|
||||
}
|
||||
s.monitor.observe(direction, event)
|
||||
s.observeRuntime(direction, event)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) observeMonitorOnly(direction fileTransferDirection, event FileEvent) {
|
||||
if s == nil || s.monitor == nil {
|
||||
return
|
||||
}
|
||||
s.monitor.observe(direction, event)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) applyConfig(cfg fileTransferConfig) {
|
||||
if s == nil || s.monitor == nil {
|
||||
return
|
||||
}
|
||||
s.monitor.applyConfig(cfg)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) monitorView() *fileTransferMonitor {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.monitor
|
||||
}
|
||||
|
||||
func (s *fileTransferState) active() fileTransferSummaryGroup {
|
||||
if s == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return s.query.active()
|
||||
}
|
||||
|
||||
func (s *fileTransferState) completed() fileTransferSummaryGroup {
|
||||
if s == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return s.query.completed()
|
||||
}
|
||||
|
||||
func (s *fileTransferState) failed() fileTransferSummaryGroup {
|
||||
if s == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return s.query.failed()
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latest(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) {
|
||||
if s == nil || s.monitor == nil {
|
||||
return fileTransferSummary{}, false
|
||||
}
|
||||
return s.monitor.latestSummary(direction, scope, fileID)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latestByFileID(fileID string) fileTransferSummaryGroup {
|
||||
if s == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return s.query.latestByFileID(fileID)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latestSendByFileID(fileID string) []fileTransferSummary {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.query.latestSendByFileID(fileID)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latestReceiveByFileID(fileID string) []fileTransferSummary {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.query.latestReceiveByFileID(fileID)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup {
|
||||
if s == nil {
|
||||
return fileTransferSummaryGroup{}
|
||||
}
|
||||
return s.query.latestByFileIDQuery(fileID, query)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.query.latestSendByFileIDQuery(fileID, query)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return s.query.latestReceiveByFileIDQuery(fileID, query)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) observeRuntime(direction fileTransferDirection, event FileEvent) {
|
||||
if s == nil || s.runtime == nil || event.Packet.FileID == "" {
|
||||
return
|
||||
}
|
||||
runtimeScope := transferRuntimeScopeForEvent(event)
|
||||
publicScope := transferRuntimePublicScopeForEvent(event)
|
||||
transportGeneration := transferRuntimeTransportGenerationForEvent(event)
|
||||
s.ensureRuntimeTransfer(direction, runtimeScope, publicScope, transportGeneration, event)
|
||||
s.recordRuntimeStage(direction, runtimeScope, event.Packet.FileID, runtimeTransferStageForEvent(event))
|
||||
switch event.Kind {
|
||||
case EnvelopeFileChunk:
|
||||
s.runtime.activate(direction, runtimeScope, event.Packet.FileID)
|
||||
s.syncRuntimeProgress(direction, runtimeScope, event)
|
||||
case EnvelopeFileEnd:
|
||||
s.runtime.activate(direction, runtimeScope, event.Packet.FileID)
|
||||
s.syncRuntimeProgress(direction, runtimeScope, event)
|
||||
switch direction {
|
||||
case fileTransferDirectionSend:
|
||||
s.runtime.beginCommit(direction, runtimeScope, event.Packet.FileID)
|
||||
case fileTransferDirectionReceive:
|
||||
s.runtime.beginVerify(direction, runtimeScope, event.Packet.FileID)
|
||||
}
|
||||
s.runtime.complete(direction, runtimeScope, event.Packet.FileID)
|
||||
case EnvelopeFileAbort:
|
||||
s.syncRuntimeProgress(direction, runtimeScope, event)
|
||||
s.recordRuntimeFailureStage(direction, runtimeScope, event.Packet.FileID, event.Packet.Stage)
|
||||
s.runtime.abort(direction, runtimeScope, event.Packet.FileID, event.Err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fileTransferState) ensureRuntimeTransfer(direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, event FileEvent) {
|
||||
if s == nil || s.runtime == nil || event.Packet.FileID == "" {
|
||||
return
|
||||
}
|
||||
s.runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{
|
||||
ID: event.Packet.FileID,
|
||||
Channel: itransfer.DataChannel,
|
||||
Size: event.Packet.Size,
|
||||
Checksum: event.Packet.Checksum,
|
||||
Metadata: buildKernelTransferMetadata(event),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *fileTransferState) startRuntimeSendSession(runtimeScope string, publicScope string, transportGeneration uint64, session *fileSendSession) {
|
||||
if s == nil || s.runtime == nil || session == nil || session.fileID == "" {
|
||||
return
|
||||
}
|
||||
s.runtime.ensureTransferDescriptor(fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{
|
||||
ID: session.fileID,
|
||||
Channel: itransfer.DataChannel,
|
||||
Size: session.size,
|
||||
Checksum: session.checksum,
|
||||
Metadata: itransfer.Metadata{
|
||||
"name": session.name,
|
||||
"path": session.path,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func buildKernelTransferMetadata(event FileEvent) itransfer.Metadata {
|
||||
metadata := make(itransfer.Metadata)
|
||||
if event.Packet.Name != "" {
|
||||
metadata["name"] = event.Packet.Name
|
||||
}
|
||||
if event.Path != "" {
|
||||
metadata["path"] = event.Path
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
|
||||
func (s *fileTransferState) syncRuntimeProgress(direction fileTransferDirection, scope string, event FileEvent) {
|
||||
if s == nil || s.runtime == nil {
|
||||
return
|
||||
}
|
||||
snapshot, ok := s.runtimeSnapshot(direction, scope, event.Packet.FileID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
progress := event.Received
|
||||
if progress < 0 {
|
||||
progress = 0
|
||||
}
|
||||
switch direction {
|
||||
case fileTransferDirectionReceive:
|
||||
if delta := progress - snapshot.ReceivedBytes; delta > 0 {
|
||||
s.runtime.recordReceive(direction, scope, event.Packet.FileID, delta)
|
||||
}
|
||||
default:
|
||||
if delta := progress - snapshot.SentBytes; delta > 0 {
|
||||
s.runtime.recordSend(direction, scope, event.Packet.FileID, delta)
|
||||
}
|
||||
s.runtime.setAckedBytes(direction, scope, event.Packet.FileID, progress)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fileTransferState) recordRuntimeRetry(direction fileTransferDirection, scope string, fileID string) {
|
||||
if s == nil || s.runtime == nil || fileID == "" {
|
||||
return
|
||||
}
|
||||
s.runtime.recordRetry(direction, scope, fileID)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) recordRuntimeTimeout(direction fileTransferDirection, scope string, fileID string) {
|
||||
if s == nil || s.runtime == nil || fileID == "" {
|
||||
return
|
||||
}
|
||||
s.runtime.recordTimeout(direction, scope, fileID)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) recordRuntimeStage(direction fileTransferDirection, scope string, fileID string, stage string) {
|
||||
if s == nil || s.runtime == nil || fileID == "" || stage == "" {
|
||||
return
|
||||
}
|
||||
s.runtime.recordStage(direction, scope, fileID, stage)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) recordRuntimeFailureStage(direction fileTransferDirection, scope string, fileID string, stage string) {
|
||||
if s == nil || s.runtime == nil || fileID == "" || stage == "" {
|
||||
return
|
||||
}
|
||||
s.runtime.recordFailureStage(direction, scope, fileID, stage)
|
||||
}
|
||||
|
||||
func (s *fileTransferState) runtimeSnapshot(direction fileTransferDirection, scope string, transferID string) (itransfer.Snapshot, bool) {
|
||||
if s == nil || s.runtime == nil || transferID == "" {
|
||||
return itransfer.Snapshot{}, false
|
||||
}
|
||||
return s.runtime.snapshot(direction, scope, transferID)
|
||||
}
|
||||
|
||||
func transferRuntimeScopeForEvent(event FileEvent) string {
|
||||
if event.TransportConn != nil {
|
||||
return serverTransportScopeForTransport(event.TransportConn)
|
||||
}
|
||||
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
|
||||
return serverTransportScope(logical)
|
||||
}
|
||||
return clientFileScope()
|
||||
}
|
||||
|
||||
func transferRuntimePublicScopeForEvent(event FileEvent) string {
|
||||
return fileTransferMonitorScope(event)
|
||||
}
|
||||
|
||||
func transferRuntimeTransportGenerationForEvent(event FileEvent) uint64 {
|
||||
if event.TransportConn != nil {
|
||||
return event.TransportConn.TransportGeneration()
|
||||
}
|
||||
logical := fileEventLogicalConnSnapshot(event)
|
||||
if logical == nil {
|
||||
return 0
|
||||
}
|
||||
return logical.transportGenerationSnapshot()
|
||||
}
|
||||
|
||||
func runtimeTransferStageForEvent(event FileEvent) string {
|
||||
if event.Packet.Stage != "" {
|
||||
return event.Packet.Stage
|
||||
}
|
||||
return fileStageByKind(event.Kind)
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getTransferRuntime() *transferRuntime {
|
||||
return c.getFileTransferState().runtime
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getTransferRuntime() *transferRuntime {
|
||||
return s.getFileTransferState().runtime
|
||||
}
|
||||
|
||||
func (c *ClientCommon) getFileTransferState() *fileTransferState {
|
||||
return c.getLogicalSessionState().fileTransfers
|
||||
}
|
||||
|
||||
func (s *ServerCommon) getFileTransferState() *fileTransferState {
|
||||
return s.getLogicalSessionState().fileTransfers
|
||||
}
|
||||
@@ -0,0 +1,371 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
itransfer "b612.me/notify/internal/transfer"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFileTransferStateObserveFeedsQuery(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
now := time.Unix(1100, 0)
|
||||
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "state-active", Size: 32},
|
||||
Received: 10,
|
||||
Total: 32,
|
||||
Time: now,
|
||||
})
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
Kind: EnvelopeFileAbort,
|
||||
Packet: FilePacket{FileID: "state-failed", Size: 16, Stage: "chunk"},
|
||||
Received: 6,
|
||||
Total: 16,
|
||||
Time: now.Add(time.Second),
|
||||
Err: errString("receive failed"),
|
||||
})
|
||||
|
||||
active := state.active()
|
||||
if got, want := len(active.Send), 1; got != want {
|
||||
t.Fatalf("active send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active.Send[0].FileID, "state-active"; got != want {
|
||||
t.Fatalf("active send fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
failed := state.failed()
|
||||
if got, want := len(failed.Receive), 1; got != want {
|
||||
t.Fatalf("failed receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := failed.Receive[0].FileID, "state-failed"; got != want {
|
||||
t.Fatalf("failed receive fileID mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := failed.Receive[0].Failed, true; got != want {
|
||||
t.Fatalf("failed receive flag mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferStateLatestHelpers(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
now := time.Unix(1200, 0)
|
||||
serverClient := &ClientConn{ClientID: "client-a"}
|
||||
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "state-shared", Size: 40},
|
||||
Received: 15,
|
||||
Total: 40,
|
||||
Time: now,
|
||||
})
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "state-shared", Size: 40},
|
||||
Received: 40,
|
||||
Total: 40,
|
||||
Done: true,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
summary, ok := state.latest(fileTransferDirectionSend, clientFileScope(), "state-shared")
|
||||
if !ok {
|
||||
t.Fatal("latest send summary should exist")
|
||||
}
|
||||
if got, want := summary.Received, int64(15); got != want {
|
||||
t.Fatalf("latest send received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
group := state.latestByFileID("state-shared")
|
||||
if got, want := len(group.Send), 1; got != want {
|
||||
t.Fatalf("latest group send count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := len(group.Receive), 1; got != want {
|
||||
t.Fatalf("latest group receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := group.Receive[0].Scope, serverFileScope(serverClient); got != want {
|
||||
t.Fatalf("latest group receive scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
send := state.latestSendByFileID("state-shared")
|
||||
if got, want := len(send), 1; got != want {
|
||||
t.Fatalf("latest send list count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
receive := state.latestReceiveByFileID("state-shared")
|
||||
if got, want := len(receive), 1; got != want {
|
||||
t.Fatalf("latest receive list count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := receive[0].Done, true; got != want {
|
||||
t.Fatalf("latest receive done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferStateObserveFeedsTransferRuntime(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
now := time.Unix(1300, 0)
|
||||
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
|
||||
Path: "/tmp/demo.bin",
|
||||
Time: now,
|
||||
})
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
|
||||
Received: 8,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
|
||||
Received: 8,
|
||||
Done: true,
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
|
||||
sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "kernel-send")
|
||||
if !ok {
|
||||
t.Fatal("send snapshot should exist")
|
||||
}
|
||||
if got, want := sendSnapshot.State, itransfer.StateDone; got != want {
|
||||
t.Fatalf("send state = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want {
|
||||
t.Fatalf("send direction = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := sendSnapshot.SentBytes, int64(8); got != want {
|
||||
t.Fatalf("send bytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := sendSnapshot.AckedBytes, int64(8); got != want {
|
||||
t.Fatalf("send acked bytes = %d, want %d", got, want)
|
||||
}
|
||||
if got := sendSnapshot.Metadata["name"]; got != "demo.bin" {
|
||||
t.Fatalf("send metadata name = %q, want demo.bin", got)
|
||||
}
|
||||
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||
Time: now,
|
||||
})
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||
Received: 6,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||
Received: 6,
|
||||
Done: true,
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
|
||||
recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, clientFileScope(), "kernel-recv")
|
||||
if !ok {
|
||||
t.Fatal("receive snapshot should exist")
|
||||
}
|
||||
if got, want := recvSnapshot.State, itransfer.StateDone; got != want {
|
||||
t.Fatalf("receive state = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want {
|
||||
t.Fatalf("receive direction = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := recvSnapshot.ReceivedBytes, int64(6); got != want {
|
||||
t.Fatalf("receive bytes = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferStateRuntimeResilienceStats(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
session := &fileSendSession{
|
||||
fileID: "kernel-retry",
|
||||
path: "/tmp/retry.bin",
|
||||
name: "retry.bin",
|
||||
size: 5,
|
||||
checksum: "sum-retry",
|
||||
}
|
||||
|
||||
state.startRuntimeSendSession(clientFileScope(), clientFileScope(), 0, session)
|
||||
state.recordRuntimeTimeout(fileTransferDirectionSend, clientFileScope(), session.fileID)
|
||||
state.recordRuntimeRetry(fileTransferDirectionSend, clientFileScope(), session.fileID)
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileAbort,
|
||||
Packet: FilePacket{FileID: session.fileID, Name: session.name, Size: session.size, Checksum: session.checksum, Stage: "meta"},
|
||||
Received: 0,
|
||||
Err: errString("ack timeout"),
|
||||
Time: time.Unix(1400, 0),
|
||||
})
|
||||
|
||||
snapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), session.fileID)
|
||||
if !ok {
|
||||
t.Fatal("runtime snapshot should exist")
|
||||
}
|
||||
if got, want := snapshot.TimeoutCount, 1; got != want {
|
||||
t.Fatalf("timeout count = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.RetryCount, 1; got != want {
|
||||
t.Fatalf("retry count = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.State, itransfer.StateAborted; got != want {
|
||||
t.Fatalf("state = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := snapshot.LastError, "ack timeout"; got != want {
|
||||
t.Fatalf("last error = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := snapshot.Stage, "meta"; got != want {
|
||||
t.Fatalf("stage = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := snapshot.LastFailureStage, "meta"; got != want {
|
||||
t.Fatalf("last failure stage = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferStateRuntimeSeparatesScopeAndDirection(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
now := time.Unix(1450, 0)
|
||||
serverClient := &ClientConn{ClientID: "client-b"}
|
||||
|
||||
state.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "shared-id", Name: "send.bin", Size: 4, Checksum: "sum-send"},
|
||||
Time: now,
|
||||
})
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "shared-id", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "shared-id")
|
||||
if !ok {
|
||||
t.Fatal("send snapshot should exist")
|
||||
}
|
||||
if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want {
|
||||
t.Fatalf("send direction = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, serverTransportScope(serverClient), "shared-id")
|
||||
if !ok {
|
||||
t.Fatal("receive snapshot should exist")
|
||||
}
|
||||
if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want {
|
||||
t.Fatalf("receive direction = %v, want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferStateRuntimeSeparatesServerTransportGenerations(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
now := time.Unix(1500, 0)
|
||||
serverClient := &ClientConn{ClientID: "client-gen"}
|
||||
serverClient.markClientConnIdentityBound()
|
||||
serverClient.markClientConnStreamTransport()
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "shared-transfer", Name: "recv-a.bin", Size: 4, Checksum: "sum-a"},
|
||||
Time: now,
|
||||
})
|
||||
|
||||
firstScope := serverTransportScope(serverClient)
|
||||
firstSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, firstScope, "shared-transfer")
|
||||
if !ok {
|
||||
t.Fatal("first generation snapshot should exist")
|
||||
}
|
||||
if got, want := firstSnapshot.Metadata[transferMetadataScopeKey], serverFileScope(serverClient); got != want {
|
||||
t.Fatalf("first generation public scope metadata = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
serverClient.markClientConnTransportDetached("read error", nil)
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileMeta,
|
||||
Packet: FilePacket{FileID: "shared-transfer", Name: "recv-b.bin", Size: 6, Checksum: "sum-b"},
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
secondScope := serverTransportScope(serverClient)
|
||||
if secondScope == firstScope {
|
||||
t.Fatalf("runtime scope should change across transport generations: got %q", secondScope)
|
||||
}
|
||||
secondSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, secondScope, "shared-transfer")
|
||||
if !ok {
|
||||
t.Fatal("second generation snapshot should exist")
|
||||
}
|
||||
if got, want := transferSnapshotRuntimeScope(secondSnapshot.Metadata), secondScope; got != want {
|
||||
t.Fatalf("second generation runtime scope metadata = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := transferSnapshotTransportGeneration(secondSnapshot.Metadata), uint64(2); got != want {
|
||||
t.Fatalf("second generation transport generation metadata = %d, want %d", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileTransferStateLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) {
|
||||
state := newFileTransferState()
|
||||
now := time.Unix(1510, 0)
|
||||
serverClient := &ClientConn{ClientID: "client-query-gen"}
|
||||
serverClient.markClientConnIdentityBound()
|
||||
serverClient.markClientConnStreamTransport()
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 30},
|
||||
Received: 6,
|
||||
Total: 30,
|
||||
Time: now,
|
||||
})
|
||||
firstRuntimeScope := serverTransportScope(serverClient)
|
||||
logicalScope := serverFileScope(serverClient)
|
||||
|
||||
serverClient.markClientConnTransportDetached("read error", nil)
|
||||
serverClient.markClientConnTransportAttached()
|
||||
|
||||
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClient,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "shared", Size: 30},
|
||||
Received: 10,
|
||||
Total: 30,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
secondRuntimeScope := serverTransportScope(serverClient)
|
||||
|
||||
legacy := state.latestReceiveByFileID("shared")
|
||||
if got, want := len(legacy), 1; got != want {
|
||||
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
gen1 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
TransportGeneration: 1,
|
||||
MatchTransportGeneration: true,
|
||||
})
|
||||
if got, want := len(gen1), 1; got != want {
|
||||
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want {
|
||||
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
|
||||
gen2 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||
Scope: logicalScope,
|
||||
TransportGeneration: 2,
|
||||
MatchTransportGeneration: true,
|
||||
})
|
||||
if got, want := len(gen2), 1; got != want {
|
||||
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want {
|
||||
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,210 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fileTransferSummary struct {
|
||||
Direction fileTransferDirection
|
||||
Scope string
|
||||
RuntimeScope string
|
||||
TransportGeneration uint64
|
||||
NetType NetType
|
||||
Kind EnvelopeKind
|
||||
FileID string
|
||||
Path string
|
||||
Received int64
|
||||
Total int64
|
||||
Percent float64
|
||||
Active bool
|
||||
Terminal bool
|
||||
Done bool
|
||||
Failed bool
|
||||
Err error
|
||||
StartedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
Duration time.Duration
|
||||
RateBPS float64
|
||||
StepDuration time.Duration
|
||||
InstantRateBPS float64
|
||||
Time time.Time
|
||||
Stage string
|
||||
}
|
||||
|
||||
type fileTransferSummaryRecord struct {
|
||||
snapshot fileTransferSnapshot
|
||||
active bool
|
||||
}
|
||||
|
||||
func fileTransferSummaryFromSnapshot(snapshot fileTransferSnapshot, active bool) fileTransferSummary {
|
||||
return fileTransferSummary{
|
||||
Direction: snapshot.Direction,
|
||||
Scope: snapshot.Scope,
|
||||
RuntimeScope: snapshot.RuntimeScope,
|
||||
TransportGeneration: snapshot.TransportGeneration,
|
||||
NetType: snapshot.NetType,
|
||||
Kind: snapshot.Kind,
|
||||
FileID: snapshot.FileID,
|
||||
Path: snapshot.Path,
|
||||
Received: snapshot.Received,
|
||||
Total: snapshot.Total,
|
||||
Percent: snapshot.Percent,
|
||||
Active: active,
|
||||
Terminal: !active && isFileTransferTerminal(snapshot.Kind),
|
||||
Done: snapshot.Done,
|
||||
Failed: snapshot.Kind == EnvelopeFileAbort || snapshot.Err != nil,
|
||||
Err: snapshot.Err,
|
||||
StartedAt: snapshot.StartedAt,
|
||||
UpdatedAt: snapshot.UpdatedAt,
|
||||
Duration: snapshot.Duration,
|
||||
RateBPS: snapshot.RateBPS,
|
||||
StepDuration: snapshot.StepDuration,
|
||||
InstantRateBPS: snapshot.InstantRateBPS,
|
||||
Time: snapshot.Time,
|
||||
Stage: snapshot.Stage,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) activeSummaries() []fileTransferSummary {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return summariesFromSnapshots(sortedFileTransferSnapshots(m.active), true)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) completedSummaries() []fileTransferSummary {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return summariesFromSnapshots(sortedFileTransferSnapshots(m.completed), false)
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) latestSummary(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) {
|
||||
if m == nil {
|
||||
return fileTransferSummary{}, false
|
||||
}
|
||||
key := fileTransferMonitorKey(direction, scope, fileID)
|
||||
if key == "" {
|
||||
return fileTransferSummary{}, false
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if snapshot, ok := m.active[key]; ok {
|
||||
return fileTransferSummaryFromSnapshot(snapshot, true), true
|
||||
}
|
||||
snapshot, ok := m.completed[key]
|
||||
if !ok {
|
||||
return fileTransferSummary{}, false
|
||||
}
|
||||
return fileTransferSummaryFromSnapshot(snapshot, false), true
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) summariesByFileID(fileID string) []fileTransferSummary {
|
||||
if m == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), fileID))
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) summariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary {
|
||||
if m == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), direction, fileID))
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) runtimeSummariesByFileID(fileID string) []fileTransferSummary {
|
||||
if m == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), fileID))
|
||||
}
|
||||
|
||||
func (m *fileTransferMonitor) runtimeSummariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary {
|
||||
if m == nil || fileID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), direction, fileID))
|
||||
}
|
||||
|
||||
func latestFileTransferSummaryRecordsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSummaryRecord {
|
||||
keys := make([]string, 0, len(active)+len(completed))
|
||||
seen := make(map[string]struct{}, len(active)+len(completed))
|
||||
for key := range completed {
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
for key := range active {
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
out := make([]fileTransferSummaryRecord, 0, len(keys))
|
||||
for _, key := range keys {
|
||||
if snapshot, ok := active[key]; ok {
|
||||
out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: true})
|
||||
continue
|
||||
}
|
||||
if snapshot, ok := completed[key]; ok {
|
||||
out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: false})
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func summariesFromSnapshots(src []fileTransferSnapshot, active bool) []fileTransferSummary {
|
||||
out := make([]fileTransferSummary, 0, len(src))
|
||||
for _, snapshot := range src {
|
||||
out = append(out, fileTransferSummaryFromSnapshot(snapshot, active))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func summariesFromRecords(src []fileTransferSummaryRecord) []fileTransferSummary {
|
||||
out := make([]fileTransferSummary, 0, len(src))
|
||||
for _, record := range src {
|
||||
out = append(out, fileTransferSummaryFromSnapshot(record.snapshot, record.active))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterFileTransferSummaryRecordsByFileID(src []fileTransferSummaryRecord, fileID string) []fileTransferSummaryRecord {
|
||||
out := make([]fileTransferSummaryRecord, 0, len(src))
|
||||
for _, record := range src {
|
||||
if record.snapshot.FileID != fileID {
|
||||
continue
|
||||
}
|
||||
out = append(out, record)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func filterFileTransferSummaryRecordsByDirectionAndFileID(src []fileTransferSummaryRecord, direction fileTransferDirection, fileID string) []fileTransferSummaryRecord {
|
||||
out := make([]fileTransferSummaryRecord, 0, len(src))
|
||||
for _, record := range src {
|
||||
if record.snapshot.Direction != direction || record.snapshot.FileID != fileID {
|
||||
continue
|
||||
}
|
||||
out = append(out, record)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTransferMonitorLatestSummaryPrefersActive(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
now := time.Unix(500, 0)
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "summary-1", Size: 30},
|
||||
Received: 12,
|
||||
Total: 30,
|
||||
Percent: 40,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(time.Second),
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
summary, ok := monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1")
|
||||
if !ok {
|
||||
t.Fatal("latest summary should exist while active")
|
||||
}
|
||||
if got, want := summary.Active, true; got != want {
|
||||
t.Fatalf("active summary mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summary.Terminal, false; got != want {
|
||||
t.Fatalf("terminal summary mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summary.Received, int64(12); got != want {
|
||||
t.Fatalf("active summary received mismatch: got %d want %d", got, want)
|
||||
}
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "summary-1", Size: 30},
|
||||
Received: 30,
|
||||
Total: 30,
|
||||
Percent: 100,
|
||||
Done: true,
|
||||
StartedAt: now,
|
||||
UpdatedAt: now.Add(2 * time.Second),
|
||||
Time: now.Add(2 * time.Second),
|
||||
})
|
||||
|
||||
summary, ok = monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1")
|
||||
if !ok {
|
||||
t.Fatal("latest summary should exist after completion")
|
||||
}
|
||||
if got, want := summary.Active, false; got != want {
|
||||
t.Fatalf("completed summary active mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summary.Terminal, true; got != want {
|
||||
t.Fatalf("completed summary terminal mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summary.Done, true; got != want {
|
||||
t.Fatalf("completed summary done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferMonitorSummariesByFileID(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
now := time.Unix(600, 0)
|
||||
serverClientA := &ClientConn{ClientID: "client-a"}
|
||||
serverClientB := &ClientConn{ClientID: "client-b"}
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "summary-shared", Size: 20},
|
||||
Received: 8,
|
||||
Total: 20,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClientA,
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "summary-shared", Size: 20},
|
||||
Received: 12,
|
||||
Total: 20,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
ClientConn: serverClientB,
|
||||
Kind: EnvelopeFileAbort,
|
||||
Packet: FilePacket{FileID: "summary-shared", Size: 20, Stage: "chunk"},
|
||||
Received: 14,
|
||||
Total: 20,
|
||||
Time: now.Add(2 * time.Second),
|
||||
Err: errString("recv failed"),
|
||||
})
|
||||
|
||||
summaries := monitor.summariesByFileID("summary-shared")
|
||||
if got, want := len(summaries), 3; got != want {
|
||||
t.Fatalf("summaries count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := summaries[0].Scope, serverFileScope(serverClientA); got != want {
|
||||
t.Fatalf("first summary scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := summaries[0].Active, true; got != want {
|
||||
t.Fatalf("first summary active mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summaries[1].Scope, serverFileScope(serverClientB); got != want {
|
||||
t.Fatalf("second summary scope mismatch: got %q want %q", got, want)
|
||||
}
|
||||
if got, want := summaries[1].Failed, true; got != want {
|
||||
t.Fatalf("second summary failed mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summaries[1].Terminal, true; got != want {
|
||||
t.Fatalf("second summary terminal mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := summaries[2].Direction, fileTransferDirectionSend; got != want {
|
||||
t.Fatalf("third summary direction mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferMonitorActiveAndCompletedSummaries(t *testing.T) {
|
||||
monitor := newFileTransferMonitor()
|
||||
now := time.Unix(700, 0)
|
||||
|
||||
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||
Kind: EnvelopeFileChunk,
|
||||
Packet: FilePacket{FileID: "active-1", Size: 10},
|
||||
Received: 3,
|
||||
Total: 10,
|
||||
Time: now,
|
||||
})
|
||||
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||
Kind: EnvelopeFileEnd,
|
||||
Packet: FilePacket{FileID: "done-1", Size: 10},
|
||||
Received: 10,
|
||||
Total: 10,
|
||||
Done: true,
|
||||
Time: now.Add(time.Second),
|
||||
})
|
||||
|
||||
active := monitor.activeSummaries()
|
||||
if got, want := len(active), 1; got != want {
|
||||
t.Fatalf("active summaries count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := active[0].Active, true; got != want {
|
||||
t.Fatalf("active summary state mismatch: got %v want %v", got, want)
|
||||
}
|
||||
|
||||
completed := monitor.completedSummaries()
|
||||
if got, want := len(completed), 1; got != want {
|
||||
t.Fatalf("completed summaries count mismatch: got %d want %d", got, want)
|
||||
}
|
||||
if got, want := completed[0].Active, false; got != want {
|
||||
t.Fatalf("completed summary state mismatch: got %v want %v", got, want)
|
||||
}
|
||||
if got, want := completed[0].Done, true; got != want {
|
||||
t.Fatalf("completed summary done mismatch: got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
type errString string
|
||||
|
||||
func (e errString) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
@@ -1,8 +1,16 @@
|
||||
module b612.me/notify
|
||||
|
||||
go 1.16
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
b612.me/starcrypto v0.0.5
|
||||
b612.me/stario v0.0.10
|
||||
b612.me/starcrypto v1.0.2
|
||||
b612.me/stario v0.1.0
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/emmansun/gmsm v0.41.1 // indirect
|
||||
golang.org/x/crypto v0.48.0 // indirect
|
||||
golang.org/x/sys v0.41.0 // indirect
|
||||
golang.org/x/term v0.40.0 // indirect
|
||||
)
|
||||
|
||||
@@ -1,75 +1,22 @@
|
||||
b612.me/starcrypto v0.0.5 h1:Aa4pRDO2lBH2Aw+vz8NuUtRb73J8z5aOa9SImBY5sq4=
|
||||
b612.me/starcrypto v0.0.5/go.mod h1:pF5A16p8r/h1G0x7ZNmmAF6K1sdIMpbCUxn2WGC8gZ0=
|
||||
b612.me/stario v0.0.0-20240818091810-d528a583f4b2 h1:SxN1WDZsEBQFTnLaKbc7Z+91uyWhUB4cKHo5Ucztyh0=
|
||||
b612.me/stario v0.0.0-20240818091810-d528a583f4b2/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
|
||||
b612.me/stario v0.0.10/go.mod h1:1Owmu9jzKWgs4VsmeI8YWlGwLrCwPNM/bYpxkyn+MMk=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
|
||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||
golang.org/x/crypto v0.26.0 h1:RrRspgV4mU+YwB4FYnuBoKsUapNIL5cohGAmSH3azsw=
|
||||
golang.org/x/crypto v0.26.0/go.mod h1:GY7jblb9wI+FOo5y8/S2oY4zWP07AkOJ4+jxCqdqn54=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
|
||||
golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.8.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.23.0 h1:YfKFowiIMvtgl1UERQoTPPToxltDeZfbj4H7dVUCwmM=
|
||||
golang.org/x/sys v0.23.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/term v0.18.0/go.mod h1:ILwASektA3OnRv7amZ1xhE/KTR+u50pbXfZ03+6Nx58=
|
||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||
golang.org/x/term v0.23.0 h1:F6D4vR+EHoL9/sWAWgAR1H2DcHr4PareCbAaCo1RpuU=
|
||||
golang.org/x/term v0.23.0/go.mod h1:DgV24QBUrK6jhZXl+20l6UWznPlwAHm1Q1mGHtydmSk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/text v0.17.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE=
|
||||
b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE=
|
||||
b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4=
|
||||
b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc=
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w=
|
||||
github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const defaultInboundDispatchSource = "_notify.default_inbound_source"
|
||||
|
||||
type inboundDispatcher struct {
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
workers map[string]*inboundDispatchWorker
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type inboundDispatchWorker struct {
|
||||
queue []func()
|
||||
running bool
|
||||
}
|
||||
|
||||
func newInboundDispatcher() *inboundDispatcher {
|
||||
return &inboundDispatcher{
|
||||
workers: make(map[string]*inboundDispatchWorker),
|
||||
}
|
||||
}
|
||||
|
||||
func (d *inboundDispatcher) Dispatch(source string, fn func()) bool {
|
||||
if d == nil || fn == nil {
|
||||
return false
|
||||
}
|
||||
if source == "" {
|
||||
source = defaultInboundDispatchSource
|
||||
}
|
||||
d.mu.Lock()
|
||||
if d.closed {
|
||||
d.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
worker := d.workers[source]
|
||||
if worker == nil {
|
||||
worker = &inboundDispatchWorker{}
|
||||
d.workers[source] = worker
|
||||
}
|
||||
worker.queue = append(worker.queue, fn)
|
||||
if worker.running {
|
||||
d.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
worker.running = true
|
||||
d.wg.Add(1)
|
||||
d.mu.Unlock()
|
||||
go d.run(source, worker)
|
||||
return true
|
||||
}
|
||||
|
||||
func (d *inboundDispatcher) run(source string, worker *inboundDispatchWorker) {
|
||||
defer d.wg.Done()
|
||||
for {
|
||||
d.mu.Lock()
|
||||
if len(worker.queue) == 0 {
|
||||
worker.running = false
|
||||
if current := d.workers[source]; current == worker {
|
||||
delete(d.workers, source)
|
||||
}
|
||||
d.mu.Unlock()
|
||||
return
|
||||
}
|
||||
fn := worker.queue[0]
|
||||
worker.queue[0] = nil
|
||||
worker.queue = worker.queue[1:]
|
||||
d.mu.Unlock()
|
||||
fn()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *inboundDispatcher) CloseAndWait() {
|
||||
if d == nil {
|
||||
return
|
||||
}
|
||||
d.mu.Lock()
|
||||
d.closed = true
|
||||
d.mu.Unlock()
|
||||
d.wg.Wait()
|
||||
}
|
||||
|
||||
func clientInboundDispatchSource() string {
|
||||
return "client"
|
||||
}
|
||||
|
||||
func serverInboundDispatchSource(source interface{}) string {
|
||||
switch data := source.(type) {
|
||||
case serverInboundSource:
|
||||
return serverInboundDispatchSourceKey(data)
|
||||
case *serverInboundSource:
|
||||
if data == nil {
|
||||
return defaultInboundDispatchSource
|
||||
}
|
||||
return serverInboundDispatchSourceKey(*data)
|
||||
case net.Conn:
|
||||
return fmt.Sprintf("conn:%p", data)
|
||||
case string:
|
||||
if data == "" {
|
||||
return defaultInboundDispatchSource
|
||||
}
|
||||
return "peer:" + data
|
||||
default:
|
||||
return defaultInboundDispatchSource
|
||||
}
|
||||
}
|
||||
|
||||
func serverInboundDispatchSourceKey(source serverInboundSource) string {
|
||||
if source.Conn != nil {
|
||||
return fmt.Sprintf("conn:%p:%d", source.Conn, source.TransportGeneration)
|
||||
}
|
||||
if source.Logical != nil {
|
||||
return fmt.Sprintf("logical:%s:%d", source.Logical.ID(), source.TransportGeneration)
|
||||
}
|
||||
if source.Source != "" {
|
||||
return fmt.Sprintf("peer:%s:%d", source.Source, source.TransportGeneration)
|
||||
}
|
||||
if source.RemoteAddr != nil {
|
||||
return fmt.Sprintf("addr:%s:%d", source.RemoteAddr.String(), source.TransportGeneration)
|
||||
}
|
||||
return defaultInboundDispatchSource
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package notify
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestInboundDispatcherSerializesPerSource(t *testing.T) {
|
||||
dispatcher := newInboundDispatcher()
|
||||
defer dispatcher.CloseAndWait()
|
||||
|
||||
firstStarted := make(chan struct{}, 1)
|
||||
secondStarted := make(chan struct{}, 1)
|
||||
otherStarted := make(chan struct{}, 1)
|
||||
releaseFirst := make(chan struct{})
|
||||
|
||||
var mu sync.Mutex
|
||||
var order []string
|
||||
|
||||
record := func(step string) {
|
||||
mu.Lock()
|
||||
order = append(order, step)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
if !dispatcher.Dispatch("alpha", func() {
|
||||
record("alpha-1-start")
|
||||
firstStarted <- struct{}{}
|
||||
<-releaseFirst
|
||||
record("alpha-1-end")
|
||||
}) {
|
||||
t.Fatal("dispatch alpha-1 failed")
|
||||
}
|
||||
if !dispatcher.Dispatch("alpha", func() {
|
||||
record("alpha-2-start")
|
||||
secondStarted <- struct{}{}
|
||||
record("alpha-2-end")
|
||||
}) {
|
||||
t.Fatal("dispatch alpha-2 failed")
|
||||
}
|
||||
if !dispatcher.Dispatch("beta", func() {
|
||||
record("beta-1-start")
|
||||
otherStarted <- struct{}{}
|
||||
record("beta-1-end")
|
||||
}) {
|
||||
t.Fatal("dispatch beta-1 failed")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-firstStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for alpha-1")
|
||||
}
|
||||
select {
|
||||
case <-otherStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for beta-1")
|
||||
}
|
||||
select {
|
||||
case <-secondStarted:
|
||||
t.Fatal("alpha-2 started before alpha-1 finished")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(releaseFirst)
|
||||
|
||||
select {
|
||||
case <-secondStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for alpha-2")
|
||||
}
|
||||
|
||||
dispatcher.CloseAndWait()
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
if len(order) == 0 {
|
||||
t.Fatal("dispatch order is empty")
|
||||
}
|
||||
alpha1Start := indexOfString(order, "alpha-1-start")
|
||||
alpha1End := indexOfString(order, "alpha-1-end")
|
||||
alpha2Start := indexOfString(order, "alpha-2-start")
|
||||
beta1Start := indexOfString(order, "beta-1-start")
|
||||
if alpha1Start < 0 || alpha1End < 0 || alpha2Start < 0 || beta1Start < 0 {
|
||||
t.Fatalf("unexpected order trace: %v", order)
|
||||
}
|
||||
if alpha2Start < alpha1End {
|
||||
t.Fatalf("alpha source was not serialized: %v", order)
|
||||
}
|
||||
if beta1Start > alpha1End {
|
||||
t.Fatalf("beta source did not run in parallel window: %v", order)
|
||||
}
|
||||
}
|
||||
|
||||
func indexOfString(list []string, target string) int {
|
||||
for idx, item := range list {
|
||||
if item == target {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
package notify
|
||||
|
||||
import "b612.me/starcrypto"
|
||||
|
||||
var integrationSharedSecret = []byte("notify-integration-modern-psk")
|
||||
|
||||
func integrationModernPSKOptions() *ModernPSKOptions {
|
||||
return &ModernPSKOptions{
|
||||
Salt: []byte("notify-integration-modern-psk-salt"),
|
||||
AAD: []byte("notify-integration-modern-psk-aad"),
|
||||
Argon2Params: starcrypto.Argon2Params{
|
||||
Time: 1,
|
||||
Memory: 8,
|
||||
Threads: 1,
|
||||
KeyLen: 32,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
package codec
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
)
|
||||
|
||||
func Register(data interface{}) {
|
||||
gob.Register(data)
|
||||
}
|
||||
|
||||
func RegisterName(name string, data interface{}) {
|
||||
gob.RegisterName(name, data)
|
||||
}
|
||||
|
||||
func RegisterAll(data []interface{}) {
|
||||
for _, v := range data {
|
||||
gob.Register(v)
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterNames(data map[string]interface{}) {
|
||||
for k, v := range data {
|
||||
gob.RegisterName(k, v)
|
||||
}
|
||||
}
|
||||
|
||||
func Encode(src interface{}) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
enc := gob.NewEncoder(&buf)
|
||||
err := enc.Encode(&src)
|
||||
return buf.Bytes(), err
|
||||
}
|
||||
|
||||
func Decode(src []byte) (interface{}, error) {
|
||||
dec := gob.NewDecoder(bytes.NewReader(src))
|
||||
var dst interface{}
|
||||
err := dec.Decode(&dst)
|
||||
return dst, err
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
package timeutil
|
||||
|
||||
import "time"
|
||||
|
||||
func NowUnixNano() int64 {
|
||||
return time.Now().UnixNano()
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTransferIDEmpty = errors.New("transfer id is empty")
|
||||
ErrTransferExists = errors.New("transfer already exists")
|
||||
ErrTransferNotFound = errors.New("transfer not found")
|
||||
ErrTransferBytesInvalid = errors.New("transfer bytes must be non-negative")
|
||||
)
|
||||
|
||||
type Manager struct {
|
||||
mu sync.Mutex
|
||||
now func() time.Time
|
||||
transfers map[string]*managedTransfer
|
||||
}
|
||||
|
||||
type managedTransfer struct {
|
||||
snapshot Snapshot
|
||||
}
|
||||
|
||||
func NewManager() *Manager {
|
||||
return NewManagerWithClock(time.Now)
|
||||
}
|
||||
|
||||
func NewManagerWithClock(now func() time.Time) *Manager {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
return &Manager{
|
||||
now: now,
|
||||
transfers: make(map[string]*managedTransfer),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) StartOutgoing(desc Descriptor) (Snapshot, error) {
|
||||
return m.start(desc, DirectionSend, StateNegotiating)
|
||||
}
|
||||
|
||||
func (m *Manager) StartIncoming(desc Descriptor) (Snapshot, error) {
|
||||
return m.start(desc, DirectionReceive, StatePrepared)
|
||||
}
|
||||
|
||||
func (m *Manager) Activate(id string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.State = StateActive
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) Pause(id string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
if snapshot.State.Terminal() {
|
||||
return nil
|
||||
}
|
||||
snapshot.State = StatePaused
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) Resume(id string, confirmedBytes int64) (Snapshot, error) {
|
||||
if confirmedBytes < 0 {
|
||||
return Snapshot{}, ErrTransferBytesInvalid
|
||||
}
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
switch snapshot.Direction {
|
||||
case DirectionSend:
|
||||
if confirmedBytes > snapshot.SentBytes {
|
||||
snapshot.SentBytes = confirmedBytes
|
||||
}
|
||||
snapshot.AckedBytes = confirmedBytes
|
||||
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
|
||||
snapshot.AckedBytes = snapshot.Size
|
||||
}
|
||||
case DirectionReceive:
|
||||
if confirmedBytes > snapshot.ReceivedBytes {
|
||||
snapshot.ReceivedBytes = confirmedBytes
|
||||
}
|
||||
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
|
||||
snapshot.ReceivedBytes = snapshot.Size
|
||||
}
|
||||
}
|
||||
snapshot.State = StateActive
|
||||
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) RecordSend(id string, sentBytes int64) (Snapshot, error) {
|
||||
if sentBytes < 0 {
|
||||
return Snapshot{}, ErrTransferBytesInvalid
|
||||
}
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.SentBytes += sentBytes
|
||||
if snapshot.Size > 0 && snapshot.SentBytes > snapshot.Size {
|
||||
snapshot.SentBytes = snapshot.Size
|
||||
}
|
||||
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||
if !snapshot.State.Terminal() {
|
||||
snapshot.State = StateActive
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) RecordReceive(id string, recvBytes int64) (Snapshot, error) {
|
||||
if recvBytes < 0 {
|
||||
return Snapshot{}, ErrTransferBytesInvalid
|
||||
}
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.ReceivedBytes += recvBytes
|
||||
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
|
||||
snapshot.ReceivedBytes = snapshot.Size
|
||||
}
|
||||
if !snapshot.State.Terminal() {
|
||||
snapshot.State = StateActive
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) SetAckedBytes(id string, ackedBytes int64) (Snapshot, error) {
|
||||
if ackedBytes < 0 {
|
||||
return Snapshot{}, ErrTransferBytesInvalid
|
||||
}
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.AckedBytes = ackedBytes
|
||||
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
|
||||
snapshot.AckedBytes = snapshot.Size
|
||||
}
|
||||
if snapshot.AckedBytes > snapshot.SentBytes {
|
||||
snapshot.SentBytes = snapshot.AckedBytes
|
||||
}
|
||||
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) BeginCommit(id string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.State = StateCommitting
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) BeginVerify(id string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.State = StateVerifying
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) Complete(id string) (Snapshot, error) {
|
||||
now := m.currentTime()
|
||||
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
|
||||
snapshot.State = StateDone
|
||||
snapshot.CompletedAt = now.UnixNano()
|
||||
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) Abort(id string, err error) (Snapshot, error) {
|
||||
return m.finishWithError(id, StateAborted, err)
|
||||
}
|
||||
|
||||
func (m *Manager) Fail(id string, err error) (Snapshot, error) {
|
||||
return m.finishWithError(id, StateFailed, err)
|
||||
}
|
||||
|
||||
func (m *Manager) RecordRetry(id string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.RetryCount++
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) RecordTimeout(id string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.TimeoutCount++
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) SetStage(id string, stage string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.Stage = stage
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) SetFailureStage(id string, stage string) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
snapshot.LastFailureStage = stage
|
||||
if stage != "" {
|
||||
snapshot.Stage = stage
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) MergeMetadata(id string, metadata Metadata) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
if len(metadata) == 0 {
|
||||
return nil
|
||||
}
|
||||
if snapshot.Metadata == nil {
|
||||
snapshot.Metadata = make(Metadata, len(metadata))
|
||||
}
|
||||
for key, value := range metadata {
|
||||
if value == "" {
|
||||
delete(snapshot.Metadata, key)
|
||||
continue
|
||||
}
|
||||
snapshot.Metadata[key] = value
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) RecordTelemetry(id string, delta TelemetryDelta) (Snapshot, error) {
|
||||
return m.update(id, func(snapshot *Snapshot) error {
|
||||
if delta.SourceReadDuration > 0 {
|
||||
snapshot.SourceReadDuration += delta.SourceReadDuration
|
||||
}
|
||||
if delta.StreamWriteDuration > 0 {
|
||||
snapshot.StreamWriteDuration += delta.StreamWriteDuration
|
||||
}
|
||||
if delta.SinkWriteDuration > 0 {
|
||||
snapshot.SinkWriteDuration += delta.SinkWriteDuration
|
||||
}
|
||||
if delta.SyncDuration > 0 {
|
||||
snapshot.SyncDuration += delta.SyncDuration
|
||||
}
|
||||
if delta.VerifyDuration > 0 {
|
||||
snapshot.VerifyDuration += delta.VerifyDuration
|
||||
}
|
||||
if delta.CommitDuration > 0 {
|
||||
snapshot.CommitDuration += delta.CommitDuration
|
||||
}
|
||||
if delta.CommitWaitDuration > 0 {
|
||||
snapshot.CommitWaitDuration += delta.CommitWaitDuration
|
||||
}
|
||||
if delta.SourceReadCount > 0 {
|
||||
snapshot.SourceReadCount += delta.SourceReadCount
|
||||
}
|
||||
if delta.StreamWriteCount > 0 {
|
||||
snapshot.StreamWriteCount += delta.StreamWriteCount
|
||||
}
|
||||
if delta.SinkWriteCount > 0 {
|
||||
snapshot.SinkWriteCount += delta.SinkWriteCount
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) Snapshot(id string) (Snapshot, bool) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
transfer, ok := m.transfers[id]
|
||||
if !ok {
|
||||
return Snapshot{}, false
|
||||
}
|
||||
return cloneSnapshot(transfer.snapshot), true
|
||||
}
|
||||
|
||||
func (m *Manager) Snapshots() []Snapshot {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
out := make([]Snapshot, 0, len(m.transfers))
|
||||
for _, transfer := range m.transfers {
|
||||
out = append(out, cloneSnapshot(transfer.snapshot))
|
||||
}
|
||||
sort.Slice(out, func(i int, j int) bool {
|
||||
return out[i].ID < out[j].ID
|
||||
})
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) Restore(snapshot Snapshot) (Snapshot, error) {
|
||||
if snapshot.ID == "" {
|
||||
return Snapshot{}, ErrTransferIDEmpty
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.transfers[snapshot.ID] = &managedTransfer{snapshot: cloneSnapshot(snapshot)}
|
||||
return cloneSnapshot(snapshot), nil
|
||||
}
|
||||
|
||||
func (m *Manager) start(desc Descriptor, direction Direction, state State) (Snapshot, error) {
|
||||
if desc.ID == "" {
|
||||
return Snapshot{}, ErrTransferIDEmpty
|
||||
}
|
||||
now := m.currentTime()
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if _, exists := m.transfers[desc.ID]; exists {
|
||||
return Snapshot{}, ErrTransferExists
|
||||
}
|
||||
snapshot := Snapshot{
|
||||
ID: desc.ID,
|
||||
Direction: direction,
|
||||
Channel: normalizeChannel(desc.Channel),
|
||||
State: state,
|
||||
Size: desc.Size,
|
||||
Checksum: desc.Checksum,
|
||||
Metadata: cloneMetadata(desc.Metadata),
|
||||
StartedAt: now.UnixNano(),
|
||||
UpdatedAt: now.UnixNano(),
|
||||
}
|
||||
m.transfers[desc.ID] = &managedTransfer{snapshot: snapshot}
|
||||
return cloneSnapshot(snapshot), nil
|
||||
}
|
||||
|
||||
func (m *Manager) finishWithError(id string, state State, err error) (Snapshot, error) {
|
||||
now := m.currentTime()
|
||||
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
|
||||
snapshot.State = state
|
||||
snapshot.CompletedAt = now.UnixNano()
|
||||
if err != nil {
|
||||
snapshot.LastError = err.Error()
|
||||
}
|
||||
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) update(id string, fn func(*Snapshot) error) (Snapshot, error) {
|
||||
return m.updateWithTime(id, m.currentTime(), func(snapshot *Snapshot, _ time.Time) error {
|
||||
return fn(snapshot)
|
||||
})
|
||||
}
|
||||
|
||||
func (m *Manager) updateWithTime(id string, now time.Time, fn func(*Snapshot, time.Time) error) (Snapshot, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
transfer, ok := m.transfers[id]
|
||||
if !ok {
|
||||
return Snapshot{}, ErrTransferNotFound
|
||||
}
|
||||
snapshot := &transfer.snapshot
|
||||
if err := fn(snapshot, now); err != nil {
|
||||
return Snapshot{}, err
|
||||
}
|
||||
snapshot.UpdatedAt = now.UnixNano()
|
||||
return cloneSnapshot(*snapshot), nil
|
||||
}
|
||||
|
||||
func (m *Manager) currentTime() time.Time {
|
||||
return m.now()
|
||||
}
|
||||
|
||||
func inflightBytes(snapshot Snapshot) int64 {
|
||||
if snapshot.Direction != DirectionSend {
|
||||
return 0
|
||||
}
|
||||
if snapshot.SentBytes <= snapshot.AckedBytes {
|
||||
return 0
|
||||
}
|
||||
return snapshot.SentBytes - snapshot.AckedBytes
|
||||
}
|
||||
@@ -0,0 +1,193 @@
|
||||
package transfer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type fakeClock struct {
|
||||
now time.Time
|
||||
}
|
||||
|
||||
func (f *fakeClock) Now() time.Time {
|
||||
return f.now
|
||||
}
|
||||
|
||||
func (f *fakeClock) Advance(d time.Duration) {
|
||||
f.now = f.now.Add(d)
|
||||
}
|
||||
|
||||
func TestManagerOutgoingLifecycle(t *testing.T) {
|
||||
clock := &fakeClock{now: time.Unix(100, 0)}
|
||||
manager := NewManagerWithClock(clock.Now)
|
||||
|
||||
snapshot, err := manager.StartOutgoing(Descriptor{
|
||||
ID: "tx-1",
|
||||
Size: 100,
|
||||
Checksum: "sum-1",
|
||||
Metadata: Metadata{"kind": "file"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartOutgoing failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.State, StateNegotiating; got != want {
|
||||
t.Fatalf("start state = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := snapshot.Channel, DataChannel; got != want {
|
||||
t.Fatalf("channel = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
clock.Advance(time.Second)
|
||||
if _, err := manager.Activate("tx-1"); err != nil {
|
||||
t.Fatalf("Activate failed: %v", err)
|
||||
}
|
||||
clock.Advance(time.Second)
|
||||
if _, err := manager.RecordSend("tx-1", 60); err != nil {
|
||||
t.Fatalf("RecordSend failed: %v", err)
|
||||
}
|
||||
clock.Advance(time.Second)
|
||||
if _, err := manager.SetAckedBytes("tx-1", 40); err != nil {
|
||||
t.Fatalf("SetAckedBytes failed: %v", err)
|
||||
}
|
||||
if _, err := manager.RecordRetry("tx-1"); err != nil {
|
||||
t.Fatalf("RecordRetry failed: %v", err)
|
||||
}
|
||||
if _, err := manager.RecordTimeout("tx-1"); err != nil {
|
||||
t.Fatalf("RecordTimeout failed: %v", err)
|
||||
}
|
||||
if _, err := manager.Pause("tx-1"); err != nil {
|
||||
t.Fatalf("Pause failed: %v", err)
|
||||
}
|
||||
clock.Advance(time.Second)
|
||||
if _, err := manager.Resume("tx-1", 40); err != nil {
|
||||
t.Fatalf("Resume failed: %v", err)
|
||||
}
|
||||
if _, err := manager.BeginCommit("tx-1"); err != nil {
|
||||
t.Fatalf("BeginCommit failed: %v", err)
|
||||
}
|
||||
clock.Advance(time.Second)
|
||||
snapshot, err = manager.Complete("tx-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Complete failed: %v", err)
|
||||
}
|
||||
|
||||
if got, want := snapshot.State, StateDone; got != want {
|
||||
t.Fatalf("complete state = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := snapshot.SentBytes, int64(60); got != want {
|
||||
t.Fatalf("sent bytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.AckedBytes, int64(40); got != want {
|
||||
t.Fatalf("acked bytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.InflightBytes, int64(20); got != want {
|
||||
t.Fatalf("inflight bytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.RetryCount, 1; got != want {
|
||||
t.Fatalf("retry count = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.TimeoutCount, 1; got != want {
|
||||
t.Fatalf("timeout count = %d, want %d", got, want)
|
||||
}
|
||||
if _, err := manager.SetStage("tx-1", "chunk"); err != nil {
|
||||
t.Fatalf("SetStage failed: %v", err)
|
||||
}
|
||||
if _, err := manager.SetFailureStage("tx-1", "chunk"); err != nil {
|
||||
t.Fatalf("SetFailureStage failed: %v", err)
|
||||
}
|
||||
if snapshot.CompletedAt == 0 {
|
||||
t.Fatal("completed timestamp should be set")
|
||||
}
|
||||
if got := snapshot.Metadata["kind"]; got != "file" {
|
||||
t.Fatalf("metadata kind = %q, want file", got)
|
||||
}
|
||||
snapshot, ok := manager.Snapshot("tx-1")
|
||||
if !ok {
|
||||
t.Fatal("snapshot should still exist")
|
||||
}
|
||||
if got, want := snapshot.Stage, "chunk"; got != want {
|
||||
t.Fatalf("stage = %q, want %q", got, want)
|
||||
}
|
||||
if got, want := snapshot.LastFailureStage, "chunk"; got != want {
|
||||
t.Fatalf("last failure stage = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerIncomingResumeAndVerify(t *testing.T) {
|
||||
clock := &fakeClock{now: time.Unix(200, 0)}
|
||||
manager := NewManagerWithClock(clock.Now)
|
||||
|
||||
snapshot, err := manager.StartIncoming(Descriptor{
|
||||
ID: "rx-1",
|
||||
Channel: ControlChannel,
|
||||
Size: 64,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartIncoming failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.State, StatePrepared; got != want {
|
||||
t.Fatalf("prepared state = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
clock.Advance(time.Second)
|
||||
snapshot, err = manager.Resume("rx-1", 16)
|
||||
if err != nil {
|
||||
t.Fatalf("Resume failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.ReceivedBytes, int64(16); got != want {
|
||||
t.Fatalf("received bytes after resume = %d, want %d", got, want)
|
||||
}
|
||||
|
||||
if _, err := manager.RecordReceive("rx-1", 20); err != nil {
|
||||
t.Fatalf("RecordReceive failed: %v", err)
|
||||
}
|
||||
if _, err := manager.BeginVerify("rx-1"); err != nil {
|
||||
t.Fatalf("BeginVerify failed: %v", err)
|
||||
}
|
||||
clock.Advance(time.Second)
|
||||
snapshot, err = manager.Complete("rx-1")
|
||||
if err != nil {
|
||||
t.Fatalf("Complete failed: %v", err)
|
||||
}
|
||||
if got, want := snapshot.State, StateDone; got != want {
|
||||
t.Fatalf("complete state = %v, want %v", got, want)
|
||||
}
|
||||
if got, want := snapshot.ReceivedBytes, int64(36); got != want {
|
||||
t.Fatalf("received bytes = %d, want %d", got, want)
|
||||
}
|
||||
if got, want := snapshot.Channel, ControlChannel; got != want {
|
||||
t.Fatalf("channel = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerValidatesIDsAndSortedSnapshots(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
if _, err := manager.StartOutgoing(Descriptor{}); !errors.Is(err, ErrTransferIDEmpty) {
|
||||
t.Fatalf("empty id error = %v, want %v", err, ErrTransferIDEmpty)
|
||||
}
|
||||
if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); err != nil {
|
||||
t.Fatalf("StartOutgoing b failed: %v", err)
|
||||
}
|
||||
if _, err := manager.StartIncoming(Descriptor{ID: "a"}); err != nil {
|
||||
t.Fatalf("StartIncoming a failed: %v", err)
|
||||
}
|
||||
if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); !errors.Is(err, ErrTransferExists) {
|
||||
t.Fatalf("duplicate id error = %v, want %v", err, ErrTransferExists)
|
||||
}
|
||||
if _, err := manager.RecordSend("missing", 1); !errors.Is(err, ErrTransferNotFound) {
|
||||
t.Fatalf("missing transfer error = %v, want %v", err, ErrTransferNotFound)
|
||||
}
|
||||
if _, err := manager.RecordReceive("a", -1); !errors.Is(err, ErrTransferBytesInvalid) {
|
||||
t.Fatalf("negative bytes error = %v, want %v", err, ErrTransferBytesInvalid)
|
||||
}
|
||||
|
||||
snapshots := manager.Snapshots()
|
||||
if len(snapshots) != 2 {
|
||||
t.Fatalf("snapshot count = %d, want 2", len(snapshots))
|
||||
}
|
||||
if snapshots[0].ID != "a" || snapshots[1].ID != "b" {
|
||||
t.Fatalf("snapshot order = [%s %s], want [a b]", snapshots[0].ID, snapshots[1].ID)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package transfer
|
||||
|
||||
import "time"
|
||||
|
||||
type Channel string
|
||||
|
||||
const (
|
||||
ControlChannel Channel = "control"
|
||||
DataChannel Channel = "data"
|
||||
)
|
||||
|
||||
type Direction uint8
|
||||
|
||||
const (
|
||||
DirectionSend Direction = iota
|
||||
DirectionReceive
|
||||
)
|
||||
|
||||
type State uint8
|
||||
|
||||
const (
|
||||
StateInit State = iota
|
||||
StateNegotiating
|
||||
StatePrepared
|
||||
StateActive
|
||||
StatePaused
|
||||
StateCommitting
|
||||
StateVerifying
|
||||
StateDone
|
||||
StateAborted
|
||||
StateFailed
|
||||
)
|
||||
|
||||
func (s State) Terminal() bool {
|
||||
switch s {
|
||||
case StateDone, StateAborted, StateFailed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
type Range struct {
|
||||
Offset int64
|
||||
Length int64
|
||||
}
|
||||
|
||||
type Metadata map[string]string
|
||||
|
||||
type TelemetryDelta struct {
|
||||
SourceReadDuration time.Duration
|
||||
StreamWriteDuration time.Duration
|
||||
SinkWriteDuration time.Duration
|
||||
SyncDuration time.Duration
|
||||
VerifyDuration time.Duration
|
||||
CommitDuration time.Duration
|
||||
CommitWaitDuration time.Duration
|
||||
SourceReadCount int
|
||||
StreamWriteCount int
|
||||
SinkWriteCount int
|
||||
}
|
||||
|
||||
type Descriptor struct {
|
||||
ID string
|
||||
Direction Direction
|
||||
Channel Channel
|
||||
Size int64
|
||||
Checksum string
|
||||
Metadata Metadata
|
||||
}
|
||||
|
||||
type Snapshot struct {
|
||||
ID string
|
||||
Direction Direction
|
||||
Channel Channel
|
||||
State State
|
||||
Stage string
|
||||
LastFailureStage string
|
||||
Size int64
|
||||
Checksum string
|
||||
Metadata Metadata
|
||||
SentBytes int64
|
||||
AckedBytes int64
|
||||
ReceivedBytes int64
|
||||
InflightBytes int64
|
||||
RetryCount int
|
||||
TimeoutCount int
|
||||
LastError string
|
||||
SourceReadDuration time.Duration
|
||||
StreamWriteDuration time.Duration
|
||||
SinkWriteDuration time.Duration
|
||||
SyncDuration time.Duration
|
||||
VerifyDuration time.Duration
|
||||
CommitDuration time.Duration
|
||||
CommitWaitDuration time.Duration
|
||||
SourceReadCount int
|
||||
StreamWriteCount int
|
||||
SinkWriteCount int
|
||||
StartedAt int64
|
||||
UpdatedAt int64
|
||||
CompletedAt int64
|
||||
}
|
||||
|
||||
type Begin struct {
|
||||
TransferID string
|
||||
Channel Channel
|
||||
Size int64
|
||||
Checksum string
|
||||
Metadata Metadata
|
||||
}
|
||||
|
||||
type BeginAck struct {
|
||||
TransferID string
|
||||
Accepted bool
|
||||
NextOffset int64
|
||||
Missing []Range
|
||||
Error string
|
||||
}
|
||||
|
||||
type Resume struct {
|
||||
TransferID string
|
||||
}
|
||||
|
||||
type ResumeAck struct {
|
||||
TransferID string
|
||||
Accepted bool
|
||||
NextOffset int64
|
||||
Missing []Range
|
||||
Error string
|
||||
}
|
||||
|
||||
type Commit struct {
|
||||
TransferID string
|
||||
Size int64
|
||||
Checksum string
|
||||
}
|
||||
|
||||
type CommitAck struct {
|
||||
TransferID string
|
||||
Accepted bool
|
||||
Error string
|
||||
}
|
||||
|
||||
type Abort struct {
|
||||
TransferID string
|
||||
Stage string
|
||||
Offset int64
|
||||
Error string
|
||||
}
|
||||
|
||||
type Segment struct {
|
||||
TransferID string
|
||||
Channel Channel
|
||||
Offset int64
|
||||
Payload []byte
|
||||
Flags uint32
|
||||
}
|
||||
|
||||
type Ack struct {
|
||||
TransferID string
|
||||
NextOffset int64
|
||||
Missing []Range
|
||||
Final bool
|
||||
Error string
|
||||
}
|
||||
|
||||
func normalizeChannel(channel Channel) Channel {
|
||||
if channel == "" {
|
||||
return DataChannel
|
||||
}
|
||||
return channel
|
||||
}
|
||||
|
||||
func cloneMetadata(src Metadata) Metadata {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make(Metadata, len(src))
|
||||
for key, value := range src {
|
||||
dst[key] = value
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
func cloneSnapshot(src Snapshot) Snapshot {
|
||||
src.Metadata = cloneMetadata(src.Metadata)
|
||||
return src
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
//go:build !windows
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func dialNamedPipe(_ string, _ *time.Duration) (net.Conn, error) {
|
||||
return nil, ErrNamedPipeUnsupported
|
||||
}
|
||||
|
||||
func listenNamedPipe(_ string) (net.Listener, error) {
|
||||
return nil, ErrNamedPipeUnsupported
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
//go:build windows
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"github.com/Microsoft/go-winio"
|
||||
)
|
||||
|
||||
func dialNamedPipe(addr string, timeout *time.Duration) (net.Conn, error) {
|
||||
return winio.DialPipe(NormalizeNamedPipeAddr(addr), timeout)
|
||||
}
|
||||
|
||||
func listenNamedPipe(addr string) (net.Listener, error) {
|
||||
return winio.ListenPipe(NormalizeNamedPipeAddr(addr), &winio.PipeConfig{
|
||||
MessageMode: false,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrNamedPipeUnsupported = errors.New("named pipe transport is only supported on windows")
|
||||
|
||||
func IsUDPNetwork(network string) bool {
|
||||
return strings.Contains(strings.ToLower(strings.TrimSpace(network)), "udp")
|
||||
}
|
||||
|
||||
func IsNamedPipeNetwork(network string) bool {
|
||||
switch strings.ToLower(strings.TrimSpace(network)) {
|
||||
case "npipe", "pipe", "namedpipe", "named-pipe":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func Dial(network string, addr string) (net.Conn, error) {
|
||||
if IsNamedPipeNetwork(network) {
|
||||
return dialNamedPipe(addr, nil)
|
||||
}
|
||||
return net.Dial(network, addr)
|
||||
}
|
||||
|
||||
func DialTimeout(network string, addr string, timeout time.Duration) (net.Conn, error) {
|
||||
if IsNamedPipeNetwork(network) {
|
||||
return dialNamedPipe(addr, &timeout)
|
||||
}
|
||||
return net.DialTimeout(network, addr, timeout)
|
||||
}
|
||||
|
||||
func Listen(network string, addr string) (net.Listener, error) {
|
||||
if IsNamedPipeNetwork(network) {
|
||||
return listenNamedPipe(addr)
|
||||
}
|
||||
return net.Listen(network, addr)
|
||||
}
|
||||
|
||||
func NormalizeNamedPipeAddr(addr string) string {
|
||||
trimmed := strings.TrimSpace(addr)
|
||||
if trimmed == "" {
|
||||
return trimmed
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `\\.\pipe\`) {
|
||||
return trimmed
|
||||
}
|
||||
if strings.HasPrefix(trimmed, `//./pipe/`) {
|
||||
return `\\.\pipe\` + strings.TrimPrefix(trimmed, `//./pipe/`)
|
||||
}
|
||||
trimmed = strings.TrimPrefix(trimmed, `\\`)
|
||||
trimmed = strings.TrimPrefix(trimmed, `//`)
|
||||
trimmed = strings.TrimPrefix(trimmed, `.\pipe\`)
|
||||
trimmed = strings.TrimPrefix(trimmed, `./pipe/`)
|
||||
trimmed = strings.TrimPrefix(trimmed, `pipe\`)
|
||||
trimmed = strings.TrimPrefix(trimmed, `pipe/`)
|
||||
trimmed = strings.TrimLeft(strings.ReplaceAll(trimmed, "/", `\`), `\`)
|
||||
return `\\.\pipe\` + trimmed
|
||||
}
|
||||
|
||||
func ConnRemoteAddrString(conn net.Conn) string {
|
||||
if conn == nil {
|
||||
return "unknown"
|
||||
}
|
||||
addr := conn.RemoteAddr()
|
||||
if addr == nil {
|
||||
return "unknown"
|
||||
}
|
||||
if value := addr.String(); value != "" {
|
||||
return value
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
//go:build !windows
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDialNamedPipeUnsupportedOnNonWindows(t *testing.T) {
|
||||
_, err := DialTimeout("npipe", "notify-demo", time.Millisecond)
|
||||
if !errors.Is(err, ErrNamedPipeUnsupported) {
|
||||
t.Fatalf("DialTimeout error = %v, want %v", err, ErrNamedPipeUnsupported)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListenNamedPipeUnsupportedOnNonWindows(t *testing.T) {
|
||||
_, err := Listen("npipe", "notify-demo")
|
||||
if !errors.Is(err, ErrNamedPipeUnsupported) {
|
||||
t.Fatalf("Listen error = %v, want %v", err, ErrNamedPipeUnsupported)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
package transport
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNamedPipeNetworkAliases(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
want bool
|
||||
}{
|
||||
{network: "npipe", want: true},
|
||||
{network: "pipe", want: true},
|
||||
{network: "namedpipe", want: true},
|
||||
{network: "named-pipe", want: true},
|
||||
{network: "tcp", want: false},
|
||||
{network: "unix", want: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := IsNamedPipeNetwork(tt.network); got != tt.want {
|
||||
t.Fatalf("IsNamedPipeNetwork(%q) = %v, want %v", tt.network, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeNamedPipeAddr(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
want string
|
||||
}{
|
||||
{name: "short-name", addr: "notify-demo", want: `\\.\pipe\notify-demo`},
|
||||
{name: "pipe-prefix", addr: `pipe\notify-demo`, want: `\\.\pipe\notify-demo`},
|
||||
{name: "slash-prefix", addr: "//./pipe/notify-demo", want: `\\.\pipe\notify-demo`},
|
||||
{name: "normalized", addr: `\\.\pipe\notify-demo`, want: `\\.\pipe\notify-demo`},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NormalizeNamedPipeAddr(tt.addr); got != tt.want {
|
||||
t.Fatalf("NormalizeNamedPipeAddr(%q) = %q, want %q", tt.addr, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
//go:build windows
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestNamedPipeRoundTripByteMode(t *testing.T) {
|
||||
pipeName := fmt.Sprintf("notify-npipe-test-%d", time.Now().UnixNano())
|
||||
listener, err := Listen("npipe", pipeName)
|
||||
if err != nil {
|
||||
t.Fatalf("Listen failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = listener.Close()
|
||||
}()
|
||||
|
||||
serverErr := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
buf := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
if got, want := string(buf), "ping"; got != want {
|
||||
serverErr <- fmt.Errorf("server got %q, want %q", got, want)
|
||||
return
|
||||
}
|
||||
if _, err := conn.Write([]byte("pong")); err != nil {
|
||||
serverErr <- err
|
||||
return
|
||||
}
|
||||
serverErr <- nil
|
||||
}()
|
||||
|
||||
conn, err := DialTimeout("npipe", pipeName, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("DialTimeout failed: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = conn.Close()
|
||||
}()
|
||||
|
||||
if _, err := conn.Write([]byte("ping")); err != nil {
|
||||
t.Fatalf("client write failed: %v", err)
|
||||
}
|
||||
|
||||
reply := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, reply); err != nil {
|
||||
t.Fatalf("client read failed: %v", err)
|
||||
}
|
||||
if got, want := string(reply), "pong"; got != want {
|
||||
t.Fatalf("client got %q, want %q", got, want)
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-serverErr:
|
||||
if err != nil {
|
||||
t.Fatalf("server error: %v", err)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("server timeout")
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user