Compare commits
19 Commits
newgen
...
v1.3.0-rc1
| Author | SHA1 | Date | |
|---|---|---|---|
|
09d972c7b7
|
|||
| d14d13c393 | |||
| 85803f75df | |||
| 48f630564f | |||
| a81d74ac45 | |||
| a2ab64a372 | |||
| 48bbc5b776 | |||
| 9065a12b99 | |||
| 996f94eef0 | |||
|
2db6102668
|
|||
| 72c3bc1c1c | |||
|
f51d2c7137
|
|||
|
bd262df8ea
|
|||
|
3964cd05b0
|
|||
|
555bc3653e
|
|||
| bdcfcd05db | |||
| 164c412c24 | |||
| 07e374b83f | |||
| 79dcaaf249 |
@@ -0,0 +1,8 @@
|
|||||||
|
.sentrux/
|
||||||
|
agent_readme.md
|
||||||
|
target.md
|
||||||
|
notify_plan.md
|
||||||
|
.gocache
|
||||||
|
.gocache/
|
||||||
|
.tmp_*/
|
||||||
|
.idea
|
||||||
@@ -0,0 +1,201 @@
|
|||||||
|
Apache License
|
||||||
|
Version 2.0, January 2004
|
||||||
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
|
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||||
|
|
||||||
|
1. Definitions.
|
||||||
|
|
||||||
|
"License" shall mean the terms and conditions for use, reproduction,
|
||||||
|
and distribution as defined by Sections 1 through 9 of this document.
|
||||||
|
|
||||||
|
"Licensor" shall mean the copyright owner or entity authorized by
|
||||||
|
the copyright owner that is granting the License.
|
||||||
|
|
||||||
|
"Legal Entity" shall mean the union of the acting entity and all
|
||||||
|
other entities that control, are controlled by, or are under common
|
||||||
|
control with that entity. For the purposes of this definition,
|
||||||
|
"control" means (i) the power, direct or indirect, to cause the
|
||||||
|
direction or management of such entity, whether by contract or
|
||||||
|
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||||
|
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
|
"You" (or "Your") shall mean an individual or Legal Entity
|
||||||
|
exercising permissions granted by this License.
|
||||||
|
|
||||||
|
"Source" form shall mean the preferred form for making modifications,
|
||||||
|
including but not limited to software source code, documentation
|
||||||
|
source, and configuration files.
|
||||||
|
|
||||||
|
"Object" form shall mean any form resulting from mechanical
|
||||||
|
transformation or translation of a Source form, including but
|
||||||
|
not limited to compiled object code, generated documentation,
|
||||||
|
and conversions to other media types.
|
||||||
|
|
||||||
|
"Work" shall mean the work of authorship, whether in Source or
|
||||||
|
Object form, made available under the License, as indicated by a
|
||||||
|
copyright notice that is included in or attached to the work
|
||||||
|
(an example is provided in the Appendix below).
|
||||||
|
|
||||||
|
"Derivative Works" shall mean any work, whether in Source or Object
|
||||||
|
form, that is based on (or derived from) the Work and for which the
|
||||||
|
editorial revisions, annotations, elaborations, or other modifications
|
||||||
|
represent, as a whole, an original work of authorship. For the purposes
|
||||||
|
of this License, Derivative Works shall not include works that remain
|
||||||
|
separable from, or merely link (or bind by name) to the interfaces of,
|
||||||
|
the Work and Derivative Works thereof.
|
||||||
|
|
||||||
|
"Contribution" shall mean any work of authorship, including
|
||||||
|
the original version of the Work and any modifications or additions
|
||||||
|
to that Work or Derivative Works thereof, that is intentionally
|
||||||
|
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||||
|
or by an individual or Legal Entity authorized to submit on behalf of
|
||||||
|
the copyright owner. For the purposes of this definition, "submitted"
|
||||||
|
means any form of electronic, verbal, or written communication sent
|
||||||
|
to the Licensor or its representatives, including but not limited to
|
||||||
|
communication on electronic mailing lists, source code control systems,
|
||||||
|
and issue tracking systems that are managed by, or on behalf of, the
|
||||||
|
Licensor for the purpose of discussing and improving the Work, but
|
||||||
|
excluding communication that is conspicuously marked or otherwise
|
||||||
|
designated in writing by the copyright owner as "Not a Contribution."
|
||||||
|
|
||||||
|
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||||
|
on behalf of whom a Contribution has been received by Licensor and
|
||||||
|
subsequently incorporated within the Work.
|
||||||
|
|
||||||
|
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
copyright license to reproduce, prepare Derivative Works of,
|
||||||
|
publicly display, publicly perform, sublicense, and distribute the
|
||||||
|
Work and such Derivative Works in Source or Object form.
|
||||||
|
|
||||||
|
3. Grant of Patent License. Subject to the terms and conditions of
|
||||||
|
this License, each Contributor hereby grants to You a perpetual,
|
||||||
|
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||||
|
(except as stated in this section) patent license to make, have made,
|
||||||
|
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||||
|
where such license applies only to those patent claims licensable
|
||||||
|
by such Contributor that are necessarily infringed by their
|
||||||
|
Contribution(s) alone or by combination of their Contribution(s)
|
||||||
|
with the Work to which such Contribution(s) was submitted. If You
|
||||||
|
institute patent litigation against any entity (including a
|
||||||
|
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||||
|
or a Contribution incorporated within the Work constitutes direct
|
||||||
|
or contributory patent infringement, then any patent licenses
|
||||||
|
granted to You under this License for that Work shall terminate
|
||||||
|
as of the date such litigation is filed.
|
||||||
|
|
||||||
|
4. Redistribution. You may reproduce and distribute copies of the
|
||||||
|
Work or Derivative Works thereof in any medium, with or without
|
||||||
|
modifications, and in Source or Object form, provided that You
|
||||||
|
meet the following conditions:
|
||||||
|
|
||||||
|
(a) You must give any other recipients of the Work or
|
||||||
|
Derivative Works a copy of this License; and
|
||||||
|
|
||||||
|
(b) You must cause any modified files to carry prominent notices
|
||||||
|
stating that You changed the files; and
|
||||||
|
|
||||||
|
(c) You must retain, in the Source form of any Derivative Works
|
||||||
|
that You distribute, all copyright, patent, trademark, and
|
||||||
|
attribution notices from the Source form of the Work,
|
||||||
|
excluding those notices that do not pertain to any part of
|
||||||
|
the Derivative Works; and
|
||||||
|
|
||||||
|
(d) If the Work includes a "NOTICE" text file as part of its
|
||||||
|
distribution, then any Derivative Works that You distribute must
|
||||||
|
include a readable copy of the attribution notices contained
|
||||||
|
within such NOTICE file, excluding those notices that do not
|
||||||
|
pertain to any part of the Derivative Works, in at least one
|
||||||
|
of the following places: within a NOTICE text file distributed
|
||||||
|
as part of the Derivative Works; within the Source form or
|
||||||
|
documentation, if provided along with the Derivative Works; or,
|
||||||
|
within a display generated by the Derivative Works, if and
|
||||||
|
wherever such third-party notices normally appear. The contents
|
||||||
|
of the NOTICE file are for informational purposes only and
|
||||||
|
do not modify the License. You may add Your own attribution
|
||||||
|
notices within Derivative Works that You distribute, alongside
|
||||||
|
or as an addendum to the NOTICE text from the Work, provided
|
||||||
|
that such additional attribution notices cannot be construed
|
||||||
|
as modifying the License.
|
||||||
|
|
||||||
|
You may add Your own copyright statement to Your modifications and
|
||||||
|
may provide additional or different license terms and conditions
|
||||||
|
for use, reproduction, or distribution of Your modifications, or
|
||||||
|
for any such Derivative Works as a whole, provided Your use,
|
||||||
|
reproduction, and distribution of the Work otherwise complies with
|
||||||
|
the conditions stated in this License.
|
||||||
|
|
||||||
|
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||||
|
any Contribution intentionally submitted for inclusion in the Work
|
||||||
|
by You to the Licensor shall be under the terms and conditions of
|
||||||
|
this License, without any additional terms or conditions.
|
||||||
|
Notwithstanding the above, nothing herein shall supersede or modify
|
||||||
|
the terms of any separate license agreement you may have executed
|
||||||
|
with Licensor regarding such Contributions.
|
||||||
|
|
||||||
|
6. Trademarks. This License does not grant permission to use the trade
|
||||||
|
names, trademarks, service marks, or product names of the Licensor,
|
||||||
|
except as required for reasonable and customary use in describing the
|
||||||
|
origin of the Work and reproducing the content of the NOTICE file.
|
||||||
|
|
||||||
|
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||||
|
agreed to in writing, Licensor provides the Work (and each
|
||||||
|
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||||
|
implied, including, without limitation, any warranties or conditions
|
||||||
|
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||||
|
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||||
|
appropriateness of using or redistributing the Work and assume any
|
||||||
|
risks associated with Your exercise of permissions under this License.
|
||||||
|
|
||||||
|
8. Limitation of Liability. In no event and under no legal theory,
|
||||||
|
whether in tort (including negligence), contract, or otherwise,
|
||||||
|
unless required by applicable law (such as deliberate and grossly
|
||||||
|
negligent acts) or agreed to in writing, shall any Contributor be
|
||||||
|
liable to You for damages, including any direct, indirect, special,
|
||||||
|
incidental, or consequential damages of any character arising as a
|
||||||
|
result of this License or out of the use or inability to use the
|
||||||
|
Work (including but not limited to damages for loss of goodwill,
|
||||||
|
work stoppage, computer failure or malfunction, or any and all
|
||||||
|
other commercial damages or losses), even if such Contributor
|
||||||
|
has been advised of the possibility of such damages.
|
||||||
|
|
||||||
|
9. Accepting Warranty or Additional Liability. While redistributing
|
||||||
|
the Work or Derivative Works thereof, You may choose to offer,
|
||||||
|
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||||
|
or other liability obligations and/or rights consistent with this
|
||||||
|
License. However, in accepting such obligations, You may act only
|
||||||
|
on Your own behalf and on Your sole responsibility, not on behalf
|
||||||
|
of any other Contributor, and only if You agree to indemnify,
|
||||||
|
defend, and hold each Contributor harmless for any liability
|
||||||
|
incurred by, or claims asserted against, such Contributor by reason
|
||||||
|
of your accepting any such warranty or additional liability.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
APPENDIX: How to apply the Apache License to your work.
|
||||||
|
|
||||||
|
To apply the Apache License to your work, attach the following
|
||||||
|
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||||
|
replaced with your own identifying information. (Don't include
|
||||||
|
the brackets!) The text should be enclosed in the appropriate
|
||||||
|
comment syntax for the file format. We also recommend that a
|
||||||
|
file or class name and description of purpose be included on the
|
||||||
|
same "printed page" as the copyright notice for easier
|
||||||
|
identification within third-party archives.
|
||||||
|
|
||||||
|
Copyright 2026 starnet contributors
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
@@ -0,0 +1,133 @@
|
|||||||
|
# notify
|
||||||
|
|
||||||
|
`b612.me/notify` 是一个面向点对点直连场景的 Go 通信基础包,覆盖消息信令、流式传输、批量数据通道和文件传输内核能力。
|
||||||
|
|
||||||
|
## 模块定位
|
||||||
|
|
||||||
|
- 消息面:`Send`、`SendWait`、`Reply`、`SetLink`
|
||||||
|
- 流式数据面:`OpenStream`
|
||||||
|
- 记录流数据面:`OpenRecordStream`
|
||||||
|
- 批量数据面:`OpenBulk`(`shared` / `dedicated`)
|
||||||
|
- 文件传输内核:transfer control / progress / resume
|
||||||
|
- 会话模型:`LogicalConn`(逻辑会话)与 `TransportConn`(物理承载)分离
|
||||||
|
|
||||||
|
## 版本要求
|
||||||
|
|
||||||
|
- Go `1.24+`
|
||||||
|
|
||||||
|
## 安全初始化要求
|
||||||
|
|
||||||
|
`Client` / `Server` 在 `Connect` / `Listen` 前必须完成安全配置。默认使用现代 PSK 方案。
|
||||||
|
|
||||||
|
- 客户端:`UseModernPSKClient`
|
||||||
|
- 服务端:`UseModernPSKServer`
|
||||||
|
|
||||||
|
未配置时会返回 `errModernPSKRequired`。
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
服务端:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"b612.me/notify"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
srv := notify.NewServer()
|
||||||
|
if err := notify.UseModernPSKServer(srv, []byte("shared-secret"), nil); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
srv.SetLink("ping", func(msg *notify.Message) {
|
||||||
|
_ = msg.Reply([]byte("pong"))
|
||||||
|
})
|
||||||
|
if err := srv.Listen("tcp", "127.0.0.1:28080"); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
select {}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
客户端:
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/notify"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
cli := notify.NewClient()
|
||||||
|
if err := notify.UseModernPSKClient(cli, []byte("shared-secret"), nil); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
if err := cli.Connect("tcp", "127.0.0.1:28080"); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer cli.Stop()
|
||||||
|
|
||||||
|
reply, err := cli.SendWait("ping", []byte("hello"), 5*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
log.Printf("reply=%s", string(reply.Value))
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 传输与 IPC
|
||||||
|
|
||||||
|
- `tcp`
|
||||||
|
- `udp`
|
||||||
|
- `unix`
|
||||||
|
- `npipe`(Windows)
|
||||||
|
|
||||||
|
示例目录:
|
||||||
|
|
||||||
|
- [examples/signal](/mnt/c/coding/gocode/src/b612.me/notify/examples/signal)
|
||||||
|
|
||||||
|
## 现代 PSK 与兼容入口
|
||||||
|
|
||||||
|
现代方案特性:
|
||||||
|
|
||||||
|
- 共享密钥派生(Argon2id)
|
||||||
|
- 消息层加密(AES-GCM)
|
||||||
|
- `stream` / `bulk` fast path 复用现代编码栈
|
||||||
|
|
||||||
|
兼容入口仍保留,但属于历史路径:
|
||||||
|
|
||||||
|
- `UseLegacySecurityClient`
|
||||||
|
- `UseLegacySecurityServer`
|
||||||
|
- `ExchangeKey`
|
||||||
|
- `SetSecretKey`
|
||||||
|
- `SetMsgEn` / `SetMsgDe`
|
||||||
|
|
||||||
|
## 发布前检查
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SENTRUX_SKIP_GRAMMAR_DOWNLOAD='1'
|
||||||
|
sentrux check .
|
||||||
|
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test ./...
|
||||||
|
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go test -race ./...
|
||||||
|
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache go vet ./...
|
||||||
|
```
|
||||||
|
|
||||||
|
手工 soak 测试(可选):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
env GOCACHE=/tmp/b612-gocache GOMODCACHE=/tmp/b612-gomodcache \
|
||||||
|
go test -tags notify_manual_soak -run 'Test_ServerTuAndClientCommon|Test_normal|Test_normal_udp'
|
||||||
|
```
|
||||||
|
|
||||||
|
## 兼容性说明
|
||||||
|
|
||||||
|
- 对外主入口保留:`NewClient`、`NewServer`、`Connect`、`Listen`、`SetLink`、`SetDefaultLink`、`Send`、`SendWait`、`SendObj`、`Reply`、`Stop`
|
||||||
|
- 内部主对象已迁移为 `LogicalConn` / `TransportConn`
|
||||||
|
- `ClientConn` 作为兼容适配层继续保留
|
||||||
@@ -0,0 +1,266 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bulkBatchMaxPayloads = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bulkBatchRequestQueued int32 = iota
|
||||||
|
bulkBatchRequestStarted
|
||||||
|
bulkBatchRequestCanceled
|
||||||
|
)
|
||||||
|
|
||||||
|
type bulkBatchRequestState struct {
|
||||||
|
value atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkBatchRequest struct {
|
||||||
|
ctx context.Context
|
||||||
|
payload []byte
|
||||||
|
deadline time.Time
|
||||||
|
done chan error
|
||||||
|
state *bulkBatchRequestState
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkBatchSender struct {
|
||||||
|
binding *transportBinding
|
||||||
|
reqCh chan bulkBatchRequest
|
||||||
|
stopCh chan struct{}
|
||||||
|
doneCh chan struct{}
|
||||||
|
|
||||||
|
stopOnce sync.Once
|
||||||
|
errMu sync.Mutex
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBulkBatchSender(binding *transportBinding) *bulkBatchSender {
|
||||||
|
sender := &bulkBatchSender{
|
||||||
|
binding: binding,
|
||||||
|
reqCh: make(chan bulkBatchRequest, bulkBatchMaxPayloads*4),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go sender.run()
|
||||||
|
return sender
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) submit(ctx context.Context, payload []byte) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
req := bulkBatchRequest{
|
||||||
|
ctx: ctx,
|
||||||
|
payload: payload,
|
||||||
|
done: make(chan error, 1),
|
||||||
|
state: &bulkBatchRequestState{},
|
||||||
|
}
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
req.deadline = deadline
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-s.stopCh:
|
||||||
|
return s.stoppedErr()
|
||||||
|
case s.reqCh <- req:
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-req.done:
|
||||||
|
return err
|
||||||
|
case <-ctx.Done():
|
||||||
|
if req.tryCancel() {
|
||||||
|
return normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
}
|
||||||
|
return <-req.done
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) run() {
|
||||||
|
defer close(s.doneCh)
|
||||||
|
for {
|
||||||
|
req, ok := s.nextRequest()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch := []bulkBatchRequest{req}
|
||||||
|
drain:
|
||||||
|
for len(batch) < bulkBatchMaxPayloads {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
s.failPending(s.stoppedErr())
|
||||||
|
return
|
||||||
|
case next := <-s.reqCh:
|
||||||
|
batch = append(batch, next)
|
||||||
|
default:
|
||||||
|
break drain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
active, payloads := activeBulkBatchRequests(batch)
|
||||||
|
if len(active) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
deadline := bulkBatchRequestsEarliestDeadline(active)
|
||||||
|
err := s.flush(payloads, deadline)
|
||||||
|
if err != nil {
|
||||||
|
s.setErr(err)
|
||||||
|
for _, item := range active {
|
||||||
|
item.done <- err
|
||||||
|
}
|
||||||
|
s.failPending(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range active {
|
||||||
|
item.done <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) nextRequest() (bulkBatchRequest, bool) {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
s.failPending(s.stoppedErr())
|
||||||
|
return bulkBatchRequest{}, false
|
||||||
|
case req := <-s.reqCh:
|
||||||
|
return req, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func activeBulkBatchRequests(batch []bulkBatchRequest) ([]bulkBatchRequest, [][]byte) {
|
||||||
|
active := make([]bulkBatchRequest, 0, len(batch))
|
||||||
|
payloads := make([][]byte, 0, len(batch))
|
||||||
|
for _, item := range batch {
|
||||||
|
if !item.tryStart() {
|
||||||
|
item.done <- item.canceledErr()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := item.contextErr(); err != nil {
|
||||||
|
item.done <- err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
active = append(active, item)
|
||||||
|
payloads = append(payloads, item.payload)
|
||||||
|
}
|
||||||
|
return active, payloads
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkBatchRequestsEarliestDeadline(batch []bulkBatchRequest) time.Time {
|
||||||
|
var deadline time.Time
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.deadline.IsZero() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if deadline.IsZero() || item.deadline.Before(deadline) {
|
||||||
|
deadline = item.deadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deadline
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkBatchRequest) contextErr() error {
|
||||||
|
if r.ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-r.ctx.Done():
|
||||||
|
return normalizeStreamDeadlineError(r.ctx.Err())
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkBatchRequest) tryStart() bool {
|
||||||
|
if r.state == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkBatchRequest) tryCancel() bool {
|
||||||
|
if r.state == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return r.state.value.CompareAndSwap(bulkBatchRequestQueued, bulkBatchRequestCanceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkBatchRequest) canceledErr() error {
|
||||||
|
if err := r.contextErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return context.Canceled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) flush(payloads [][]byte, deadline time.Time) error {
|
||||||
|
if s == nil || s.binding == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
queue := s.binding.queueSnapshot()
|
||||||
|
if queue == nil {
|
||||||
|
return errTransportFrameQueueUnavailable
|
||||||
|
}
|
||||||
|
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
|
||||||
|
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
s.setErr(errTransportDetached)
|
||||||
|
close(s.stopCh)
|
||||||
|
})
|
||||||
|
<-s.doneCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) failPending(err error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case item := <-s.reqCh:
|
||||||
|
item.done <- err
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) setErr(err error) {
|
||||||
|
if s == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.errMu.Lock()
|
||||||
|
if s.err == nil {
|
||||||
|
s.err = err
|
||||||
|
}
|
||||||
|
s.errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) errSnapshot() error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
s.errMu.Lock()
|
||||||
|
defer s.errMu.Unlock()
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkBatchSender) stoppedErr() error {
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
@@ -0,0 +1,392 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughput(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "chunk_256KiB",
|
||||||
|
payloadSize: 256 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_768KiB",
|
||||||
|
payloadSize: 768 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_2MiB",
|
||||||
|
payloadSize: 2 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughput(b, tc.payloadSize, false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughputDedicated(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "chunk_256KiB",
|
||||||
|
payloadSize: 256 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_768KiB",
|
||||||
|
payloadSize: 768 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "chunk_2MiB",
|
||||||
|
payloadSize: 2 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughput(b, tc.payloadSize, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughputConcurrent(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
concurrency int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bulks_2_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bulks_4_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bulks_2_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
concurrency: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bulks_4_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, false)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkTCPThroughputConcurrentDedicated(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
concurrency int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "bulks_2_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bulks_4_512KiB",
|
||||||
|
payloadSize: 512 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bulks_2_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
concurrency: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bulks_4_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkTCPThroughputConcurrent(b, tc.payloadSize, tc.concurrency, true)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkBulkTCPThroughput(b *testing.B, payloadSize int, dedicated bool) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
acceptCh := make(chan BulkAcceptInfo, 1)
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
acceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||||
|
b.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||||
|
b.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
totalBytes := int64(payloadSize)
|
||||||
|
if b.N > 1 {
|
||||||
|
totalBytes = int64(payloadSize) * int64(b.N)
|
||||||
|
}
|
||||||
|
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: totalBytes,
|
||||||
|
},
|
||||||
|
ChunkSize: payloadSize,
|
||||||
|
Dedicated: dedicated,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("client OpenBulk failed: %v", err)
|
||||||
|
}
|
||||||
|
accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second)
|
||||||
|
|
||||||
|
drainDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(io.Discard, accepted.Bulk)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
drainDone <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
drainDone <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(payloadSize))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
n, err := bulk.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("bulk Write failed at iter %d: %v", i, err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
if err := bulk.CloseWrite(); err != nil {
|
||||||
|
b.Fatalf("bulk CloseWrite failed: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-drainDone:
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("server drain failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
b.Fatal("timed out waiting for server drain")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = accepted.Bulk.Close()
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkBulkTCPThroughputConcurrent(b *testing.B, payloadSize int, concurrency int, dedicated bool) {
|
||||||
|
b.Helper()
|
||||||
|
if concurrency <= 0 {
|
||||||
|
b.Fatal("concurrency must be > 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
b.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
acceptCh := make(chan BulkAcceptInfo, concurrency*2)
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
acceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||||
|
b.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
b.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||||
|
b.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
bulks := make([]Bulk, 0, concurrency)
|
||||||
|
acceptedBulks := make([]Bulk, 0, concurrency)
|
||||||
|
totalBytes := int64(payloadSize)
|
||||||
|
if b.N > 1 {
|
||||||
|
totalBytes = int64(payloadSize) * int64(b.N)
|
||||||
|
}
|
||||||
|
for index := 0; index < concurrency; index++ {
|
||||||
|
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: int64(index) * totalBytes,
|
||||||
|
Length: totalBytes,
|
||||||
|
},
|
||||||
|
ChunkSize: payloadSize,
|
||||||
|
Dedicated: dedicated,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("client OpenBulk failed for bulk %d: %v", index, err)
|
||||||
|
}
|
||||||
|
bulks = append(bulks, bulk)
|
||||||
|
accepted := waitBenchmarkAcceptedBulk(b, acceptCh, 5*time.Second)
|
||||||
|
acceptedBulks = append(acceptedBulks, accepted.Bulk)
|
||||||
|
}
|
||||||
|
|
||||||
|
drainDone := make(chan error, concurrency)
|
||||||
|
for _, acceptedBulk := range acceptedBulks {
|
||||||
|
bulk := acceptedBulk
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(io.Discard, bulk)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
drainDone <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
drainDone <- nil
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(payloadSize))
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errCh := make(chan error, concurrency)
|
||||||
|
for index, bulk := range bulks {
|
||||||
|
count := b.N / concurrency
|
||||||
|
if index < b.N%concurrency {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(bulk Bulk, count int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
n, err := bulk.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
errCh <- errors.New("bulk write bytes mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(bulk, count)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
|
||||||
|
for err := range errCh {
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("concurrent bulk write failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
for index, bulk := range bulks {
|
||||||
|
if err := bulk.CloseWrite(); err != nil {
|
||||||
|
b.Fatalf("bulk %d CloseWrite failed: %v", index, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for index := 0; index < concurrency; index++ {
|
||||||
|
select {
|
||||||
|
case err := <-drainDone:
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("server drain failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, bulk := range acceptedBulks {
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
for _, bulk := range bulks {
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitBenchmarkAcceptedBulk(tb testing.TB, ch <-chan BulkAcceptInfo, timeout time.Duration) BulkAcceptInfo {
|
||||||
|
tb.Helper()
|
||||||
|
select {
|
||||||
|
case info := <-ch:
|
||||||
|
return info
|
||||||
|
case <-time.After(timeout):
|
||||||
|
tb.Fatalf("timed out waiting for accepted bulk after %v", timeout)
|
||||||
|
return BulkAcceptInfo{}
|
||||||
|
}
|
||||||
|
}
|
||||||
+702
@@ -0,0 +1,702 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BulkOpenRequest struct {
|
||||||
|
BulkID string
|
||||||
|
DataID uint64
|
||||||
|
Range BulkRange
|
||||||
|
Metadata BulkMetadata
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
Dedicated bool
|
||||||
|
AttachToken string
|
||||||
|
ChunkSize int
|
||||||
|
WindowBytes int
|
||||||
|
MaxInFlight int
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkOpenResponse struct {
|
||||||
|
BulkID string
|
||||||
|
DataID uint64
|
||||||
|
Accepted bool
|
||||||
|
Dedicated bool
|
||||||
|
AttachToken string
|
||||||
|
TransportGeneration uint64
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkCloseRequest struct {
|
||||||
|
BulkID string
|
||||||
|
Full bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkCloseResponse struct {
|
||||||
|
BulkID string
|
||||||
|
Accepted bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkResetRequest struct {
|
||||||
|
BulkID string
|
||||||
|
DataID uint64
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkResetResponse struct {
|
||||||
|
BulkID string
|
||||||
|
Accepted bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type BulkReleaseRequest struct {
|
||||||
|
BulkID string
|
||||||
|
DataID uint64
|
||||||
|
Bytes int64
|
||||||
|
Chunks int
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindClientBulkControl(c *ClientCommon) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.SetLink(BulkOpenSignalKey, func(msg *Message) {
|
||||||
|
c.handleInboundBulkOpen(msg)
|
||||||
|
})
|
||||||
|
c.SetLink(BulkCloseSignalKey, func(msg *Message) {
|
||||||
|
c.handleInboundBulkClose(msg)
|
||||||
|
})
|
||||||
|
c.SetLink(BulkResetSignalKey, func(msg *Message) {
|
||||||
|
c.handleInboundBulkReset(msg)
|
||||||
|
})
|
||||||
|
c.SetLink(BulkReleaseSignalKey, func(msg *Message) {
|
||||||
|
c.handleInboundBulkRelease(msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindServerBulkControl(s *ServerCommon) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.SetLink(BulkOpenSignalKey, func(msg *Message) {
|
||||||
|
s.handleInboundBulkOpen(msg)
|
||||||
|
})
|
||||||
|
s.SetLink(BulkCloseSignalKey, func(msg *Message) {
|
||||||
|
s.handleInboundBulkClose(msg)
|
||||||
|
})
|
||||||
|
s.SetLink(BulkResetSignalKey, func(msg *Message) {
|
||||||
|
s.handleInboundBulkReset(msg)
|
||||||
|
})
|
||||||
|
s.SetLink(BulkReleaseSignalKey, func(msg *Message) {
|
||||||
|
s.handleInboundBulkRelease(msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleInboundBulkOpen(msg *Message) {
|
||||||
|
req, err := decodeBulkOpenRequest(msg)
|
||||||
|
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Dedicated {
|
||||||
|
if err := clientDedicatedBulkSupportError(c); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
resp.Error = errBulkRuntimeNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scope := clientFileScope()
|
||||||
|
if req.DataID == 0 {
|
||||||
|
req.DataID = runtime.nextDataID()
|
||||||
|
resp.DataID = req.DataID
|
||||||
|
}
|
||||||
|
if req.Dedicated && req.AttachToken == "" {
|
||||||
|
req.AttachToken = newBulkAttachToken()
|
||||||
|
}
|
||||||
|
resp.AttachToken = req.AttachToken
|
||||||
|
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, scope, req, c.currentClientSessionEpoch(), nil, nil, 0, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
|
||||||
|
bulk.setClientSnapshotOwner(c)
|
||||||
|
if err := runtime.register(scope, bulk); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler := runtime.handlerSnapshot()
|
||||||
|
if handler == nil {
|
||||||
|
bulk.markReset(errBulkHandlerNotConfigured)
|
||||||
|
resp.Error = errBulkHandlerNotConfigured.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Dedicated {
|
||||||
|
if err := c.attachDedicatedBulkSidecar(context.Background(), bulk); err != nil {
|
||||||
|
bulk.markReset(err)
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info := BulkAcceptInfo{
|
||||||
|
ID: bulk.ID(),
|
||||||
|
Range: bulk.Range(),
|
||||||
|
Metadata: bulk.Metadata(),
|
||||||
|
Dedicated: bulk.Dedicated(),
|
||||||
|
TransportGeneration: bulk.TransportGeneration(),
|
||||||
|
Bulk: bulk,
|
||||||
|
}
|
||||||
|
if err := handler(info); err != nil {
|
||||||
|
bulk.markReset(err)
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Accepted = true
|
||||||
|
resp.DataID = bulk.dataIDSnapshot()
|
||||||
|
resp.TransportGeneration = bulk.TransportGeneration()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) handleInboundBulkOpen(msg *Message) {
|
||||||
|
req, err := decodeBulkOpenRequest(msg)
|
||||||
|
resp := BulkOpenResponse{BulkID: req.BulkID, DataID: req.DataID, Dedicated: req.Dedicated}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := s.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
resp.Error = errBulkRuntimeNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical := messageLogicalConnSnapshot(msg)
|
||||||
|
if logical == nil {
|
||||||
|
resp.Error = errBulkLogicalConnNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
transport := messageTransportConnSnapshot(msg)
|
||||||
|
if req.Dedicated {
|
||||||
|
if err := logicalDedicatedBulkSupportError(logical); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if transport != nil {
|
||||||
|
if err := transportDedicatedBulkSupportError(transport); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scope := serverFileScope(logical)
|
||||||
|
if req.DataID == 0 {
|
||||||
|
req.DataID = runtime.nextDataID()
|
||||||
|
resp.DataID = req.DataID
|
||||||
|
}
|
||||||
|
if req.Dedicated && req.AttachToken == "" {
|
||||||
|
req.AttachToken = newBulkAttachToken()
|
||||||
|
}
|
||||||
|
resp.AttachToken = req.AttachToken
|
||||||
|
bulk := newBulkHandle(logical.stopContextSnapshot(), runtime, scope, req, 0, logical, transport, bulkTransportGeneration(logical, transport), serverBulkCloseSender(s, logical, transport), serverBulkResetSender(s, logical, transport), serverBulkDataSender(s, transport), serverBulkWriteSender(s, logical, transport), serverBulkReleaseSender(s, logical, transport))
|
||||||
|
if err := runtime.register(scope, bulk); err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler := runtime.handlerSnapshot()
|
||||||
|
if handler == nil {
|
||||||
|
bulk.markReset(errBulkHandlerNotConfigured)
|
||||||
|
resp.Error = errBulkHandlerNotConfigured.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
info := BulkAcceptInfo{
|
||||||
|
ID: bulk.ID(),
|
||||||
|
Range: bulk.Range(),
|
||||||
|
Metadata: bulk.Metadata(),
|
||||||
|
Dedicated: bulk.Dedicated(),
|
||||||
|
LogicalConn: logical,
|
||||||
|
TransportConn: transport,
|
||||||
|
TransportGeneration: bulk.TransportGeneration(),
|
||||||
|
Bulk: bulk,
|
||||||
|
}
|
||||||
|
if err := handler(info); err != nil {
|
||||||
|
bulk.markReset(err)
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Accepted = true
|
||||||
|
resp.DataID = bulk.dataIDSnapshot()
|
||||||
|
resp.TransportGeneration = bulk.TransportGeneration()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleInboundBulkClose(msg *Message) {
|
||||||
|
req, err := decodeBulkCloseRequest(msg)
|
||||||
|
resp := BulkCloseResponse{BulkID: req.BulkID}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
resp.Error = errBulkRuntimeNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
|
||||||
|
if !ok {
|
||||||
|
resp.Error = errBulkNotFound.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Full {
|
||||||
|
bulk.markPeerClosed()
|
||||||
|
} else {
|
||||||
|
bulk.markRemoteClosed()
|
||||||
|
}
|
||||||
|
resp.Accepted = true
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) handleInboundBulkClose(msg *Message) {
|
||||||
|
req, err := decodeBulkCloseRequest(msg)
|
||||||
|
resp := BulkCloseResponse{BulkID: req.BulkID}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := s.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
resp.Error = errBulkRuntimeNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical := messageLogicalConnSnapshot(msg)
|
||||||
|
scope := serverFileScope(logical)
|
||||||
|
bulk, ok := runtime.lookup(scope, req.BulkID)
|
||||||
|
if !ok {
|
||||||
|
resp.Error = errBulkNotFound.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Full {
|
||||||
|
bulk.markPeerClosed()
|
||||||
|
} else {
|
||||||
|
bulk.markRemoteClosed()
|
||||||
|
}
|
||||||
|
resp.Accepted = true
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleInboundBulkReset(msg *Message) {
|
||||||
|
req, err := decodeBulkResetRequest(msg)
|
||||||
|
resp := BulkResetResponse{BulkID: req.BulkID}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
resp.Error = errBulkRuntimeNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
|
||||||
|
if !ok && req.DataID != 0 {
|
||||||
|
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
resp.Error = errBulkNotFound.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.BulkID == "" {
|
||||||
|
resp.BulkID = bulk.ID()
|
||||||
|
}
|
||||||
|
bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error)))
|
||||||
|
resp.Accepted = true
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleInboundBulkRelease(msg *Message) {
|
||||||
|
req, err := decodeBulkReleaseRequest(msg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk, ok := runtime.lookup(clientFileScope(), req.BulkID)
|
||||||
|
if !ok && req.DataID != 0 {
|
||||||
|
bulk, ok = runtime.lookupByDataID(clientFileScope(), req.DataID)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) handleInboundBulkReset(msg *Message) {
|
||||||
|
req, err := decodeBulkResetRequest(msg)
|
||||||
|
resp := BulkResetResponse{BulkID: req.BulkID}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := s.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
resp.Error = errBulkRuntimeNil.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical := messageLogicalConnSnapshot(msg)
|
||||||
|
scope := serverFileScope(logical)
|
||||||
|
bulk, ok := runtime.lookup(scope, req.BulkID)
|
||||||
|
if !ok && req.DataID != 0 {
|
||||||
|
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
resp.Error = errBulkNotFound.Error()
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resp.BulkID == "" {
|
||||||
|
resp.BulkID = bulk.ID()
|
||||||
|
}
|
||||||
|
bulk.markReset(bulkResetError(bulkRemoteResetError(req.Error)))
|
||||||
|
resp.Accepted = true
|
||||||
|
replyBulkControlIfNeeded(msg, resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) handleInboundBulkRelease(msg *Message) {
|
||||||
|
req, err := decodeBulkReleaseRequest(msg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := s.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical := messageLogicalConnSnapshot(msg)
|
||||||
|
scope := serverFileScope(logical)
|
||||||
|
bulk, ok := runtime.lookup(scope, req.BulkID)
|
||||||
|
if !ok && req.DataID != 0 {
|
||||||
|
bulk, ok = runtime.lookupByDataID(scope, req.DataID)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk.releaseOutboundWindow(req.Bytes, req.Chunks)
|
||||||
|
}
|
||||||
|
|
||||||
|
func replyBulkControlIfNeeded(msg *Message, value interface{}) {
|
||||||
|
if msg == nil || !requiresSignalReplyWait(msg.TransferMsg) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = msg.ReplyObj(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkOpenClient(ctx context.Context, c Client, req BulkOpenRequest) (BulkOpenResponse, error) {
|
||||||
|
if c == nil {
|
||||||
|
return BulkOpenResponse{}, errBulkClientNil
|
||||||
|
}
|
||||||
|
msg, err := c.SendObjCtx(ctx, BulkOpenSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkOpenResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkOpenResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkOpenServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkOpenRequest) (BulkOpenResponse, error) {
|
||||||
|
if s == nil {
|
||||||
|
return BulkOpenResponse{}, errBulkServerNil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return BulkOpenResponse{}, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
msg, err := s.SendObjCtxLogical(ctx, logical, BulkOpenSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkOpenResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkOpenResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkOpenServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkOpenRequest) (BulkOpenResponse, error) {
|
||||||
|
if s == nil {
|
||||||
|
return BulkOpenResponse{}, errBulkServerNil
|
||||||
|
}
|
||||||
|
if transport == nil {
|
||||||
|
return BulkOpenResponse{}, errBulkTransportNil
|
||||||
|
}
|
||||||
|
msg, err := s.SendObjCtxTransport(ctx, transport, BulkOpenSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkOpenResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkOpenResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkCloseClient(ctx context.Context, c Client, req BulkCloseRequest) (BulkCloseResponse, error) {
|
||||||
|
if c == nil {
|
||||||
|
return BulkCloseResponse{}, errBulkClientNil
|
||||||
|
}
|
||||||
|
msg, err := c.SendObjCtx(ctx, BulkCloseSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkCloseResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkCloseResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkCloseServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkCloseRequest) (BulkCloseResponse, error) {
|
||||||
|
if s == nil {
|
||||||
|
return BulkCloseResponse{}, errBulkServerNil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return BulkCloseResponse{}, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
msg, err := s.SendObjCtxLogical(ctx, logical, BulkCloseSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkCloseResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkCloseResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkCloseServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkCloseRequest) (BulkCloseResponse, error) {
|
||||||
|
if s == nil {
|
||||||
|
return BulkCloseResponse{}, errBulkServerNil
|
||||||
|
}
|
||||||
|
if transport == nil {
|
||||||
|
return BulkCloseResponse{}, errBulkTransportNil
|
||||||
|
}
|
||||||
|
msg, err := s.SendObjCtxTransport(ctx, transport, BulkCloseSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkCloseResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkCloseResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkResetClient(ctx context.Context, c Client, req BulkResetRequest) (BulkResetResponse, error) {
|
||||||
|
if c == nil {
|
||||||
|
return BulkResetResponse{}, errBulkClientNil
|
||||||
|
}
|
||||||
|
msg, err := c.SendObjCtx(ctx, BulkResetSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkResetResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkResetResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkResetServerLogical(ctx context.Context, s Server, logical *LogicalConn, req BulkResetRequest) (BulkResetResponse, error) {
|
||||||
|
if s == nil {
|
||||||
|
return BulkResetResponse{}, errBulkServerNil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return BulkResetResponse{}, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
msg, err := s.SendObjCtxLogical(ctx, logical, BulkResetSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkResetResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkResetResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkResetServerTransport(ctx context.Context, s Server, transport *TransportConn, req BulkResetRequest) (BulkResetResponse, error) {
|
||||||
|
if s == nil {
|
||||||
|
return BulkResetResponse{}, errBulkServerNil
|
||||||
|
}
|
||||||
|
if transport == nil {
|
||||||
|
return BulkResetResponse{}, errBulkTransportNil
|
||||||
|
}
|
||||||
|
msg, err := s.SendObjCtxTransport(ctx, transport, BulkResetSignalKey, req)
|
||||||
|
if err != nil {
|
||||||
|
return BulkResetResponse{}, err
|
||||||
|
}
|
||||||
|
return decodeBulkResetResponse(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkReleaseClient(c Client, req BulkReleaseRequest) error {
|
||||||
|
if c == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
return c.SendObj(BulkReleaseSignalKey, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkReleaseServerLogical(s Server, logical *LogicalConn, req BulkReleaseRequest) error {
|
||||||
|
if s == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
return s.SendObjLogical(logical, BulkReleaseSignalKey, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendBulkReleaseServerTransport(s Server, transport *TransportConn, req BulkReleaseRequest) error {
|
||||||
|
if s == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
if transport == nil {
|
||||||
|
return errBulkTransportNil
|
||||||
|
}
|
||||||
|
return s.SendObjTransport(transport, BulkReleaseSignalKey, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkOpenRequest(msg *Message) (BulkOpenRequest, error) {
|
||||||
|
var req BulkOpenRequest
|
||||||
|
if msg == nil {
|
||||||
|
return BulkOpenRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if err := msg.Value.Orm(&req); err != nil {
|
||||||
|
return BulkOpenRequest{}, err
|
||||||
|
}
|
||||||
|
req = normalizeBulkOpenRequest(req)
|
||||||
|
if req.BulkID == "" {
|
||||||
|
return BulkOpenRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if !validBulkRange(req.Range) {
|
||||||
|
return BulkOpenRequest{}, errBulkRangeInvalid
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkCloseRequest(msg *Message) (BulkCloseRequest, error) {
|
||||||
|
var req BulkCloseRequest
|
||||||
|
if msg == nil {
|
||||||
|
return BulkCloseRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if err := msg.Value.Orm(&req); err != nil {
|
||||||
|
return BulkCloseRequest{}, err
|
||||||
|
}
|
||||||
|
if req.BulkID == "" {
|
||||||
|
return BulkCloseRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkResetRequest(msg *Message) (BulkResetRequest, error) {
|
||||||
|
var req BulkResetRequest
|
||||||
|
if msg == nil {
|
||||||
|
return BulkResetRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if err := msg.Value.Orm(&req); err != nil {
|
||||||
|
return BulkResetRequest{}, err
|
||||||
|
}
|
||||||
|
if req.BulkID == "" && req.DataID == 0 {
|
||||||
|
return BulkResetRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkReleaseRequest(msg *Message) (BulkReleaseRequest, error) {
|
||||||
|
var req BulkReleaseRequest
|
||||||
|
if msg == nil {
|
||||||
|
return BulkReleaseRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if err := msg.Value.Orm(&req); err != nil {
|
||||||
|
return BulkReleaseRequest{}, err
|
||||||
|
}
|
||||||
|
if req.BulkID == "" && req.DataID == 0 {
|
||||||
|
return BulkReleaseRequest{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if req.Bytes < 0 || req.Chunks < 0 {
|
||||||
|
return BulkReleaseRequest{}, errBulkRangeInvalid
|
||||||
|
}
|
||||||
|
return req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkOpenResponse(msg Message) (BulkOpenResponse, error) {
|
||||||
|
var resp BulkOpenResponse
|
||||||
|
if err := msg.Value.Orm(&resp); err != nil {
|
||||||
|
return BulkOpenResponse{}, err
|
||||||
|
}
|
||||||
|
return resp, bulkControlResultError("open", resp.Accepted, resp.Error, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkCloseResponse(msg Message) (BulkCloseResponse, error) {
|
||||||
|
var resp BulkCloseResponse
|
||||||
|
if err := msg.Value.Orm(&resp); err != nil {
|
||||||
|
return BulkCloseResponse{}, err
|
||||||
|
}
|
||||||
|
return resp, bulkControlResultError("close", resp.Accepted, resp.Error, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkResetResponse(msg Message) (BulkResetResponse, error) {
|
||||||
|
var resp BulkResetResponse
|
||||||
|
if err := msg.Value.Orm(&resp); err != nil {
|
||||||
|
return BulkResetResponse{}, err
|
||||||
|
}
|
||||||
|
return resp, bulkControlResultError("reset", resp.Accepted, resp.Error, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkControlResultError(op string, accepted bool, message string, callErr error) error {
|
||||||
|
if callErr != nil {
|
||||||
|
return callErr
|
||||||
|
}
|
||||||
|
if message != "" {
|
||||||
|
return bulkControlMessageError(message)
|
||||||
|
}
|
||||||
|
if accepted {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if op == "open" {
|
||||||
|
return errBulkRejected
|
||||||
|
}
|
||||||
|
return errors.New("bulk " + op + " rejected")
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkControlMessageError(message string) error {
|
||||||
|
switch message {
|
||||||
|
case errBulkNotFound.Error():
|
||||||
|
return errBulkNotFound
|
||||||
|
case errBulkAlreadyExists.Error():
|
||||||
|
return errBulkAlreadyExists
|
||||||
|
case errBulkHandlerNotConfigured.Error():
|
||||||
|
return errBulkHandlerNotConfigured
|
||||||
|
case errBulkLogicalConnNil.Error():
|
||||||
|
return errBulkLogicalConnNil
|
||||||
|
case errBulkTransportNil.Error():
|
||||||
|
return errBulkTransportNil
|
||||||
|
case errBulkRuntimeNil.Error():
|
||||||
|
return errBulkRuntimeNil
|
||||||
|
case errBulkIDEmpty.Error():
|
||||||
|
return errBulkIDEmpty
|
||||||
|
case errBulkRangeInvalid.Error():
|
||||||
|
return errBulkRangeInvalid
|
||||||
|
case errBulkDataIDEmpty.Error():
|
||||||
|
return errBulkDataIDEmpty
|
||||||
|
default:
|
||||||
|
return errors.New(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkRemoteResetError(message string) error {
|
||||||
|
if message == "" {
|
||||||
|
return errBulkReset
|
||||||
|
}
|
||||||
|
return errors.New(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkTransportGeneration(logical *LogicalConn, transport *TransportConn) uint64 {
|
||||||
|
return streamTransportGeneration(logical, transport)
|
||||||
|
}
|
||||||
@@ -0,0 +1,723 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/notify/internal/transport"
|
||||||
|
"b612.me/stario"
|
||||||
|
"context"
|
||||||
|
cryptorand "crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
systemBulkAttachKey = "_notify_bulk_attach"
|
||||||
|
bulkDedicatedRecordMagic = "NBR1"
|
||||||
|
bulkDedicatedRecordHeaderLen = 8
|
||||||
|
bulkDedicatedAttachTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type bulkAttachRequest struct {
|
||||||
|
PeerID string
|
||||||
|
BulkID string
|
||||||
|
AttachToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkAttachResponse struct {
|
||||||
|
Accepted bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBulkAttachToken() string {
|
||||||
|
var buf [16]byte
|
||||||
|
if _, err := cryptorand.Read(buf[:]); err == nil {
|
||||||
|
return hex.EncodeToString(buf[:])
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkAttachRequest(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachRequest, error) {
|
||||||
|
var req bulkAttachRequest
|
||||||
|
if decodeFn == nil {
|
||||||
|
decodeFn = Decode
|
||||||
|
}
|
||||||
|
raw := []byte(data)
|
||||||
|
value, err := decodeFn(raw)
|
||||||
|
if err != nil {
|
||||||
|
return req, err
|
||||||
|
}
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case bulkAttachRequest:
|
||||||
|
return typed, nil
|
||||||
|
case *bulkAttachRequest:
|
||||||
|
if typed == nil {
|
||||||
|
return req, errors.New("bulk attach request is nil")
|
||||||
|
}
|
||||||
|
return *typed, nil
|
||||||
|
default:
|
||||||
|
return req, errors.New("invalid bulk attach payload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkAttachResponse(decodeFn func([]byte) (interface{}, error), data MsgVal) (bulkAttachResponse, error) {
|
||||||
|
var resp bulkAttachResponse
|
||||||
|
if decodeFn == nil {
|
||||||
|
decodeFn = Decode
|
||||||
|
}
|
||||||
|
raw := []byte(data)
|
||||||
|
value, err := decodeFn(raw)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
switch typed := value.(type) {
|
||||||
|
case bulkAttachResponse:
|
||||||
|
return typed, nil
|
||||||
|
case *bulkAttachResponse:
|
||||||
|
if typed == nil {
|
||||||
|
return resp, errors.New("bulk attach response is nil")
|
||||||
|
}
|
||||||
|
return *typed, nil
|
||||||
|
default:
|
||||||
|
return resp, errors.New("invalid bulk attach response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeDirectSignalFrame(queue *stario.StarQueue, sequenceEn func(interface{}) ([]byte, error), msgEn func([]byte, []byte) []byte, secretKey []byte, msg TransferMsg) ([]byte, error) {
|
||||||
|
if queue == nil {
|
||||||
|
queue = stario.NewQueue()
|
||||||
|
}
|
||||||
|
env, err := wrapTransferMsgEnvelope(msg, sequenceEn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
plain, err := sequenceEn(env)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
payload := msgEn(secretKey, plain)
|
||||||
|
if payload == nil && len(plain) != 0 {
|
||||||
|
return nil, errTransportPayloadEncryptFailed
|
||||||
|
}
|
||||||
|
return queue.BuildMessage(payload), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeDirectSignalPayload(sequenceDe func([]byte) (interface{}, error), msgDe func([]byte, []byte) []byte, secretKey []byte, payload []byte) (TransferMsg, error) {
|
||||||
|
plain := msgDe(secretKey, payload)
|
||||||
|
if plain == nil && len(payload) != 0 {
|
||||||
|
return TransferMsg{}, errTransportPayloadDecryptFailed
|
||||||
|
}
|
||||||
|
value, err := sequenceDe(plain)
|
||||||
|
if err != nil {
|
||||||
|
return TransferMsg{}, err
|
||||||
|
}
|
||||||
|
env, ok := value.(Envelope)
|
||||||
|
if !ok {
|
||||||
|
return TransferMsg{}, errors.New("invalid signal envelope")
|
||||||
|
}
|
||||||
|
return unwrapTransferMsgEnvelope(env, sequenceDe)
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeBulkDedicatedRecord(conn net.Conn, payload []byte) error {
|
||||||
|
return writeBulkDedicatedRecordWithDeadline(conn, payload, time.Time{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeBulkDedicatedRecordWithDeadline(conn net.Conn, payload []byte, deadline time.Time) error {
|
||||||
|
if conn == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
return withRawConnWriteLockDeadline(conn, deadline, func(conn net.Conn) error {
|
||||||
|
var header [bulkDedicatedRecordHeaderLen]byte
|
||||||
|
copy(header[:4], bulkDedicatedRecordMagic)
|
||||||
|
binary.BigEndian.PutUint32(header[4:8], uint32(len(payload)))
|
||||||
|
buffers := net.Buffers{header[:], payload}
|
||||||
|
_, err := buffers.WriteTo(conn)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func readBulkDedicatedRecord(conn net.Conn) ([]byte, error) {
|
||||||
|
if conn == nil {
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
var header [bulkDedicatedRecordHeaderLen]byte
|
||||||
|
if _, err := io.ReadFull(conn, header[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if string(header[:4]) != bulkDedicatedRecordMagic {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
size := int(binary.BigEndian.Uint32(header[4:8]))
|
||||||
|
if size < 0 {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
payload := make([]byte, size)
|
||||||
|
if _, err := io.ReadFull(conn, payload); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) dialDedicatedBulkConn(ctx context.Context) (net.Conn, error) {
|
||||||
|
source := c.clientConnectSourceSnapshot()
|
||||||
|
if source != nil && source.canReconnect() {
|
||||||
|
return source.dial(ctx)
|
||||||
|
}
|
||||||
|
conn := c.clientTransportConnSnapshot()
|
||||||
|
if conn == nil || conn.RemoteAddr() == nil {
|
||||||
|
return nil, errClientReconnectSourceUnavailable
|
||||||
|
}
|
||||||
|
return transport.Dial(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) attachDedicatedBulkSidecar(ctx context.Context, bulk *bulkHandle) error {
|
||||||
|
if c == nil || bulk == nil || !bulk.Dedicated() || bulk.dedicatedAttachedSnapshot() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, bulkDedicatedAttachTimeout)
|
||||||
|
defer cancel()
|
||||||
|
conn, err := c.dialDedicatedBulkConn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
resp, err := c.sendDedicatedBulkAttachRequest(ctx, conn, bulk)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !resp.Accepted {
|
||||||
|
_ = conn.Close()
|
||||||
|
if resp.Error != "" {
|
||||||
|
return errors.New(resp.Error)
|
||||||
|
}
|
||||||
|
return errors.New("bulk attach rejected")
|
||||||
|
}
|
||||||
|
if err := bulk.attachDedicatedConn(conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go c.readDedicatedBulkLoop(bulk, conn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendDedicatedBulkAttachRequest(ctx context.Context, conn net.Conn, bulk *bulkHandle) (bulkAttachResponse, error) {
|
||||||
|
if c == nil {
|
||||||
|
return bulkAttachResponse{}, errBulkClientNil
|
||||||
|
}
|
||||||
|
if bulk == nil {
|
||||||
|
return bulkAttachResponse{}, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = conn.SetReadDeadline(time.Time{})
|
||||||
|
}()
|
||||||
|
reqPayload, err := c.sequenceEn(bulkAttachRequest{
|
||||||
|
PeerID: c.ensureClientPeerIdentity(),
|
||||||
|
BulkID: bulk.ID(),
|
||||||
|
AttachToken: bulk.dedicatedAttachTokenSnapshot(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return bulkAttachResponse{}, err
|
||||||
|
}
|
||||||
|
queue := stario.NewQueue()
|
||||||
|
msg := TransferMsg{
|
||||||
|
ID: atomic.AddUint64(&c.msgID, 1),
|
||||||
|
Key: systemBulkAttachKey,
|
||||||
|
Value: reqPayload,
|
||||||
|
Type: MSG_SYS_WAIT,
|
||||||
|
}
|
||||||
|
frame, err := encodeDirectSignalFrame(queue, c.sequenceEn, c.msgEn, c.SecretKey, msg)
|
||||||
|
if err != nil {
|
||||||
|
return bulkAttachResponse{}, err
|
||||||
|
}
|
||||||
|
if err := writeFullToConn(conn, frame); err != nil {
|
||||||
|
return bulkAttachResponse{}, err
|
||||||
|
}
|
||||||
|
replyCh := make(chan Message, 1)
|
||||||
|
readBuf := streamReadBuffer()
|
||||||
|
for {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
_ = conn.SetReadDeadline(deadline)
|
||||||
|
}
|
||||||
|
n, err := conn.Read(readBuf)
|
||||||
|
if err != nil {
|
||||||
|
return bulkAttachResponse{}, err
|
||||||
|
}
|
||||||
|
parseErr := queue.ParseMessageOwned(readBuf[:n], "bulk-attach", func(msgq stario.MsgQueue) error {
|
||||||
|
transfer, err := decodeDirectSignalPayload(c.sequenceDe, c.msgDe, c.SecretKey, msgq.Msg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
replyCh <- Message{
|
||||||
|
ServerConn: c,
|
||||||
|
TransferMsg: transfer,
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if parseErr != nil {
|
||||||
|
return bulkAttachResponse{}, parseErr
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case reply := <-replyCh:
|
||||||
|
return decodeBulkAttachResponse(c.sequenceDe, reply.Value)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readDedicatedBulkLoop(bulk *bulkHandle, conn net.Conn) {
|
||||||
|
for {
|
||||||
|
payload, err := readBulkDedicatedRecord(conn)
|
||||||
|
if err != nil {
|
||||||
|
handleDedicatedBulkReadError(bulk, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
plain, err := c.decryptTransportPayload(payload)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
|
||||||
|
bulk.markReset(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
|
||||||
|
bulk.markReset(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
_ = c.sendDedicatedBulkReset(context.Background(), bulk, err.Error())
|
||||||
|
bulk.markReset(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if bulk.Context().Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) handleBulkAttachSystemMessage(message Message) bool {
|
||||||
|
if message.Key != systemBulkAttachKey {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
current := messageLogicalConnSnapshot(&message)
|
||||||
|
resp := bulkAttachResponse{}
|
||||||
|
var (
|
||||||
|
req bulkAttachRequest
|
||||||
|
logical *LogicalConn
|
||||||
|
bulk *bulkHandle
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
req, err = decodeBulkAttachRequest(s.sequenceDe, message.Value)
|
||||||
|
if err == nil {
|
||||||
|
logical, bulk, err = s.resolveInboundDedicatedBulk(current, req)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
resp.Error = err.Error()
|
||||||
|
} else {
|
||||||
|
resp.Accepted = true
|
||||||
|
}
|
||||||
|
if current != nil {
|
||||||
|
_ = s.replyDedicatedBulkAttach(current, message, resp)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
if attachErr := s.finishInboundDedicatedBulkAttach(current, logical, bulk); attachErr != nil {
|
||||||
|
bulk.markReset(attachErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) resolveInboundDedicatedBulk(current *LogicalConn, req bulkAttachRequest) (*LogicalConn, *bulkHandle, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, nil, errBulkServerNil
|
||||||
|
}
|
||||||
|
if current == nil {
|
||||||
|
return nil, nil, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
if req.PeerID == "" || req.BulkID == "" || req.AttachToken == "" {
|
||||||
|
return nil, nil, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
logical := s.GetLogicalConn(req.PeerID)
|
||||||
|
if logical == nil {
|
||||||
|
return nil, nil, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
runtime := s.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return nil, nil, errBulkRuntimeNil
|
||||||
|
}
|
||||||
|
bulk, ok := runtime.lookup(serverFileScope(logical), req.BulkID)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, errBulkNotFound
|
||||||
|
}
|
||||||
|
if !bulk.Dedicated() {
|
||||||
|
return nil, nil, errors.New("bulk is not dedicated")
|
||||||
|
}
|
||||||
|
if bulk.dedicatedAttachTokenSnapshot() != req.AttachToken {
|
||||||
|
return nil, nil, errors.New("bulk attach token mismatch")
|
||||||
|
}
|
||||||
|
return logical, bulk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) finishInboundDedicatedBulkAttach(current *LogicalConn, logical *LogicalConn, bulk *bulkHandle) error {
|
||||||
|
if current == nil || logical == nil || bulk == nil {
|
||||||
|
return errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
conn, err := current.detachTransportForTransfer()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := bulk.attachDedicatedConn(conn); err != nil {
|
||||||
|
if conn != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
go s.readDedicatedBulkLoop(logical, bulk, conn)
|
||||||
|
current.markSessionStopped("bulk dedicated attach", nil)
|
||||||
|
s.removeLogical(current)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) replyDedicatedBulkAttach(client *LogicalConn, message Message, resp bulkAttachResponse) error {
|
||||||
|
if s == nil || client == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
encoded, err := s.sequenceEn(resp)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reply := TransferMsg{
|
||||||
|
ID: message.ID,
|
||||||
|
Key: systemBulkAttachKey,
|
||||||
|
Value: encoded,
|
||||||
|
Type: MSG_SYS_REPLY,
|
||||||
|
}
|
||||||
|
if message.inboundConn != nil {
|
||||||
|
return s.sendTransferInbound(client, messageTransportConnSnapshot(&message), message.inboundConn, reply)
|
||||||
|
}
|
||||||
|
_, err = s.sendLogical(client, reply)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) readDedicatedBulkLoop(logical *LogicalConn, bulk *bulkHandle, conn net.Conn) {
|
||||||
|
for {
|
||||||
|
payload, err := readBulkDedicatedRecord(conn)
|
||||||
|
if err != nil {
|
||||||
|
handleDedicatedBulkReadError(bulk, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
plain, err := s.decryptTransportPayloadLogical(logical, payload)
|
||||||
|
if err != nil {
|
||||||
|
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
|
||||||
|
bulk.markReset(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
items, err := decodeDedicatedBulkInboundItems(bulk.dataIDSnapshot(), plain)
|
||||||
|
if err != nil {
|
||||||
|
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
|
||||||
|
bulk.markReset(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range items {
|
||||||
|
if err := dispatchDedicatedBulkInboundItem(bulk, item); err != nil {
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
_ = s.sendDedicatedBulkReset(context.Background(), logical, bulk, err.Error())
|
||||||
|
bulk.markReset(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if bulk.Context().Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleDedicatedBulkReadError(bulk *bulkHandle, err error) {
|
||||||
|
if bulk == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if bulk.Context().Err() != nil || bulk.remoteClosedSnapshot() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) {
|
||||||
|
if bulk.Dedicated() || bulk.localClosedSnapshot() {
|
||||||
|
bulk.markRemoteClosed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
bulk.markReset(transportDetachedError("dedicated bulk read error", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) dedicatedBulkSender(bulk *bulkHandle) (*bulkDedicatedSender, error) {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return nil, errBulkClientNil
|
||||||
|
}
|
||||||
|
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
|
||||||
|
return sender, nil
|
||||||
|
}
|
||||||
|
conn := bulk.dedicatedConnSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||||
|
}
|
||||||
|
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), c.encryptTransportPayload, func(items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
return c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), items)
|
||||||
|
}, func(err error) {
|
||||||
|
bulk.markReset(err)
|
||||||
|
})
|
||||||
|
actual := bulk.installDedicatedSender(sender)
|
||||||
|
if actual != sender {
|
||||||
|
sender.stop()
|
||||||
|
}
|
||||||
|
return actual, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendDedicatedBulkData(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
sender, err := c.dedicatedBulkSender(bulk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendDedicatedBulkWrite(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return 0, errBulkClientNil
|
||||||
|
}
|
||||||
|
sender, err := c.dedicatedBulkSender(bulk)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendDedicatedBulkClose(ctx context.Context, bulk *bulkHandle, full bool) error {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
flags := uint8(0)
|
||||||
|
if full {
|
||||||
|
flags = bulkFastPayloadFlagFullClose
|
||||||
|
}
|
||||||
|
sender, err := c.dedicatedBulkSender(bulk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendDedicatedBulkReset(ctx context.Context, bulk *bulkHandle, message string) error {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
sender, err := c.dedicatedBulkSender(bulk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendDedicatedBulkRelease(ctx context.Context, bulk *bulkHandle, bytes int64, chunks int) error {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
frame, err := c.encodeDedicatedBulkBatchPayload(bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
|
||||||
|
Type: bulkFastPayloadTypeRelease,
|
||||||
|
Payload: payload,
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn := bulk.dedicatedConnSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||||
|
}
|
||||||
|
deadline, _ := sendCtx.Deadline()
|
||||||
|
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) dedicatedBulkSender(logical *LogicalConn, bulk *bulkHandle) (*bulkDedicatedSender, error) {
|
||||||
|
if s == nil || bulk == nil {
|
||||||
|
return nil, errBulkServerNil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
logical = bulk.LogicalConn()
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return nil, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
if sender := bulk.dedicatedSenderSnapshot(); sender != nil {
|
||||||
|
return sender, nil
|
||||||
|
}
|
||||||
|
conn := bulk.dedicatedConnSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return nil, transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||||
|
}
|
||||||
|
sender := newBulkDedicatedSender(conn, bulk.dataIDSnapshot(), func(plain []byte) ([]byte, error) {
|
||||||
|
return s.encryptTransportPayloadLogical(logical, plain)
|
||||||
|
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
return s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), items)
|
||||||
|
}, func(err error) {
|
||||||
|
bulk.markReset(err)
|
||||||
|
})
|
||||||
|
actual := bulk.installDedicatedSender(sender)
|
||||||
|
if actual != sender {
|
||||||
|
sender.stop()
|
||||||
|
}
|
||||||
|
return actual, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendDedicatedBulkData(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, chunk []byte) error {
|
||||||
|
if s == nil || bulk == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sender.submitData(ctx, bulk.nextOutboundDataSeq(), chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendDedicatedBulkWrite(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, payload []byte) (int, error) {
|
||||||
|
if s == nil || bulk == nil {
|
||||||
|
return 0, errBulkServerNil
|
||||||
|
}
|
||||||
|
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return sender.submitWrite(ctx, bulk.nextOutboundDataSeq(), payload, bulk.chunkSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendDedicatedBulkClose(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, full bool) error {
|
||||||
|
if s == nil || bulk == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
flags := uint8(0)
|
||||||
|
if full {
|
||||||
|
flags = bulkFastPayloadFlagFullClose
|
||||||
|
}
|
||||||
|
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sender.submitControl(sendCtx, bulkFastPayloadTypeClose, flags, 0, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendDedicatedBulkReset(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, message string) error {
|
||||||
|
if s == nil || bulk == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
sender, err := s.dedicatedBulkSender(logical, bulk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sender.submitControl(sendCtx, bulkFastPayloadTypeReset, 0, 0, []byte(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendDedicatedBulkRelease(ctx context.Context, logical *LogicalConn, bulk *bulkHandle, bytes int64, chunks int) error {
|
||||||
|
if s == nil || bulk == nil {
|
||||||
|
return errBulkServerNil
|
||||||
|
}
|
||||||
|
payload, err := encodeBulkDedicatedReleasePayload(bytes, chunks)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sendCtx, cancel, err := bulkWriteContext(ctx, bulk.writeTimeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cancel()
|
||||||
|
frame, err := s.encodeDedicatedBulkBatchPayload(logical, bulk.dataIDSnapshot(), []bulkDedicatedSendRequest{{
|
||||||
|
Type: bulkFastPayloadTypeRelease,
|
||||||
|
Payload: payload,
|
||||||
|
}})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn := bulk.dedicatedConnSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return transportDetachedError("dedicated bulk sidecar not attached", nil)
|
||||||
|
}
|
||||||
|
deadline, _ := sendCtx.Deadline()
|
||||||
|
return writeBulkDedicatedRecordWithDeadline(conn, frame, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) encodeDedicatedBulkBatchPayload(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errBulkClientNil
|
||||||
|
}
|
||||||
|
if c.fastPlainEncode != nil {
|
||||||
|
return encodeBulkDedicatedBatchPayloadFast(c.fastPlainEncode, c.SecretKey, dataID, items)
|
||||||
|
}
|
||||||
|
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.encryptTransportPayload(plain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) encodeDedicatedBulkBatchPayload(logical *LogicalConn, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, errBulkServerNil
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return nil, errBulkLogicalConnNil
|
||||||
|
}
|
||||||
|
if fastPlainEncode := logical.fastPlainEncodeSnapshot(); fastPlainEncode != nil {
|
||||||
|
return encodeBulkDedicatedBatchPayloadFast(fastPlainEncode, logical.secretKeySnapshot(), dataID, items)
|
||||||
|
}
|
||||||
|
plain, err := encodeBulkDedicatedBatchPlain(dataID, items)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.encryptTransportPayloadLogical(logical, plain)
|
||||||
|
}
|
||||||
@@ -0,0 +1,663 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bulkDedicatedBatchMagic = "NBD2"
|
||||||
|
bulkDedicatedBatchVersion = 1
|
||||||
|
bulkDedicatedBatchHeaderLen = 20
|
||||||
|
bulkDedicatedBatchItemHeaderLen = 16
|
||||||
|
bulkDedicatedBatchMaxItems = 32
|
||||||
|
bulkDedicatedBatchMaxPlainBytes = 8 * 1024 * 1024
|
||||||
|
bulkDedicatedSendQueueSize = bulkDedicatedBatchMaxItems
|
||||||
|
bulkDedicatedReleasePayloadLen = 12
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
bulkDedicatedRequestQueued int32 = iota
|
||||||
|
bulkDedicatedRequestStarted
|
||||||
|
bulkDedicatedRequestCanceled
|
||||||
|
)
|
||||||
|
|
||||||
|
type bulkDedicatedRequestState struct {
|
||||||
|
value atomic.Int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkDedicatedBatchItem struct {
|
||||||
|
Type uint8
|
||||||
|
Flags uint8
|
||||||
|
Seq uint64
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkDedicatedSendRequest struct {
|
||||||
|
Type uint8
|
||||||
|
Flags uint8
|
||||||
|
Seq uint64
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkDedicatedBatchRequest struct {
|
||||||
|
Ctx context.Context
|
||||||
|
Items []bulkDedicatedSendRequest
|
||||||
|
Deadline time.Time
|
||||||
|
Ack chan error
|
||||||
|
State *bulkDedicatedRequestState
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkDedicatedSender struct {
|
||||||
|
conn net.Conn
|
||||||
|
dataID uint64
|
||||||
|
encrypt func([]byte) ([]byte, error)
|
||||||
|
encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error)
|
||||||
|
fail func(error)
|
||||||
|
|
||||||
|
reqCh chan bulkDedicatedBatchRequest
|
||||||
|
stopCh chan struct{}
|
||||||
|
doneCh chan struct{}
|
||||||
|
stopOnce sync.Once
|
||||||
|
flushMu sync.Mutex
|
||||||
|
queued atomic.Int64
|
||||||
|
|
||||||
|
errMu sync.Mutex
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBulkDedicatedSender(conn net.Conn, dataID uint64, encrypt func([]byte) ([]byte, error), encodeBatch func([]bulkDedicatedSendRequest) ([]byte, error), fail func(error)) *bulkDedicatedSender {
|
||||||
|
sender := &bulkDedicatedSender{
|
||||||
|
conn: conn,
|
||||||
|
dataID: dataID,
|
||||||
|
encrypt: encrypt,
|
||||||
|
encodeBatch: encodeBatch,
|
||||||
|
fail: fail,
|
||||||
|
reqCh: make(chan bulkDedicatedBatchRequest, bulkDedicatedSendQueueSize),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go sender.run()
|
||||||
|
return sender
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) submitData(ctx context.Context, seq uint64, payload []byte) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
items := []bulkDedicatedSendRequest{{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: append([]byte(nil), payload...),
|
||||||
|
}}
|
||||||
|
return s.submitBatch(ctx, items, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) submitWrite(ctx context.Context, startSeq uint64, payload []byte, chunkSize int) (int, error) {
|
||||||
|
if s == nil {
|
||||||
|
return 0, errTransportDetached
|
||||||
|
}
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if chunkSize <= 0 {
|
||||||
|
chunkSize = defaultBulkChunkSize
|
||||||
|
}
|
||||||
|
written := 0
|
||||||
|
seq := startSeq
|
||||||
|
for written < len(payload) {
|
||||||
|
var itemBuf [bulkDedicatedBatchMaxItems]bulkDedicatedSendRequest
|
||||||
|
items := itemBuf[:0]
|
||||||
|
batchBytes := bulkDedicatedBatchHeaderLen
|
||||||
|
start := written
|
||||||
|
for written < len(payload) && len(items) < bulkDedicatedBatchMaxItems {
|
||||||
|
end := written + chunkSize
|
||||||
|
if end > len(payload) {
|
||||||
|
end = len(payload)
|
||||||
|
}
|
||||||
|
itemLen := bulkDedicatedSendRequestLenFromPayloadLen(end - written)
|
||||||
|
if len(items) > 0 && batchBytes+itemLen > bulkDedicatedBatchMaxPlainBytes {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
items = append(items, bulkDedicatedSendRequest{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload[written:end],
|
||||||
|
})
|
||||||
|
batchBytes += itemLen
|
||||||
|
seq++
|
||||||
|
written = end
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
end := written + chunkSize
|
||||||
|
if end > len(payload) {
|
||||||
|
end = len(payload)
|
||||||
|
}
|
||||||
|
items = append(items, bulkDedicatedSendRequest{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload[written:end],
|
||||||
|
})
|
||||||
|
seq++
|
||||||
|
written = end
|
||||||
|
}
|
||||||
|
if err := s.submitWriteBatch(ctx, items); err != nil {
|
||||||
|
return start, err
|
||||||
|
}
|
||||||
|
start = written
|
||||||
|
}
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) submitWriteBatch(ctx context.Context, items []bulkDedicatedSendRequest) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if submitted, err := s.tryDirectSubmitBatch(ctx, items); submitted {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
queuedItems := make([]bulkDedicatedSendRequest, len(items))
|
||||||
|
copy(queuedItems, items)
|
||||||
|
return s.submitBatch(ctx, queuedItems, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) submitControl(ctx context.Context, frameType uint8, flags uint8, seq uint64, payload []byte) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
items := []bulkDedicatedSendRequest{{
|
||||||
|
Type: frameType,
|
||||||
|
Flags: flags,
|
||||||
|
Seq: seq,
|
||||||
|
}}
|
||||||
|
if len(payload) > 0 {
|
||||||
|
items[0].Payload = append([]byte(nil), payload...)
|
||||||
|
}
|
||||||
|
return s.submitBatch(ctx, items, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) submitBatch(ctx context.Context, items []bulkDedicatedSendRequest, wait bool) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req := bulkDedicatedBatchRequest{
|
||||||
|
Ctx: ctx,
|
||||||
|
Items: items,
|
||||||
|
State: &bulkDedicatedRequestState{},
|
||||||
|
}
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
req.Deadline = deadline
|
||||||
|
}
|
||||||
|
if wait {
|
||||||
|
req.Ack = make(chan error, 1)
|
||||||
|
}
|
||||||
|
s.queued.Add(1)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
s.queued.Add(-1)
|
||||||
|
return normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-s.stopCh:
|
||||||
|
s.queued.Add(-1)
|
||||||
|
return s.stoppedErr()
|
||||||
|
case s.reqCh <- req:
|
||||||
|
if !wait {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.waitAck(req)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) tryDirectSubmitBatch(ctx context.Context, items []bulkDedicatedSendRequest) (bool, error) {
|
||||||
|
if s == nil {
|
||||||
|
return true, errTransportDetached
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
if len(items) == 0 {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return true, normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-s.stopCh:
|
||||||
|
return true, s.stoppedErr()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if s.queued.Load() != 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if !s.flushMu.TryLock() {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
defer s.flushMu.Unlock()
|
||||||
|
if s.queued.Load() != 0 {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return true, normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-s.stopCh:
|
||||||
|
return true, s.stoppedErr()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
deadline, _ := ctx.Deadline()
|
||||||
|
if err := s.flush(items, deadline); err != nil {
|
||||||
|
err = normalizeDedicatedBulkSendError(err)
|
||||||
|
s.setErr(err)
|
||||||
|
s.failPending(err)
|
||||||
|
if s.fail != nil {
|
||||||
|
go s.fail(err)
|
||||||
|
}
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) waitAck(req bulkDedicatedBatchRequest) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
ctx := req.Ctx
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-req.Ack:
|
||||||
|
return normalizeDedicatedBulkSendError(err)
|
||||||
|
case <-ctx.Done():
|
||||||
|
if req.tryCancel() {
|
||||||
|
return normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
}
|
||||||
|
return normalizeDedicatedBulkSendError(<-req.Ack)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
s.setErr(errTransportDetached)
|
||||||
|
close(s.stopCh)
|
||||||
|
})
|
||||||
|
<-s.doneCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) run() {
|
||||||
|
defer close(s.doneCh)
|
||||||
|
|
||||||
|
for {
|
||||||
|
req, ok := s.nextRequest()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.tryStart() {
|
||||||
|
s.finishRequest(req, req.canceledErr())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := req.contextErr(); err != nil {
|
||||||
|
s.finishRequest(req, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s.flushMu.Lock()
|
||||||
|
err := s.errSnapshot()
|
||||||
|
if err == nil {
|
||||||
|
err = s.flush(req.Items, req.Deadline)
|
||||||
|
}
|
||||||
|
s.flushMu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
err = normalizeDedicatedBulkSendError(err)
|
||||||
|
s.setErr(err)
|
||||||
|
s.finishRequest(req, err)
|
||||||
|
s.failPending(err)
|
||||||
|
if s.fail != nil {
|
||||||
|
go s.fail(err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.finishRequest(req, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkDedicatedBatchRequest) contextErr() error {
|
||||||
|
if r.Ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-r.Ctx.Done():
|
||||||
|
return normalizeStreamDeadlineError(r.Ctx.Err())
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkDedicatedBatchRequest) tryStart() bool {
|
||||||
|
if r.State == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestStarted)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkDedicatedBatchRequest) tryCancel() bool {
|
||||||
|
if r.State == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return r.State.value.CompareAndSwap(bulkDedicatedRequestQueued, bulkDedicatedRequestCanceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r bulkDedicatedBatchRequest) canceledErr() error {
|
||||||
|
if err := r.contextErr(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return context.Canceled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) nextRequest() (bulkDedicatedBatchRequest, bool) {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
s.failPending(s.stoppedErr())
|
||||||
|
return bulkDedicatedBatchRequest{}, false
|
||||||
|
case req := <-s.reqCh:
|
||||||
|
return req, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) flush(batch []bulkDedicatedSendRequest, deadline time.Time) error {
|
||||||
|
if s == nil || s.conn == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
payload []byte
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if s.encodeBatch != nil {
|
||||||
|
payload, err = s.encodeBatch(batch)
|
||||||
|
} else {
|
||||||
|
plain, plainErr := encodeBulkDedicatedBatchPlain(s.dataID, batch)
|
||||||
|
if plainErr != nil {
|
||||||
|
return plainErr
|
||||||
|
}
|
||||||
|
payload, err = s.encrypt(plain)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return writeBulkDedicatedRecordWithDeadline(s.conn, payload, deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) ack(req bulkDedicatedBatchRequest, err error) {
|
||||||
|
if req.Ack != nil {
|
||||||
|
req.Ack <- err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) finishRequest(req bulkDedicatedBatchRequest, err error) {
|
||||||
|
if s != nil {
|
||||||
|
s.queued.Add(-1)
|
||||||
|
}
|
||||||
|
s.ack(req, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) failPending(err error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case item := <-s.reqCh:
|
||||||
|
s.finishRequest(item, err)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) setErr(err error) {
|
||||||
|
if s == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.errMu.Lock()
|
||||||
|
if s.err == nil {
|
||||||
|
s.err = err
|
||||||
|
}
|
||||||
|
s.errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) errSnapshot() error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
s.errMu.Lock()
|
||||||
|
defer s.errMu.Unlock()
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *bulkDedicatedSender) stoppedErr() error {
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkDedicatedSendRequestLen(req bulkDedicatedSendRequest) int {
|
||||||
|
return bulkDedicatedSendRequestLenFromPayloadLen(len(req.Payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkDedicatedSendRequestLenFromPayloadLen(payloadLen int) int {
|
||||||
|
return bulkDedicatedBatchItemHeaderLen + payloadLen
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBulkDedicatedReleasePayload(bytes int64, chunks int) ([]byte, error) {
|
||||||
|
if bytes <= 0 && chunks <= 0 {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if chunks < 0 {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
payload := make([]byte, bulkDedicatedReleasePayloadLen)
|
||||||
|
binary.BigEndian.PutUint64(payload[:8], uint64(bytes))
|
||||||
|
binary.BigEndian.PutUint32(payload[8:12], uint32(chunks))
|
||||||
|
return payload, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkDedicatedReleasePayload(payload []byte) (int64, int, error) {
|
||||||
|
if len(payload) != bulkDedicatedReleasePayloadLen {
|
||||||
|
return 0, 0, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
bytes := int64(binary.BigEndian.Uint64(payload[:8]))
|
||||||
|
chunks := int(binary.BigEndian.Uint32(payload[8:12]))
|
||||||
|
if bytes <= 0 && chunks <= 0 {
|
||||||
|
return 0, 0, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
return bytes, chunks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBulkDedicatedBatchPlain(dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
if dataID == 0 || len(items) == 0 {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
total := bulkDedicatedBatchPlainLen(items)
|
||||||
|
buf := make([]byte, total)
|
||||||
|
if err := writeBulkDedicatedBatchPlain(buf, dataID, items); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBulkDedicatedBatchPayloadFast(encode transportFastPlainEncoder, secretKey []byte, dataID uint64, items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
if encode == nil {
|
||||||
|
return nil, errTransportPayloadEncryptFailed
|
||||||
|
}
|
||||||
|
plainLen := bulkDedicatedBatchPlainLen(items)
|
||||||
|
return encode(secretKey, plainLen, func(dst []byte) error {
|
||||||
|
return writeBulkDedicatedBatchPlain(dst, dataID, items)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkDedicatedBatchPlainLen(items []bulkDedicatedSendRequest) int {
|
||||||
|
total := bulkDedicatedBatchHeaderLen
|
||||||
|
for _, item := range items {
|
||||||
|
total += bulkDedicatedSendRequestLen(item)
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeBulkDedicatedBatchPlain(buf []byte, dataID uint64, items []bulkDedicatedSendRequest) error {
|
||||||
|
if dataID == 0 || len(items) == 0 {
|
||||||
|
return errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if len(buf) != bulkDedicatedBatchPlainLen(items) {
|
||||||
|
return errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
copy(buf[:4], bulkDedicatedBatchMagic)
|
||||||
|
buf[4] = bulkDedicatedBatchVersion
|
||||||
|
binary.BigEndian.PutUint64(buf[8:16], dataID)
|
||||||
|
binary.BigEndian.PutUint32(buf[16:20], uint32(len(items)))
|
||||||
|
offset := bulkDedicatedBatchHeaderLen
|
||||||
|
for _, item := range items {
|
||||||
|
buf[offset] = item.Type
|
||||||
|
buf[offset+1] = item.Flags
|
||||||
|
binary.BigEndian.PutUint64(buf[offset+4:offset+12], item.Seq)
|
||||||
|
binary.BigEndian.PutUint32(buf[offset+12:offset+16], uint32(len(item.Payload)))
|
||||||
|
offset += bulkDedicatedBatchItemHeaderLen
|
||||||
|
copy(buf[offset:offset+len(item.Payload)], item.Payload)
|
||||||
|
offset += len(item.Payload)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkDedicatedBatchPlain(payload []byte) (uint64, []bulkDedicatedBatchItem, bool, error) {
|
||||||
|
if len(payload) < 4 || string(payload[:4]) != bulkDedicatedBatchMagic {
|
||||||
|
return 0, nil, false, nil
|
||||||
|
}
|
||||||
|
if len(payload) < bulkDedicatedBatchHeaderLen {
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if payload[4] != bulkDedicatedBatchVersion {
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
dataID := binary.BigEndian.Uint64(payload[8:16])
|
||||||
|
count := int(binary.BigEndian.Uint32(payload[16:20]))
|
||||||
|
if dataID == 0 || count <= 0 {
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
items := make([]bulkDedicatedBatchItem, 0, count)
|
||||||
|
offset := bulkDedicatedBatchHeaderLen
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
if len(payload)-offset < bulkDedicatedBatchItemHeaderLen {
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
itemType := payload[offset]
|
||||||
|
switch itemType {
|
||||||
|
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
|
||||||
|
default:
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
flags := payload[offset+1]
|
||||||
|
seq := binary.BigEndian.Uint64(payload[offset+4 : offset+12])
|
||||||
|
dataLen := int(binary.BigEndian.Uint32(payload[offset+12 : offset+16]))
|
||||||
|
offset += bulkDedicatedBatchItemHeaderLen
|
||||||
|
if dataLen < 0 || len(payload)-offset < dataLen {
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
items = append(items, bulkDedicatedBatchItem{
|
||||||
|
Type: itemType,
|
||||||
|
Flags: flags,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload[offset : offset+dataLen],
|
||||||
|
})
|
||||||
|
offset += dataLen
|
||||||
|
}
|
||||||
|
if offset != len(payload) {
|
||||||
|
return 0, nil, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
return dataID, items, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeDedicatedBulkInboundItems(expectedDataID uint64, plain []byte) ([]bulkDedicatedBatchItem, error) {
|
||||||
|
if dataID, items, matched, err := decodeBulkDedicatedBatchPlain(plain); matched {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if expectedDataID == 0 || dataID != expectedDataID {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
frame, matched, err := decodeBulkFastFrame(plain)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !matched || expectedDataID == 0 || frame.DataID != expectedDataID {
|
||||||
|
return nil, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
return []bulkDedicatedBatchItem{{
|
||||||
|
Type: frame.Type,
|
||||||
|
Flags: frame.Flags,
|
||||||
|
Seq: frame.Seq,
|
||||||
|
Payload: frame.Payload,
|
||||||
|
}}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeDedicatedBulkSendError(err error) error {
|
||||||
|
switch {
|
||||||
|
case err == nil:
|
||||||
|
return nil
|
||||||
|
case errors.Is(err, net.ErrClosed):
|
||||||
|
return errTransportDetached
|
||||||
|
default:
|
||||||
|
return normalizeStreamDeadlineError(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dispatchDedicatedBulkInboundItem(bulk *bulkHandle, item bulkDedicatedBatchItem) error {
|
||||||
|
if bulk == nil {
|
||||||
|
return io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
switch item.Type {
|
||||||
|
case bulkFastPayloadTypeData:
|
||||||
|
return bulk.pushOwnedChunkNoReset(item.Payload)
|
||||||
|
case bulkFastPayloadTypeClose:
|
||||||
|
if item.Flags&bulkFastPayloadFlagFullClose != 0 {
|
||||||
|
bulk.markPeerClosed()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
bulk.markRemoteClosed()
|
||||||
|
return nil
|
||||||
|
case bulkFastPayloadTypeReset:
|
||||||
|
resetErr := errBulkReset
|
||||||
|
if len(item.Payload) > 0 {
|
||||||
|
resetErr = bulkRemoteResetError(string(item.Payload))
|
||||||
|
}
|
||||||
|
bulk.markReset(bulkResetError(resetErr))
|
||||||
|
return nil
|
||||||
|
case bulkFastPayloadTypeRelease:
|
||||||
|
bytes, chunks, err := decodeBulkDedicatedReleasePayload(item.Payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bulk.releaseOutboundWindow(bytes, chunks)
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,296 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBulkDedicatedBatchPlainRoundTrip(t *testing.T) {
|
||||||
|
releasePayload, err := encodeBulkDedicatedReleasePayload(4096, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encodeBulkDedicatedReleasePayload failed: %v", err)
|
||||||
|
}
|
||||||
|
items := []bulkDedicatedSendRequest{
|
||||||
|
{
|
||||||
|
Type: bulkFastPayloadTypeData,
|
||||||
|
Seq: 7,
|
||||||
|
Payload: []byte("hello"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: bulkFastPayloadTypeClose,
|
||||||
|
Flags: bulkFastPayloadFlagFullClose,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: bulkFastPayloadTypeReset,
|
||||||
|
Payload: []byte("boom"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Type: bulkFastPayloadTypeRelease,
|
||||||
|
Payload: releasePayload,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
plain, err := encodeBulkDedicatedBatchPlain(42, items)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encodeBulkDedicatedBatchPlain failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dataID, decoded, matched, err := decodeBulkDedicatedBatchPlain(plain)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("decodeBulkDedicatedBatchPlain failed: %v", err)
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Fatal("decodeBulkDedicatedBatchPlain should match dedicated batch")
|
||||||
|
}
|
||||||
|
if dataID != 42 {
|
||||||
|
t.Fatalf("decoded data id = %d, want 42", dataID)
|
||||||
|
}
|
||||||
|
if len(decoded) != len(items) {
|
||||||
|
t.Fatalf("decoded item count = %d, want %d", len(decoded), len(items))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range items {
|
||||||
|
if decoded[i].Type != items[i].Type {
|
||||||
|
t.Fatalf("item %d type = %d, want %d", i, decoded[i].Type, items[i].Type)
|
||||||
|
}
|
||||||
|
if decoded[i].Flags != items[i].Flags {
|
||||||
|
t.Fatalf("item %d flags = %d, want %d", i, decoded[i].Flags, items[i].Flags)
|
||||||
|
}
|
||||||
|
if decoded[i].Seq != items[i].Seq {
|
||||||
|
t.Fatalf("item %d seq = %d, want %d", i, decoded[i].Seq, items[i].Seq)
|
||||||
|
}
|
||||||
|
if got, want := string(decoded[i].Payload), string(items[i].Payload); got != want {
|
||||||
|
t.Fatalf("item %d payload = %q, want %q", i, got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBulkOpenRoundTripDedicatedMultiWriteTCP(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
acceptCh := make(chan BulkAcceptInfo, 1)
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
acceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||||
|
t.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||||
|
t.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: 1024,
|
||||||
|
},
|
||||||
|
Dedicated: true,
|
||||||
|
ChunkSize: 4,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client OpenBulk dedicated failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
accepted := waitAcceptedBulk(t, acceptCh, 2*time.Second)
|
||||||
|
|
||||||
|
clientParts := []string{"aa", "bb", "cc", "dd", "ee", "ff"}
|
||||||
|
for _, part := range clientParts {
|
||||||
|
if _, err := bulk.Write([]byte(part)); err != nil {
|
||||||
|
t.Fatalf("client dedicated bulk Write(%q) failed: %v", part, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
readBulkExactly(t, accepted.Bulk, "aabbccddeeff", 2*time.Second)
|
||||||
|
|
||||||
|
serverParts := []string{"11", "22", "33", "44", "55", "66"}
|
||||||
|
for _, part := range serverParts {
|
||||||
|
if _, err := accepted.Bulk.Write([]byte(part)); err != nil {
|
||||||
|
t.Fatalf("server dedicated bulk Write(%q) failed: %v", part, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
readBulkExactly(t, bulk, "112233445566", 2*time.Second)
|
||||||
|
|
||||||
|
if err := bulk.CloseWrite(); err != nil {
|
||||||
|
t.Fatalf("client dedicated bulk CloseWrite failed: %v", err)
|
||||||
|
}
|
||||||
|
waitForBulkReadEOF(t, accepted.Bulk, 2*time.Second)
|
||||||
|
|
||||||
|
if err := accepted.Bulk.Close(); err != nil {
|
||||||
|
t.Fatalf("server dedicated bulk Close failed: %v", err)
|
||||||
|
}
|
||||||
|
waitForBulkReadEOF(t, bulk, 2*time.Second)
|
||||||
|
waitForBulkContextDone(t, bulk.Context(), 2*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBulkDedicatedSenderRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) {
|
||||||
|
return plain, nil
|
||||||
|
}, nil, nil)
|
||||||
|
defer sender.stop()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 0, nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("sender.submitControl should fail when receiver stalls")
|
||||||
|
}
|
||||||
|
if !isTimeoutLikeError(err) {
|
||||||
|
t.Fatalf("sender.submitControl error = %v, want timeout-like error", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("sender.submitControl should not hang when receiver stalls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBulkDedicatedSenderSubmitWriteDirectPathRespectsWriteDeadlineWhenReceiverStalls(t *testing.T) {
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
sender := newBulkDedicatedSender(left, 1, func(plain []byte) ([]byte, error) {
|
||||||
|
return plain, nil
|
||||||
|
}, nil, nil)
|
||||||
|
defer sender.stop()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
payload := make([]byte, 256*1024)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := sender.submitWrite(ctx, 1, payload, len(payload))
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("sender.submitWrite should fail when receiver stalls")
|
||||||
|
}
|
||||||
|
if !isTimeoutLikeError(err) {
|
||||||
|
t.Fatalf("sender.submitWrite error = %v, want timeout-like error", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("sender.submitWrite should not hang when receiver stalls")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBulkDedicatedSenderSkipsQueuedCanceledRequest(t *testing.T) {
|
||||||
|
conn := newBlockingPacketWriteConn()
|
||||||
|
sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) {
|
||||||
|
return plain, nil
|
||||||
|
}, nil, nil)
|
||||||
|
defer sender.stop()
|
||||||
|
|
||||||
|
firstErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
firstErrCh <- sender.submitControl(context.Background(), bulkFastPayloadTypeClose, 0, 1, nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-conn.startCh:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("first dedicated bulk write did not start")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
secondErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
secondErrCh <- sender.submitControl(ctx, bulkFastPayloadTypeReset, 0, 2, nil)
|
||||||
|
}()
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-secondErrCh:
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("second dedicated bulk submit error = %v, want %v", err, context.Canceled)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("second dedicated bulk submit did not return after cancel")
|
||||||
|
}
|
||||||
|
|
||||||
|
close(conn.unblockCh)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-firstErrCh:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("first dedicated bulk submit failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("first dedicated bulk submit did not finish")
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
if got, want := conn.writeCount.Load(), int32(2); got != want {
|
||||||
|
t.Fatalf("dedicated bulk write count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBulkDedicatedSenderReturnsFlushResultAfterStartedContextCancel(t *testing.T) {
|
||||||
|
conn := newBlockingPacketWriteConn()
|
||||||
|
sender := newBulkDedicatedSender(conn, 1, func(plain []byte) ([]byte, error) {
|
||||||
|
return plain, nil
|
||||||
|
}, nil, nil)
|
||||||
|
defer sender.stop()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
errCh <- sender.submitControl(ctx, bulkFastPayloadTypeClose, 0, 1, nil)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-conn.startCh:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("dedicated bulk write did not start")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatalf("sender.submitControl returned before flush completed: %v", err)
|
||||||
|
case <-time.After(50 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
close(conn.unblockCh)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sender.submitControl failed after started flush: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("sender.submitControl did not return after started flush completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const bulkDispatchRejectTimeout = 300 * time.Millisecond
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchFastBulkFrame(frame bulkFastFrame) {
|
||||||
|
if frame.DataID == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk, ok := runtime.lookupByDataID(clientFileScope(), frame.DataID)
|
||||||
|
if !ok {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client bulk data for unknown data id", frame.DataID)
|
||||||
|
}
|
||||||
|
c.bestEffortRejectInboundBulkData("", frame.DataID, errBulkNotFound.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bulk.acceptsClientSessionEpoch(c.currentClientSessionEpoch()) {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client bulk data rejected by stale session epoch", frame.DataID)
|
||||||
|
}
|
||||||
|
detachErr := transportDetachedSessionEpochError()
|
||||||
|
bulk.markReset(detachErr)
|
||||||
|
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, detachErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch frame.Type {
|
||||||
|
case bulkFastPayloadTypeData:
|
||||||
|
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client bulk push chunk error", err)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
c.bestEffortRejectInboundBulkData(bulk.ID(), frame.DataID, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case bulkFastPayloadTypeClose:
|
||||||
|
if frame.Flags&bulkFastPayloadFlagFullClose != 0 {
|
||||||
|
bulk.markPeerClosed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk.markRemoteClosed()
|
||||||
|
case bulkFastPayloadTypeReset:
|
||||||
|
resetErr := errBulkReset
|
||||||
|
if len(frame.Payload) > 0 {
|
||||||
|
resetErr = bulkRemoteResetError(string(frame.Payload))
|
||||||
|
}
|
||||||
|
bulk.markReset(bulkResetError(resetErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchFastBulkData(frame bulkFastDataFrame) {
|
||||||
|
c.dispatchFastBulkFrame(frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) dispatchFastBulkFrame(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastFrame) {
|
||||||
|
if logical == nil || frame.DataID == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime := s.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk, ok := runtime.lookupByDataID(serverFileScope(logical), frame.DataID)
|
||||||
|
if !ok {
|
||||||
|
if s.showError || s.debugMode {
|
||||||
|
fmt.Println("server bulk data for unknown data id", frame.DataID)
|
||||||
|
}
|
||||||
|
s.bestEffortRejectInboundBulkData(logical, transport, conn, "", frame.DataID, errBulkNotFound.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bulk.acceptsTransportGeneration(transport) {
|
||||||
|
if s.showError || s.debugMode {
|
||||||
|
fmt.Println("server bulk data rejected by transport generation mismatch", frame.DataID)
|
||||||
|
}
|
||||||
|
detachErr := transportDetachedGenerationMismatchError(bulk.TransportGeneration(), transport)
|
||||||
|
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, detachErr.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch frame.Type {
|
||||||
|
case bulkFastPayloadTypeData:
|
||||||
|
if err := bulk.pushOwnedChunk(frame.Payload); err != nil {
|
||||||
|
if s.showError || s.debugMode {
|
||||||
|
fmt.Println("server bulk push chunk error", err)
|
||||||
|
}
|
||||||
|
if !errors.Is(err, io.EOF) {
|
||||||
|
s.bestEffortRejectInboundBulkData(logical, transport, conn, bulk.ID(), frame.DataID, err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case bulkFastPayloadTypeClose:
|
||||||
|
if frame.Flags&bulkFastPayloadFlagFullClose != 0 {
|
||||||
|
bulk.markPeerClosed()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulk.markRemoteClosed()
|
||||||
|
case bulkFastPayloadTypeReset:
|
||||||
|
resetErr := errBulkReset
|
||||||
|
if len(frame.Payload) > 0 {
|
||||||
|
resetErr = bulkRemoteResetError(string(frame.Payload))
|
||||||
|
}
|
||||||
|
bulk.markReset(bulkResetError(resetErr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) dispatchFastBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, frame bulkFastDataFrame) {
|
||||||
|
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) bestEffortRejectInboundBulkData(bulkID string, dataID uint64, message string) {
|
||||||
|
if c == nil || (bulkID == "" && dataID == 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), bulkDispatchRejectTimeout)
|
||||||
|
defer cancel()
|
||||||
|
_, _ = sendBulkResetClient(ctx, c, BulkResetRequest{
|
||||||
|
BulkID: bulkID,
|
||||||
|
DataID: dataID,
|
||||||
|
Error: message,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) bestEffortRejectInboundBulkData(logical *LogicalConn, transport *TransportConn, conn net.Conn, bulkID string, dataID uint64, message string) {
|
||||||
|
if s == nil || logical == nil || (bulkID == "" && dataID == 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
payload, err := encode(BulkResetRequest{
|
||||||
|
BulkID: bulkID,
|
||||||
|
DataID: dataID,
|
||||||
|
Error: message,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
||||||
|
Key: BulkResetSignalKey,
|
||||||
|
Value: payload,
|
||||||
|
Type: MSG_ASYNC,
|
||||||
|
}, s.sequenceEn)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = s.sendEnvelopeInboundTransport(logical, transport, conn, env)
|
||||||
|
}
|
||||||
@@ -0,0 +1,350 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkBulkEndToEndThroughput(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
network string
|
||||||
|
payloadSize int
|
||||||
|
dedicated bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp_shared_1MiB",
|
||||||
|
network: "tcp",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp_dedicated_1MiB",
|
||||||
|
network: "tcp",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
dedicated: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix_shared_1MiB",
|
||||||
|
network: "unix",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix_dedicated_1MiB",
|
||||||
|
network: "unix",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
dedicated: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkEndToEndThroughputNetwork(b, tc.network, tc.payloadSize, tc.dedicated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkBulkEndToEndThroughputConcurrent(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
network string
|
||||||
|
payloadSize int
|
||||||
|
concurrency int
|
||||||
|
dedicated bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp_dedicated_4x1MiB",
|
||||||
|
network: "tcp",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
dedicated: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unix_dedicated_4x1MiB",
|
||||||
|
network: "unix",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
concurrency: 4,
|
||||||
|
dedicated: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkBulkEndToEndThroughputConcurrentNetwork(b, tc.network, tc.payloadSize, tc.concurrency, tc.dedicated)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkBulkEndToEndThroughputNetwork(b *testing.B, network string, payloadSize int, dedicated bool) {
|
||||||
|
b.Helper()
|
||||||
|
if network == "unix" && runtime.GOOS == "windows" {
|
||||||
|
b.Skip("unix socket is not available on windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := newBulkBenchmarkServer(b, network)
|
||||||
|
client := newBulkBenchmarkClient(b, network, server)
|
||||||
|
|
||||||
|
totalBytes := int64(payloadSize)
|
||||||
|
if b.N > 1 {
|
||||||
|
totalBytes = int64(payloadSize) * int64(b.N)
|
||||||
|
}
|
||||||
|
bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: totalBytes,
|
||||||
|
},
|
||||||
|
ChunkSize: payloadSize,
|
||||||
|
Dedicated: dedicated,
|
||||||
|
})
|
||||||
|
|
||||||
|
drainDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(io.Discard, accepted.Bulk)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
drainDone <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
drainDone <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(payloadSize))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
n, err := bulk.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("bulk Write failed at iter %d: %v", i, err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
b.Fatalf("bulk Write bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := bulk.CloseWrite(); err != nil {
|
||||||
|
b.Fatalf("bulk CloseWrite failed: %v", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-drainDone:
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("server drain failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
b.Fatal("timed out waiting for server drain")
|
||||||
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
_ = accepted.Bulk.Close()
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkBulkEndToEndThroughputConcurrentNetwork(b *testing.B, network string, payloadSize int, concurrency int, dedicated bool) {
|
||||||
|
b.Helper()
|
||||||
|
if concurrency <= 0 {
|
||||||
|
b.Fatal("concurrency must be > 0")
|
||||||
|
}
|
||||||
|
if network == "unix" && runtime.GOOS == "windows" {
|
||||||
|
b.Skip("unix socket is not available on windows")
|
||||||
|
}
|
||||||
|
|
||||||
|
server := newBulkBenchmarkServer(b, network)
|
||||||
|
client := newBulkBenchmarkClient(b, network, server)
|
||||||
|
|
||||||
|
totalBytes := int64(payloadSize)
|
||||||
|
if b.N > 1 {
|
||||||
|
totalBytes = int64(payloadSize) * int64(b.N)
|
||||||
|
}
|
||||||
|
|
||||||
|
bulks := make([]Bulk, 0, concurrency)
|
||||||
|
acceptedBulks := make([]Bulk, 0, concurrency)
|
||||||
|
for index := 0; index < concurrency; index++ {
|
||||||
|
bulk, accepted := openBenchmarkBulkPair(b, client, server.acceptCh, BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: int64(index) * totalBytes,
|
||||||
|
Length: totalBytes,
|
||||||
|
},
|
||||||
|
ChunkSize: payloadSize,
|
||||||
|
Dedicated: dedicated,
|
||||||
|
})
|
||||||
|
bulks = append(bulks, bulk)
|
||||||
|
acceptedBulks = append(acceptedBulks, accepted.Bulk)
|
||||||
|
}
|
||||||
|
|
||||||
|
drainDone := make(chan error, concurrency)
|
||||||
|
for _, acceptedBulk := range acceptedBulks {
|
||||||
|
bulk := acceptedBulk
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(io.Discard, bulk)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
drainDone <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
drainDone <- nil
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(payloadSize))
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
errCh := make(chan error, concurrency)
|
||||||
|
for index, bulk := range bulks {
|
||||||
|
count := b.N / concurrency
|
||||||
|
if index < b.N%concurrency {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
wg.Add(1)
|
||||||
|
go func(bulk Bulk, count int) {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
n, err := bulk.Write(payload)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
errCh <- errors.New("bulk write bytes mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(bulk, count)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
for err := range errCh {
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("concurrent bulk write failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for index, bulk := range bulks {
|
||||||
|
if err := bulk.CloseWrite(); err != nil {
|
||||||
|
b.Fatalf("bulk %d CloseWrite failed: %v", index, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for index := 0; index < concurrency; index++ {
|
||||||
|
select {
|
||||||
|
case err := <-drainDone:
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("server drain failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(15 * time.Second):
|
||||||
|
b.Fatalf("timed out waiting for server drain %d/%d", index+1, concurrency)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
for _, bulk := range acceptedBulks {
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
for _, bulk := range bulks {
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkBenchmarkServer struct {
|
||||||
|
server *ServerCommon
|
||||||
|
acceptCh chan BulkAcceptInfo
|
||||||
|
addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBulkBenchmarkServer(tb testing.TB, network string) bulkBenchmarkServer {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
tb.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
if network == "udp" {
|
||||||
|
if err := UseSignalReliabilityServer(server, bulkBenchmarkSignalReliabilityOptions()); err != nil {
|
||||||
|
tb.Fatalf("UseSignalReliabilityServer failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
acceptCh := make(chan BulkAcceptInfo, 32)
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
acceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
addr := bulkBenchmarkListenAddr(tb, network)
|
||||||
|
if err := server.Listen(network, addr); err != nil {
|
||||||
|
tb.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
tb.Cleanup(func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
return bulkBenchmarkServer{
|
||||||
|
server: server,
|
||||||
|
acceptCh: acceptCh,
|
||||||
|
addr: signalRoundTripServerAddr(server, addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBulkBenchmarkClient(tb testing.TB, network string, server bulkBenchmarkServer) *ClientCommon {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
tb.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if network == "udp" {
|
||||||
|
if err := UseSignalReliabilityClient(client, bulkBenchmarkSignalReliabilityOptions()); err != nil {
|
||||||
|
tb.Fatalf("UseSignalReliabilityClient failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := client.Connect(network, server.addr); err != nil {
|
||||||
|
tb.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
tb.Cleanup(func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
})
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
func openBenchmarkBulkPair(tb testing.TB, client *ClientCommon, acceptCh <-chan BulkAcceptInfo, opt BulkOpenOptions) (Bulk, BulkAcceptInfo) {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
bulk, err := client.OpenBulk(context.Background(), opt)
|
||||||
|
if err != nil {
|
||||||
|
tb.Fatalf("client OpenBulk failed: %v", err)
|
||||||
|
}
|
||||||
|
return bulk, waitBenchmarkAcceptedBulk(tb, acceptCh, 5*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkBenchmarkListenAddr(tb testing.TB, network string) string {
|
||||||
|
tb.Helper()
|
||||||
|
switch network {
|
||||||
|
case "unix":
|
||||||
|
return filepath.Join(tb.TempDir(), "notify-bulk.sock")
|
||||||
|
case "udp", "tcp":
|
||||||
|
return "127.0.0.1:0"
|
||||||
|
default:
|
||||||
|
tb.Fatalf("unsupported benchmark network %q", network)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkBenchmarkSignalReliabilityOptions() *SignalReliabilityOptions {
|
||||||
|
return &SignalReliabilityOptions{
|
||||||
|
Enabled: true,
|
||||||
|
AckTimeout: 3 * time.Second,
|
||||||
|
SendRetry: 8,
|
||||||
|
ReceiveCacheLimit: 512,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,280 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errBulkFastPayloadInvalid = errors.New("invalid bulk fast payload")
|
||||||
|
)
|
||||||
|
|
||||||
|
var bulkFastFrameScratchPool sync.Pool
|
||||||
|
|
||||||
|
const (
|
||||||
|
bulkFastPayloadMagic = "NBF1"
|
||||||
|
bulkFastPayloadVersion = 1
|
||||||
|
bulkFastPayloadTypeData = 1
|
||||||
|
bulkFastPayloadTypeClose = 2
|
||||||
|
bulkFastPayloadTypeReset = 3
|
||||||
|
bulkFastPayloadTypeRelease = 4
|
||||||
|
bulkFastPayloadHeaderLen = 28
|
||||||
|
bulkFastPayloadFlagFullClose = 1 << 0
|
||||||
|
)
|
||||||
|
|
||||||
|
type bulkFastFrame struct {
|
||||||
|
Type uint8
|
||||||
|
Flags uint8
|
||||||
|
DataID uint64
|
||||||
|
Seq uint64
|
||||||
|
Payload []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type bulkFastDataFrame = bulkFastFrame
|
||||||
|
|
||||||
|
func encodeBulkFastFrameHeader(dst []byte, frameType uint8, flags uint8, dataID uint64, seq uint64, payloadLen int) error {
|
||||||
|
if dataID == 0 {
|
||||||
|
return errBulkDataIDEmpty
|
||||||
|
}
|
||||||
|
if len(dst) < bulkFastPayloadHeaderLen {
|
||||||
|
return errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
copy(dst[:4], bulkFastPayloadMagic)
|
||||||
|
dst[4] = bulkFastPayloadVersion
|
||||||
|
dst[5] = frameType
|
||||||
|
dst[6] = flags
|
||||||
|
dst[7] = 0
|
||||||
|
binary.BigEndian.PutUint64(dst[8:16], dataID)
|
||||||
|
binary.BigEndian.PutUint64(dst[16:24], seq)
|
||||||
|
binary.BigEndian.PutUint32(dst[24:28], uint32(payloadLen))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBulkFastDataFrameHeader(dst []byte, dataID uint64, seq uint64, payloadLen int) error {
|
||||||
|
return encodeBulkFastFrameHeader(dst, bulkFastPayloadTypeData, 0, dataID, seq, payloadLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBulkFastDataFrame(dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||||
|
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
|
||||||
|
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(payload)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(frame[bulkFastPayloadHeaderLen:], payload)
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBulkFastControlFrame(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||||
|
frame := make([]byte, bulkFastPayloadHeaderLen+len(payload))
|
||||||
|
if err := encodeBulkFastFrameHeader(frame, frameType, flags, dataID, seq, len(payload)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(frame[bulkFastPayloadHeaderLen:], payload)
|
||||||
|
return frame, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkFastFrame(payload []byte) (bulkFastFrame, bool, error) {
|
||||||
|
if len(payload) < 4 || string(payload[:4]) != bulkFastPayloadMagic {
|
||||||
|
return bulkFastFrame{}, false, nil
|
||||||
|
}
|
||||||
|
if len(payload) < bulkFastPayloadHeaderLen {
|
||||||
|
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
if payload[4] != bulkFastPayloadVersion {
|
||||||
|
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
switch payload[5] {
|
||||||
|
case bulkFastPayloadTypeData, bulkFastPayloadTypeClose, bulkFastPayloadTypeReset, bulkFastPayloadTypeRelease:
|
||||||
|
default:
|
||||||
|
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
dataLen := int(binary.BigEndian.Uint32(payload[24:28]))
|
||||||
|
if dataLen < 0 || len(payload) != bulkFastPayloadHeaderLen+dataLen {
|
||||||
|
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
dataID := binary.BigEndian.Uint64(payload[8:16])
|
||||||
|
if dataID == 0 {
|
||||||
|
return bulkFastFrame{}, true, errBulkFastPayloadInvalid
|
||||||
|
}
|
||||||
|
return bulkFastFrame{
|
||||||
|
Type: payload[5],
|
||||||
|
Flags: payload[6],
|
||||||
|
DataID: dataID,
|
||||||
|
Seq: binary.BigEndian.Uint64(payload[16:24]),
|
||||||
|
Payload: payload[bulkFastPayloadHeaderLen:],
|
||||||
|
}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBulkFastDataFrame(payload []byte) (bulkFastDataFrame, bool, error) {
|
||||||
|
frame, matched, err := decodeBulkFastFrame(payload)
|
||||||
|
if !matched || err != nil {
|
||||||
|
return frame, matched, err
|
||||||
|
}
|
||||||
|
if frame.Type != bulkFastPayloadTypeData {
|
||||||
|
return bulkFastDataFrame{}, false, nil
|
||||||
|
}
|
||||||
|
return frame, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) encodeFastBulkDataPayload(dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
|
||||||
|
if c != nil && c.fastBulkEncode != nil {
|
||||||
|
return c.fastBulkEncode(c.SecretKey, dataID, seq, chunk)
|
||||||
|
}
|
||||||
|
scratch := getBulkFastFrameScratch(len(chunk))
|
||||||
|
defer putBulkFastFrameScratch(scratch)
|
||||||
|
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
|
||||||
|
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(frame[bulkFastPayloadHeaderLen:], chunk)
|
||||||
|
return c.encryptTransportPayload(frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendFastBulkData(ctx context.Context, dataID uint64, seq uint64, chunk []byte) error {
|
||||||
|
payload, err := c.encodeFastBulkDataPayload(dataID, seq, chunk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
binding := c.clientTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
|
||||||
|
return sender.submit(ctx, payload)
|
||||||
|
}
|
||||||
|
return c.writePayloadToTransport(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) encodeBulkFastControlPayload(frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||||
|
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return c.encryptTransportPayload(plain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) encodeFastBulkDataPayloadLogical(logical *LogicalConn, dataID uint64, seq uint64, chunk []byte) ([]byte, error) {
|
||||||
|
if logical != nil {
|
||||||
|
if fastBulkEncode := logical.fastBulkEncodeSnapshot(); fastBulkEncode != nil {
|
||||||
|
return fastBulkEncode(logical.secretKeySnapshot(), dataID, seq, chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
scratch := getBulkFastFrameScratch(len(chunk))
|
||||||
|
defer putBulkFastFrameScratch(scratch)
|
||||||
|
frame := scratch[:bulkFastPayloadHeaderLen+len(chunk)]
|
||||||
|
if err := encodeBulkFastDataFrameHeader(frame, dataID, seq, len(chunk)); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
copy(frame[bulkFastPayloadHeaderLen:], chunk)
|
||||||
|
return s.encryptTransportPayloadLogical(logical, frame)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFastBulkDataTransport(ctx context.Context, logical *LogicalConn, transport *TransportConn, dataID uint64, seq uint64, chunk []byte) error {
|
||||||
|
if err := s.ensureServerTransportSendReady(transport); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if logical == nil && transport != nil {
|
||||||
|
logical = transport.logicalConnSnapshot()
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
payload, err := s.encodeFastBulkDataPayloadLogical(logical, dataID, seq, chunk)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if binding := logical.transportBindingSnapshot(); binding != nil {
|
||||||
|
if binding.queueSnapshot() != nil {
|
||||||
|
if sender := binding.bulkBatchSenderSnapshot(); sender != nil {
|
||||||
|
return sender.submit(ctx, payload)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return s.writeEnvelopePayload(logical, transport, nil, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) encodeBulkFastControlPayloadLogical(logical *LogicalConn, frameType uint8, flags uint8, dataID uint64, seq uint64, payload []byte) ([]byte, error) {
|
||||||
|
plain, err := encodeBulkFastControlFrame(frameType, flags, dataID, seq, payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.encryptTransportPayloadLogical(logical, plain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getBulkFastFrameScratch(payloadLen int) []byte {
|
||||||
|
need := bulkFastPayloadHeaderLen + payloadLen
|
||||||
|
if buf, ok := bulkFastFrameScratchPool.Get().([]byte); ok && cap(buf) >= need {
|
||||||
|
return buf[:need]
|
||||||
|
}
|
||||||
|
return make([]byte, need)
|
||||||
|
}
|
||||||
|
|
||||||
|
func putBulkFastFrameScratch(buf []byte) {
|
||||||
|
if cap(buf) == 0 || cap(buf) > 4*1024*1024 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
bulkFastFrameScratchPool.Put(buf[:0])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchInboundTransportPayload(payload []byte, now time.Time) error {
|
||||||
|
plain, err := c.decryptTransportPayload(payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if frame, matched, err := decodeBulkFastFrame(plain); matched {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.dispatchFastBulkFrame(frame)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.dispatchFastStreamData(frame)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
env, err := c.decodeEnvelopePlain(plain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.dispatchEnvelope(env, now)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) dispatchInboundTransportPayload(logical *LogicalConn, transport *TransportConn, conn net.Conn, payload []byte, now time.Time) error {
|
||||||
|
if logical == nil && transport != nil {
|
||||||
|
logical = transport.logicalConnSnapshot()
|
||||||
|
}
|
||||||
|
if logical == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
plain, err := s.decryptTransportPayloadLogical(logical, payload)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if frame, matched, err := decodeBulkFastFrame(plain); matched {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.dispatchFastBulkFrame(logical, transport, conn, frame)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if frame, matched, err := decodeStreamFastDataFrame(plain); matched {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.dispatchFastStreamData(logical, transport, conn, frame)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
env, err := s.decodeEnvelopePlain(plain)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.dispatchEnvelope(logical, transport, conn, env, now)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
+196
@@ -0,0 +1,196 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type bulkRuntime struct {
|
||||||
|
rolePrefix string
|
||||||
|
seq atomic.Uint64
|
||||||
|
dataSeq atomic.Uint64
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
handler func(BulkAcceptInfo) error
|
||||||
|
bulks map[string]*bulkHandle
|
||||||
|
data map[string]*bulkHandle
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBulkRuntime(rolePrefix string) *bulkRuntime {
|
||||||
|
return &bulkRuntime{
|
||||||
|
rolePrefix: rolePrefix,
|
||||||
|
bulks: make(map[string]*bulkHandle),
|
||||||
|
data: make(map[string]*bulkHandle),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) nextID() string {
|
||||||
|
if r == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s-%d", r.rolePrefix, r.seq.Add(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) nextDataID() uint64 {
|
||||||
|
if r == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return r.dataSeq.Add(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) setHandler(fn func(BulkAcceptInfo) error) {
|
||||||
|
if r == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
r.handler = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) handlerSnapshot() func(BulkAcceptInfo) error {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
return r.handler
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) register(scope string, bulk *bulkHandle) error {
|
||||||
|
if r == nil {
|
||||||
|
return errBulkRuntimeNil
|
||||||
|
}
|
||||||
|
if bulk == nil || bulk.id == "" {
|
||||||
|
return errBulkIDEmpty
|
||||||
|
}
|
||||||
|
key := bulkRuntimeKey(scope, bulk.id)
|
||||||
|
dataKey := bulkRuntimeDataKey(scope, bulk.dataID)
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if _, ok := r.bulks[key]; ok {
|
||||||
|
return errBulkAlreadyExists
|
||||||
|
}
|
||||||
|
if bulk.dataID == 0 {
|
||||||
|
return errBulkDataIDEmpty
|
||||||
|
}
|
||||||
|
if _, ok := r.data[dataKey]; ok {
|
||||||
|
return errBulkAlreadyExists
|
||||||
|
}
|
||||||
|
r.bulks[key] = bulk
|
||||||
|
r.data[dataKey] = bulk
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) lookup(scope string, bulkID string) (*bulkHandle, bool) {
|
||||||
|
if r == nil || bulkID == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
key := bulkRuntimeKey(scope, bulkID)
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
bulk, ok := r.bulks[key]
|
||||||
|
return bulk, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) lookupByDataID(scope string, dataID uint64) (*bulkHandle, bool) {
|
||||||
|
if r == nil || dataID == 0 {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
key := bulkRuntimeDataKey(scope, dataID)
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
bulk, ok := r.data[key]
|
||||||
|
return bulk, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) remove(scope string, bulkID string) {
|
||||||
|
if r == nil || bulkID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := bulkRuntimeKey(scope, bulkID)
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if bulk := r.bulks[key]; bulk != nil && bulk.dataID != 0 {
|
||||||
|
delete(r.data, bulkRuntimeDataKey(scope, bulk.dataID))
|
||||||
|
}
|
||||||
|
delete(r.bulks, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) closeAll(err error) {
|
||||||
|
r.closeMatching(func(string) bool { return true }, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) closeScope(scope string, err error) {
|
||||||
|
scope = normalizeFileScope(scope)
|
||||||
|
r.closeMatching(func(key string) bool {
|
||||||
|
return strings.HasPrefix(key, scope+"\x00")
|
||||||
|
}, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) closeMatching(match func(string) bool, err error) {
|
||||||
|
if r == nil || match == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resetErr := bulkRuntimeCloseError(err)
|
||||||
|
r.mu.RLock()
|
||||||
|
bulks := make([]*bulkHandle, 0, len(r.bulks))
|
||||||
|
for key, bulk := range r.bulks {
|
||||||
|
if bulk == nil || !match(key) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
bulks = append(bulks, bulk)
|
||||||
|
}
|
||||||
|
r.mu.RUnlock()
|
||||||
|
for _, bulk := range bulks {
|
||||||
|
bulk.markReset(resetErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *bulkRuntime) snapshots() []BulkSnapshot {
|
||||||
|
if r == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
r.mu.RLock()
|
||||||
|
snapshots := make([]BulkSnapshot, 0, len(r.bulks))
|
||||||
|
for _, bulk := range r.bulks {
|
||||||
|
if bulk == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snapshots = append(snapshots, bulk.snapshot())
|
||||||
|
}
|
||||||
|
r.mu.RUnlock()
|
||||||
|
sortBulkSnapshots(snapshots)
|
||||||
|
return snapshots
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkRuntimeKey(scope string, bulkID string) string {
|
||||||
|
return normalizeFileScope(scope) + "\x00" + bulkID
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkRuntimeDataKey(scope string, dataID uint64) string {
|
||||||
|
return normalizeFileScope(scope) + "\x01" + strconv.FormatUint(dataID, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkRuntimeCloseError(err error) error {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errServiceShutdown
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getBulkRuntime() *bulkRuntime {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.bulkRuntime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getBulkRuntime() *bulkRuntime {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.bulkRuntime
|
||||||
|
}
|
||||||
@@ -0,0 +1,120 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type BulkSnapshot struct {
|
||||||
|
ID string
|
||||||
|
DataID uint64
|
||||||
|
Scope string
|
||||||
|
Range BulkRange
|
||||||
|
Metadata BulkMetadata
|
||||||
|
BindingOwner string
|
||||||
|
BindingAlive bool
|
||||||
|
BindingCurrent bool
|
||||||
|
BindingReason string
|
||||||
|
BindingError string
|
||||||
|
Dedicated bool
|
||||||
|
DedicatedAttached bool
|
||||||
|
SessionEpoch uint64
|
||||||
|
LogicalClientID string
|
||||||
|
TransportGeneration uint64
|
||||||
|
TransportAttached bool
|
||||||
|
TransportHasRuntimeConn bool
|
||||||
|
TransportCurrent bool
|
||||||
|
TransportDetachReason string
|
||||||
|
TransportDetachKind string
|
||||||
|
TransportDetachGeneration uint64
|
||||||
|
TransportDetachError string
|
||||||
|
TransportDetachedAt time.Time
|
||||||
|
ReattachEligible bool
|
||||||
|
LocalClosed bool
|
||||||
|
LocalReadClosed bool
|
||||||
|
RemoteClosed bool
|
||||||
|
PeerReadClosed bool
|
||||||
|
BufferedChunks int
|
||||||
|
BufferedBytes int
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
ChunkSize int
|
||||||
|
WindowBytes int
|
||||||
|
MaxInFlight int
|
||||||
|
BytesRead int64
|
||||||
|
BytesWritten int64
|
||||||
|
ReadCalls int64
|
||||||
|
WriteCalls int64
|
||||||
|
OpenedAt time.Time
|
||||||
|
LastReadAt time.Time
|
||||||
|
LastWriteAt time.Time
|
||||||
|
ResetError string
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientBulkSnapshotReader interface {
|
||||||
|
clientBulkSnapshots() []BulkSnapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverBulkSnapshotReader interface {
|
||||||
|
serverBulkSnapshots() []BulkSnapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errClientBulkSnapshotNil = errors.New("client bulk snapshot target is nil")
|
||||||
|
errServerBulkSnapshotNil = errors.New("server bulk snapshot target is nil")
|
||||||
|
errClientBulkSnapshotUnsupported = errors.New("client bulk snapshot target type is unsupported")
|
||||||
|
errServerBulkSnapshotUnsupported = errors.New("server bulk snapshot target type is unsupported")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetClientBulkSnapshots(c Client) ([]BulkSnapshot, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errClientBulkSnapshotNil
|
||||||
|
}
|
||||||
|
reader, ok := any(c).(clientBulkSnapshotReader)
|
||||||
|
if !ok {
|
||||||
|
return nil, errClientBulkSnapshotUnsupported
|
||||||
|
}
|
||||||
|
return reader.clientBulkSnapshots(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerBulkSnapshots(s Server) ([]BulkSnapshot, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, errServerBulkSnapshotNil
|
||||||
|
}
|
||||||
|
reader, ok := any(s).(serverBulkSnapshotReader)
|
||||||
|
if !ok {
|
||||||
|
return nil, errServerBulkSnapshotUnsupported
|
||||||
|
}
|
||||||
|
return reader.serverBulkSnapshots(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientBulkSnapshots() []BulkSnapshot {
|
||||||
|
return bulkSnapshotsFromRuntime(c.getBulkRuntime())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverBulkSnapshots() []BulkSnapshot {
|
||||||
|
return bulkSnapshotsFromRuntime(s.getBulkRuntime())
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkSnapshotsFromRuntime(runtime *bulkRuntime) []BulkSnapshot {
|
||||||
|
if runtime == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return runtime.snapshots()
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortBulkSnapshots(src []BulkSnapshot) {
|
||||||
|
sort.Slice(src, func(i, j int) bool {
|
||||||
|
if src[i].Scope != src[j].Scope {
|
||||||
|
return src[i].Scope < src[j].Scope
|
||||||
|
}
|
||||||
|
if src[i].ID != src[j].ID {
|
||||||
|
return src[i].ID < src[j].ID
|
||||||
|
}
|
||||||
|
if src[i].DataID != src[j].DataID {
|
||||||
|
return src[i].DataID < src[j].DataID
|
||||||
|
}
|
||||||
|
return src[i].TransportGeneration < src[j].TransportGeneration
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,186 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkModernPSKSealPlainThroughput(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "seal_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "seal_4MiB",
|
||||||
|
payloadSize: 4 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("deriveModernPSKKey failed: %v", err)
|
||||||
|
}
|
||||||
|
transport := buildModernPSKTransportBundle(aad)
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
payload := make([]byte, tc.payloadSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sink []byte
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(tc.payloadSize))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
wire, err := transport.fastPlainEncode(key, len(payload), func(dst []byte) error {
|
||||||
|
copy(dst, payload)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("fastPlainEncode failed: %v", err)
|
||||||
|
}
|
||||||
|
sink = wire
|
||||||
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
_ = sink
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkDedicatedWireLocalhostThroughput(b *testing.B) {
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
payloadSize int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wire_1MiB",
|
||||||
|
payloadSize: 1024 * 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wire_4MiB",
|
||||||
|
payloadSize: 4 * 1024 * 1024,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
key, aad, err := deriveModernPSKKey(integrationSharedSecret, integrationModernPSKOptions())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("deriveModernPSKKey failed: %v", err)
|
||||||
|
}
|
||||||
|
transport := buildModernPSKTransportBundle(aad)
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b.Run(tc.name, func(b *testing.B) {
|
||||||
|
benchmarkDedicatedWireLocalhostThroughput(b, key, transport.fastPlainEncode, tc.payloadSize)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func benchmarkDedicatedWireLocalhostThroughput(b *testing.B, key []byte, encode transportFastPlainEncoder, payloadSize int) {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("net.Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
acceptCh := make(chan net.Conn, 1)
|
||||||
|
acceptErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
acceptErrCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
acceptCh <- conn
|
||||||
|
}()
|
||||||
|
|
||||||
|
clientConn, err := net.Dial("tcp", listener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("net.Dial failed: %v", err)
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
_ = clientConn.Close()
|
||||||
|
})
|
||||||
|
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
|
||||||
|
_ = tcpConn.SetNoDelay(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
var serverConn net.Conn
|
||||||
|
select {
|
||||||
|
case conn := <-acceptCh:
|
||||||
|
serverConn = conn
|
||||||
|
case err := <-acceptErrCh:
|
||||||
|
b.Fatalf("Accept failed: %v", err)
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
b.Fatal("timed out waiting for accept")
|
||||||
|
}
|
||||||
|
b.Cleanup(func() {
|
||||||
|
if serverConn != nil {
|
||||||
|
_ = serverConn.Close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
drainDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(io.Discard, serverConn)
|
||||||
|
if err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
drainDone <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
drainDone <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
sender := newBulkDedicatedSender(clientConn, 1, func(plain []byte) ([]byte, error) {
|
||||||
|
return encode(key, len(plain), func(dst []byte) error {
|
||||||
|
copy(dst, plain)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}, func(items []bulkDedicatedSendRequest) ([]byte, error) {
|
||||||
|
return encodeBulkDedicatedBatchPayloadFast(encode, key, 1, items)
|
||||||
|
}, nil)
|
||||||
|
defer sender.stop()
|
||||||
|
|
||||||
|
payload := make([]byte, payloadSize)
|
||||||
|
for i := range payload {
|
||||||
|
payload[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.SetBytes(int64(payloadSize))
|
||||||
|
b.ResetTimer()
|
||||||
|
seq := uint64(1)
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
n, err := sender.submitWrite(context.Background(), seq, payload, payloadSize)
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("submitWrite failed at iter %d: %v", i, err)
|
||||||
|
}
|
||||||
|
if n != len(payload) {
|
||||||
|
b.Fatalf("submitWrite bytes mismatch at iter %d: got %d want %d", i, n, len(payload))
|
||||||
|
}
|
||||||
|
seq++
|
||||||
|
}
|
||||||
|
b.StopTimer()
|
||||||
|
|
||||||
|
_ = clientConn.Close()
|
||||||
|
select {
|
||||||
|
case err := <-drainDone:
|
||||||
|
if err != nil {
|
||||||
|
b.Fatalf("server drain failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
b.Fatal("timed out waiting for server drain")
|
||||||
|
}
|
||||||
|
}
|
||||||
+1494
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,85 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBulkOpenDedicatedUDPRejected(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err := server.Listen("udp", "127.0.0.1:0"); err != nil {
|
||||||
|
t.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil {
|
||||||
|
t.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: 128,
|
||||||
|
},
|
||||||
|
Dedicated: true,
|
||||||
|
})
|
||||||
|
if !errors.Is(err, errBulkDedicatedStreamOnly) {
|
||||||
|
t.Fatalf("client OpenBulk dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerOpenBulkLogicalDedicatedUDPRejected(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err := server.Listen("udp", "127.0.0.1:0"); err != nil {
|
||||||
|
t.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("udp", signalRoundTripServerAddr(server, "")); err != nil {
|
||||||
|
t.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
logical := waitForTransferControlLogicalConn(t, server, 2*time.Second)
|
||||||
|
_, err := server.OpenBulkLogical(context.Background(), logical, BulkOpenOptions{
|
||||||
|
Range: BulkRange{
|
||||||
|
Offset: 0,
|
||||||
|
Length: 128,
|
||||||
|
},
|
||||||
|
Dedicated: true,
|
||||||
|
})
|
||||||
|
if !errors.Is(err, errBulkDedicatedStreamOnly) {
|
||||||
|
t.Fatalf("server OpenBulkLogical dedicated over udp error = %v, want %v", err, errBulkDedicatedStreamOnly)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,330 +1,105 @@
|
|||||||
package notify
|
package notify
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"b612.me/stario"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"b612.me/starnet"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// StarNotifyC 为Client端
|
type ClientCommon struct {
|
||||||
type StarNotifyC struct {
|
alive atomic.Value
|
||||||
Connc net.Conn
|
status Status
|
||||||
dialTimeout time.Duration
|
byeFromServer bool
|
||||||
clientSign map[string]chan string
|
conn net.Conn
|
||||||
// FuncLists 当不使用channel时,使用此记录调用函数
|
mu sync.Mutex
|
||||||
FuncLists map[string]func(CMsg)
|
msgID uint64
|
||||||
stopSign context.Context
|
peerIdentity string
|
||||||
cancel context.CancelFunc
|
sessionEpoch uint64
|
||||||
defaultFunc func(CMsg)
|
sessionOwnerState atomic.Int32
|
||||||
// Stop 停止信 号
|
sessionRuntime atomic.Pointer[clientSessionRuntime]
|
||||||
Stop chan int
|
connectSource atomic.Pointer[clientConnectSource]
|
||||||
// UseChannel 是否使用channel作为信息传递
|
queue *stario.StarQueue
|
||||||
UseChannel bool
|
stopFn context.CancelFunc
|
||||||
isUDP bool
|
stopCtx context.Context
|
||||||
// Queue 是用来处理收发信息的简单消息队列
|
parallelNum int
|
||||||
Queue *starnet.StarQueue
|
maxReadTimeout time.Duration
|
||||||
// Online 当前链接是否处于活跃状态
|
maxWriteTimeout time.Duration
|
||||||
Online bool
|
keyExchangeFn func(c Client) error
|
||||||
lockPool map[string]CMsg
|
linkFns map[string]func(message *Message)
|
||||||
|
defaultFns func(message *Message)
|
||||||
|
msgEn func([]byte, []byte) []byte
|
||||||
|
msgDe func([]byte, []byte) []byte
|
||||||
|
fastStreamEncode transportFastStreamEncoder
|
||||||
|
fastBulkEncode transportFastBulkEncoder
|
||||||
|
fastPlainEncode transportFastPlainEncoder
|
||||||
|
handshakeRsaPubKey []byte
|
||||||
|
SecretKey []byte
|
||||||
|
noFinSyncMsgMaxKeepSeconds int
|
||||||
|
lastHeartbeat int64
|
||||||
|
heartbeatPeriod time.Duration
|
||||||
|
wg stario.WaitGroup
|
||||||
|
netType NetType
|
||||||
|
showError bool
|
||||||
|
skipKeyExchange bool
|
||||||
|
useHeartBeat bool
|
||||||
|
sequenceDe func([]byte) (interface{}, error)
|
||||||
|
sequenceEn func(interface{}) ([]byte, error)
|
||||||
|
logicalSession *logicalSessionState
|
||||||
|
onFileEvent func(FileEvent)
|
||||||
|
fileEventObserver func(FileEvent)
|
||||||
|
fileTransferCfg fileTransferConfig
|
||||||
|
signalReliableCfg signalReliabilityConfig
|
||||||
|
streamRuntime *streamRuntime
|
||||||
|
recordRuntime *recordRuntime
|
||||||
|
bulkRuntime *bulkRuntime
|
||||||
|
connectionRetryState *connectionRetryState
|
||||||
|
securityReadyCheck bool
|
||||||
|
debugMode bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// CMsg 指明当前客户端被通知的关键字
|
func NewClient() Client {
|
||||||
type CMsg struct {
|
transport := defaultModernPSKTransportBundle()
|
||||||
Key string
|
var client = ClientCommon{
|
||||||
Value string
|
maxReadTimeout: 0,
|
||||||
mode string
|
maxWriteTimeout: 0,
|
||||||
wait chan int
|
peerIdentity: newClientPeerIdentity(),
|
||||||
|
sequenceEn: encode,
|
||||||
|
sequenceDe: Decode,
|
||||||
|
keyExchangeFn: aesRsaHello,
|
||||||
|
SecretKey: nil,
|
||||||
|
handshakeRsaPubKey: defaultRsaPubKey,
|
||||||
|
msgEn: transport.msgEn,
|
||||||
|
msgDe: transport.msgDe,
|
||||||
|
fastStreamEncode: transport.fastStreamEncode,
|
||||||
|
fastBulkEncode: transport.fastBulkEncode,
|
||||||
|
fastPlainEncode: transport.fastPlainEncode,
|
||||||
|
skipKeyExchange: true,
|
||||||
|
securityReadyCheck: true,
|
||||||
}
|
}
|
||||||
|
client.alive.Store(false)
|
||||||
func (star *StarNotifyC) starinitc() {
|
client.useHeartBeat = true
|
||||||
star.stopSign, star.cancel = context.WithCancel(context.Background())
|
client.heartbeatPeriod = time.Second * 20
|
||||||
star.Queue = starnet.NewQueue()
|
client.linkFns = make(map[string]func(*Message))
|
||||||
star.Queue.EncodeFunc = encodeFunc
|
client.defaultFns = func(message *Message) {
|
||||||
star.Queue.DecodeFunc = decodeFunc
|
|
||||||
star.Queue.Encode = true
|
|
||||||
star.FuncLists = make(map[string]func(CMsg))
|
|
||||||
star.UseChannel = false
|
|
||||||
star.Stop = make(chan int, 5)
|
|
||||||
star.clientSign = make(map[string]chan string)
|
|
||||||
star.Online = false
|
|
||||||
star.lockPool = make(map[string]CMsg)
|
|
||||||
star.Queue.RestoreDuration(time.Second * 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notify 用于获取一个通知
|
|
||||||
func (star *StarNotifyC) Notify(key string) chan string {
|
|
||||||
if _, ok := star.clientSign[key]; !ok {
|
|
||||||
ch := make(chan string, 20)
|
|
||||||
star.clientSign[key] = ch
|
|
||||||
}
|
|
||||||
return star.clientSign[key]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (star *StarNotifyC) store(key, value string) {
|
|
||||||
if _, ok := star.clientSign[key]; !ok {
|
|
||||||
ch := make(chan string, 20)
|
|
||||||
ch <- value
|
|
||||||
star.clientSign[key] = ch
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
star.clientSign[key] <- value
|
client.wg = stario.NewWaitGroup(0)
|
||||||
}
|
client.fileTransferCfg = defaultFileTransferConfig()
|
||||||
func NewNotifyCWithTimeOut(netype, value string, timeout time.Duration) (*StarNotifyC, error) {
|
client.signalReliableCfg = defaultSignalReliabilityConfig()
|
||||||
var err error
|
client.logicalSession = newLogicalSessionState(client.fileTransferCfg, client.signalReliableCfg)
|
||||||
var star StarNotifyC
|
client.streamRuntime = newStreamRuntime("cstrm")
|
||||||
star.starinitc()
|
client.recordRuntime = newRecordRuntime()
|
||||||
star.isUDP = false
|
client.bulkRuntime = newBulkRuntime("cblk")
|
||||||
if strings.Index(netype, "udp") >= 0 {
|
client.connectionRetryState = newConnectionRetryState()
|
||||||
star.isUDP = true
|
client.onFileEvent = normalizeFileEventCallback(nil)
|
||||||
}
|
client.fileEventObserver = normalizeFileEventCallback(nil)
|
||||||
star.Connc, err = net.DialTimeout(netype, value, timeout)
|
client.stopCtx, client.stopFn = context.WithCancel(context.Background())
|
||||||
if err != nil {
|
client.sessionRuntime.Store(newClientSessionRuntimeBase(client.stopCtx, client.stopFn))
|
||||||
return nil, err
|
bindClientStreamControl(&client)
|
||||||
}
|
bindClientBulkControl(&client)
|
||||||
star.dialTimeout = timeout
|
client.getTransferState().setBuiltinHandler(client.builtinFileTransferHandler)
|
||||||
go star.cnotify()
|
return &client
|
||||||
go func() {
|
|
||||||
<-star.stopSign.Done()
|
|
||||||
star.Connc.Close()
|
|
||||||
star.Online = false
|
|
||||||
return
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf := make([]byte, 8192)
|
|
||||||
n, err := star.Connc.Read(buf)
|
|
||||||
if n != 0 {
|
|
||||||
star.Queue.ParseMessage(buf[0:n], star.Connc)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
star.Connc.Close()
|
|
||||||
star.ClientStop()
|
|
||||||
//star, _ = NewNotifyC(netype, value)
|
|
||||||
star.Online = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
star.Online = true
|
|
||||||
return &star, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewNotifyC 用于新建一个Client端进程
|
|
||||||
func NewNotifyC(netype, value string) (*StarNotifyC, error) {
|
|
||||||
var err error
|
|
||||||
var star StarNotifyC
|
|
||||||
star.starinitc()
|
|
||||||
star.isUDP = false
|
|
||||||
if strings.Index(netype, "udp") >= 0 {
|
|
||||||
star.isUDP = true
|
|
||||||
}
|
|
||||||
star.Connc, err = net.Dial(netype, value)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
go star.cnotify()
|
|
||||||
go func() {
|
|
||||||
<-star.stopSign.Done()
|
|
||||||
star.Connc.Close()
|
|
||||||
star.Online = false
|
|
||||||
return
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
buf := make([]byte, 8192)
|
|
||||||
n, err := star.Connc.Read(buf)
|
|
||||||
if n != 0 {
|
|
||||||
star.Queue.ParseMessage(buf[0:n], star.Connc)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
star.Connc.Close()
|
|
||||||
star.ClientStop()
|
|
||||||
//star, _ = NewNotifyC(netype, value)
|
|
||||||
star.Online = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
star.Online = true
|
|
||||||
return &star, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send 用于向Server端发送数据
|
|
||||||
func (star *StarNotifyC) Send(name string) error {
|
|
||||||
return star.SendValue(name, "")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (star *StarNotifyC) SendValueRaw(key string, msg interface{}) error {
|
|
||||||
encodeData, err := encode(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return star.SendValue(key, string(encodeData))
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendValue 用于向Server端发送key-value类型数据
|
|
||||||
func (star *StarNotifyC) SendValue(name, value string) error {
|
|
||||||
var err error
|
|
||||||
var key []byte
|
|
||||||
for _, v := range []byte(name) {
|
|
||||||
if v == byte(124) || v == byte(92) {
|
|
||||||
key = append(key, byte(92))
|
|
||||||
}
|
|
||||||
key = append(key, v)
|
|
||||||
}
|
|
||||||
_, err = star.Connc.Write(star.Queue.BuildMessage([]byte("pa" + "||" + string(key) + "||" + value)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (star *StarNotifyC) trim(name string) string {
|
|
||||||
var slash bool = false
|
|
||||||
var key []byte
|
|
||||||
for _, v := range []byte(name) {
|
|
||||||
if v == byte(92) && !slash {
|
|
||||||
slash = true
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
slash = false
|
|
||||||
key = append(key, v)
|
|
||||||
}
|
|
||||||
return string(key)
|
|
||||||
}
|
|
||||||
func (star *StarNotifyC) SendValueWaitRaw(key string, msg interface{}, tmout time.Duration) (CMsg, error) {
|
|
||||||
encodeData, err := encode(msg)
|
|
||||||
if err != nil {
|
|
||||||
return CMsg{}, err
|
|
||||||
}
|
|
||||||
return star.SendValueWait(key, string(encodeData), tmout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendValueWait 用于向Server端发送key-value类型数据并等待结果返回,此结果不会通过标准返回流程处理
|
|
||||||
func (star *StarNotifyC) SendValueWait(name, value string, tmout time.Duration) (CMsg, error) {
|
|
||||||
var err error
|
|
||||||
var tmceed <-chan time.Time
|
|
||||||
if star.UseChannel {
|
|
||||||
return CMsg{}, errors.New("Do Not Use UseChannel Mode!")
|
|
||||||
}
|
|
||||||
rand.Seed(time.Now().UnixNano())
|
|
||||||
mode := "cr" + fmt.Sprintf("%05d", rand.Intn(99999))
|
|
||||||
var key []byte
|
|
||||||
for _, v := range []byte(name) {
|
|
||||||
if v == byte(124) || v == byte(92) {
|
|
||||||
key = append(key, byte(92))
|
|
||||||
}
|
|
||||||
key = append(key, v)
|
|
||||||
}
|
|
||||||
_, err = star.Connc.Write(star.Queue.BuildMessage([]byte(mode + "||" + string(key) + "||" + value)))
|
|
||||||
if err != nil {
|
|
||||||
return CMsg{}, err
|
|
||||||
}
|
|
||||||
if int64(tmout) > 0 {
|
|
||||||
tmceed = time.After(tmout)
|
|
||||||
}
|
|
||||||
var source CMsg
|
|
||||||
source.wait = make(chan int, 2)
|
|
||||||
star.lockPool[mode] = source
|
|
||||||
select {
|
|
||||||
case <-source.wait:
|
|
||||||
res := star.lockPool[mode]
|
|
||||||
delete(star.lockPool, mode)
|
|
||||||
return res, nil
|
|
||||||
case <-tmceed:
|
|
||||||
return CMsg{}, errors.New("Time Exceed")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReplyMsg 用于向Server端Reply信息
|
|
||||||
func (star *StarNotifyC) ReplyMsg(data CMsg, name, value string) error {
|
|
||||||
var err error
|
|
||||||
var key []byte
|
|
||||||
for _, v := range []byte(name) {
|
|
||||||
if v == byte(124) || v == byte(92) {
|
|
||||||
key = append(key, byte(92))
|
|
||||||
}
|
|
||||||
key = append(key, v)
|
|
||||||
}
|
|
||||||
_, err = star.Connc.Write(star.Queue.BuildMessage([]byte(data.mode + "||" + string(key) + "||" + value)))
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (star *StarNotifyC) cnotify() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-star.stopSign.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
data, err := star.Queue.RestoreOne()
|
|
||||||
if err != nil {
|
|
||||||
time.Sleep(time.Millisecond * 20)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if string(data.Msg) == "b612ryzstop" {
|
|
||||||
star.ClientStop()
|
|
||||||
star.Online = false
|
|
||||||
return
|
|
||||||
}
|
|
||||||
strs := strings.SplitN(string(data.Msg), "||", 3)
|
|
||||||
if len(strs) < 3 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
strs[1] = star.trim(strs[1])
|
|
||||||
if star.UseChannel {
|
|
||||||
go star.store(strs[1], strs[2])
|
|
||||||
} else {
|
|
||||||
mode, key, value := strs[0], strs[1], strs[2]
|
|
||||||
if mode[0:2] != "cr" {
|
|
||||||
if msg, ok := star.FuncLists[key]; ok {
|
|
||||||
go msg(CMsg{key, value, mode, nil})
|
|
||||||
} else {
|
|
||||||
if star.defaultFunc != nil {
|
|
||||||
go star.defaultFunc(CMsg{key, value, mode, nil})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if sa, ok := star.lockPool[mode]; ok {
|
|
||||||
sa.Key = key
|
|
||||||
sa.Value = value
|
|
||||||
sa.mode = mode
|
|
||||||
star.lockPool[mode] = sa
|
|
||||||
sa.wait <- 1
|
|
||||||
} else {
|
|
||||||
if msg, ok := star.FuncLists[key]; ok {
|
|
||||||
go msg(CMsg{key, value, mode, nil})
|
|
||||||
} else {
|
|
||||||
if star.defaultFunc != nil {
|
|
||||||
go star.defaultFunc(CMsg{key, value, mode, nil})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientStop 终止client端运行
|
|
||||||
func (star *StarNotifyC) ClientStop() {
|
|
||||||
if star.isUDP {
|
|
||||||
star.Send("b612ryzstop")
|
|
||||||
}
|
|
||||||
star.cancel()
|
|
||||||
star.Stop <- 1
|
|
||||||
star.Stop <- 1
|
|
||||||
star.Stop <- 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotify 用于设置关键词的调用函数
|
|
||||||
func (star *StarNotifyC) SetNotify(name string, data func(CMsg)) {
|
|
||||||
star.FuncLists[name] = data
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDefaultNotify 用于设置默认关键词的调用函数
|
|
||||||
func (star *StarNotifyC) SetDefaultNotify(data func(CMsg)) {
|
|
||||||
star.defaultFunc = data
|
|
||||||
}
|
}
|
||||||
|
|||||||
+198
@@ -0,0 +1,198 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetBulkHandler(fn func(BulkAcceptInfo) error) {
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime.setHandler(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errBulkClientNil
|
||||||
|
}
|
||||||
|
runtime := c.getBulkRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return nil, errBulkRuntimeNil
|
||||||
|
}
|
||||||
|
req := clientBulkRequest(runtime, opt)
|
||||||
|
if req.BulkID == "" {
|
||||||
|
return nil, errBulkIDEmpty
|
||||||
|
}
|
||||||
|
if req.Dedicated {
|
||||||
|
if err := clientDedicatedBulkSupportError(c); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !validBulkRange(req.Range) {
|
||||||
|
return nil, errBulkRangeInvalid
|
||||||
|
}
|
||||||
|
if _, exists := runtime.lookup(clientFileScope(), req.BulkID); exists {
|
||||||
|
return nil, errBulkAlreadyExists
|
||||||
|
}
|
||||||
|
resp, err := sendBulkOpenClient(ctx, c, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.DataID != 0 {
|
||||||
|
req.DataID = resp.DataID
|
||||||
|
}
|
||||||
|
req.Dedicated = resp.Dedicated
|
||||||
|
if resp.AttachToken != "" {
|
||||||
|
req.AttachToken = resp.AttachToken
|
||||||
|
}
|
||||||
|
if req.DataID == 0 {
|
||||||
|
return nil, errBulkDataIDEmpty
|
||||||
|
}
|
||||||
|
bulk := newBulkHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientBulkCloseSender(c), clientBulkResetSender(c), clientBulkDataSender(c, c.currentClientSessionEpoch()), clientBulkWriteSender(c, c.currentClientSessionEpoch()), clientBulkReleaseSender(c))
|
||||||
|
bulk.setClientSnapshotOwner(c)
|
||||||
|
if err := runtime.register(clientFileScope(), bulk); err != nil {
|
||||||
|
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
|
||||||
|
BulkID: req.BulkID,
|
||||||
|
DataID: req.DataID,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if bulk.Dedicated() {
|
||||||
|
if err := c.attachDedicatedBulkSidecar(ctx, bulk); err != nil {
|
||||||
|
runtime.remove(clientFileScope(), bulk.ID())
|
||||||
|
_, _ = sendBulkResetClient(context.Background(), c, BulkResetRequest{
|
||||||
|
BulkID: bulk.ID(),
|
||||||
|
DataID: bulk.dataIDSnapshot(),
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return bulk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientBulkRequest(runtime *bulkRuntime, opt BulkOpenOptions) BulkOpenRequest {
|
||||||
|
opt = normalizeBulkOpenOptions(opt)
|
||||||
|
id := opt.ID
|
||||||
|
if id == "" && runtime != nil {
|
||||||
|
id = runtime.nextID()
|
||||||
|
}
|
||||||
|
return normalizeBulkOpenRequest(BulkOpenRequest{
|
||||||
|
BulkID: id,
|
||||||
|
Range: opt.Range,
|
||||||
|
Metadata: cloneBulkMetadata(opt.Metadata),
|
||||||
|
ReadTimeout: opt.ReadTimeout,
|
||||||
|
WriteTimeout: opt.WriteTimeout,
|
||||||
|
Dedicated: opt.Dedicated,
|
||||||
|
ChunkSize: opt.ChunkSize,
|
||||||
|
WindowBytes: opt.WindowBytes,
|
||||||
|
MaxInFlight: opt.MaxInFlight,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientBulkCloseSender(c *ClientCommon) bulkCloseSender {
|
||||||
|
return func(ctx context.Context, bulk *bulkHandle, full bool) error {
|
||||||
|
if bulk != nil && bulk.Dedicated() {
|
||||||
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.sendDedicatedBulkClose(ctx, bulk, full)
|
||||||
|
}
|
||||||
|
_, err := sendBulkCloseClient(ctx, c, BulkCloseRequest{
|
||||||
|
BulkID: bulk.ID(),
|
||||||
|
Full: full,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientBulkResetSender(c *ClientCommon) bulkResetSender {
|
||||||
|
return func(ctx context.Context, bulk *bulkHandle, message string) error {
|
||||||
|
if bulk != nil && bulk.Dedicated() {
|
||||||
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.sendDedicatedBulkReset(ctx, bulk, message)
|
||||||
|
}
|
||||||
|
_, err := sendBulkResetClient(ctx, c, BulkResetRequest{
|
||||||
|
BulkID: bulk.ID(),
|
||||||
|
DataID: bulk.dataIDSnapshot(),
|
||||||
|
Error: message,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientBulkDataSender(c *ClientCommon, epoch uint64) bulkDataSender {
|
||||||
|
return func(ctx context.Context, bulk *bulkHandle, chunk []byte) error {
|
||||||
|
if c == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if bulk != nil && bulk.Dedicated() {
|
||||||
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.sendDedicatedBulkData(ctx, bulk, chunk)
|
||||||
|
}
|
||||||
|
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
dataID := bulk.dataIDSnapshot()
|
||||||
|
if dataID == 0 {
|
||||||
|
return errBulkDataPathNotReady
|
||||||
|
}
|
||||||
|
return c.sendFastBulkData(ctx, dataID, bulk.nextOutboundDataSeq(), chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientBulkWriteSender(c *ClientCommon, epoch uint64) bulkWriteSender {
|
||||||
|
return func(ctx context.Context, bulk *bulkHandle, payload []byte) (int, error) {
|
||||||
|
if c == nil {
|
||||||
|
return 0, errBulkClientNil
|
||||||
|
}
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return 0, ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if bulk != nil && bulk.Dedicated() {
|
||||||
|
if err := bulk.waitDedicatedReady(ctx); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return c.sendDedicatedBulkWrite(ctx, bulk, payload)
|
||||||
|
}
|
||||||
|
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||||
|
return 0, errTransportDetached
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientBulkReleaseSender(c *ClientCommon) bulkReleaseSender {
|
||||||
|
return func(bulk *bulkHandle, bytes int64, chunks int) error {
|
||||||
|
if c == nil || bulk == nil {
|
||||||
|
return errBulkClientNil
|
||||||
|
}
|
||||||
|
if bytes <= 0 && chunks <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if bulk.Dedicated() {
|
||||||
|
return c.sendDedicatedBulkRelease(context.Background(), bulk, bytes, chunks)
|
||||||
|
}
|
||||||
|
return sendBulkReleaseClient(c, BulkReleaseRequest{
|
||||||
|
BulkID: bulk.ID(),
|
||||||
|
DataID: bulk.dataIDSnapshot(),
|
||||||
|
Bytes: bytes,
|
||||||
|
Chunks: chunks,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,155 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ClientCommon) DebugMode(dmg bool) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.debugMode = dmg
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) IsDebugMode() bool {
|
||||||
|
return c.debugMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||||
|
func (c *ClientCommon) SkipExchangeKey() bool {
|
||||||
|
return c.skipKeyExchange
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||||
|
func (c *ClientCommon) SetSkipExchangeKey(val bool) {
|
||||||
|
c.skipKeyExchange = val
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) ShowError(std bool) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.showError = std
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetDefaultLink(fn func(message *Message)) {
|
||||||
|
c.defaultFns = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetLink(key string, fn func(*Message)) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.linkFns[key] = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetFileHandler(fn func(FileEvent)) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.onFileEvent = normalizeFileEventCallback(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetFileReceiveDir(dir string) error {
|
||||||
|
return c.getFileReceivePool().setDir(dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetTransferResumeStore(store TransferResumeStore) {
|
||||||
|
if runtime := c.getTransferRuntime(); runtime != nil {
|
||||||
|
runtime.setResumeStore(store)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) RecoverTransferSnapshots(ctx context.Context) error {
|
||||||
|
if runtime := c.getTransferRuntime(); runtime != nil {
|
||||||
|
return runtime.recover(ctx)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) GetMsgEn() func([]byte, []byte) []byte {
|
||||||
|
return c.msgEn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetMsgEn overrides the transport codec directly.
|
||||||
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
|
func (c *ClientCommon) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||||
|
c.msgEn = fn
|
||||||
|
c.fastStreamEncode = nil
|
||||||
|
c.fastBulkEncode = nil
|
||||||
|
c.fastPlainEncode = nil
|
||||||
|
c.securityReadyCheck = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) GetMsgDe() func([]byte, []byte) []byte {
|
||||||
|
return c.msgDe
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetMsgDe overrides the transport codec directly.
|
||||||
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
|
func (c *ClientCommon) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||||
|
c.msgDe = fn
|
||||||
|
c.fastStreamEncode = nil
|
||||||
|
c.fastBulkEncode = nil
|
||||||
|
c.fastPlainEncode = nil
|
||||||
|
c.securityReadyCheck = false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) HeartbeatPeroid() time.Duration {
|
||||||
|
return c.heartbeatPeriod
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetHeartbeatPeroid(duration time.Duration) {
|
||||||
|
c.heartbeatPeriod = duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) GetSecretKey() []byte {
|
||||||
|
return c.SecretKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||||
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
|
func (c *ClientCommon) SetSecretKey(key []byte) {
|
||||||
|
c.SecretKey = key
|
||||||
|
c.securityReadyCheck = len(key) == 0
|
||||||
|
c.skipKeyExchange = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||||
|
func (c *ClientCommon) RsaPubKey() []byte {
|
||||||
|
return c.handshakeRsaPubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||||
|
func (c *ClientCommon) SetRsaPubKey(key []byte) {
|
||||||
|
c.handshakeRsaPubKey = key
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) Stop() error {
|
||||||
|
if !sessionIsAlive(&c.alive) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.stopClientSession("recv stop signal from user", nil)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) StopMonitorChan() <-chan struct{} {
|
||||||
|
return sessionStopChan(c.clientStopContextSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) Status() Status {
|
||||||
|
return sessionStatusValue(&c.mu, &c.status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) GetSequenceEn() func(interface{}) ([]byte, error) {
|
||||||
|
return c.sequenceEn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetSequenceEn(fn func(interface{}) ([]byte, error)) {
|
||||||
|
c.sequenceEn = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) GetSequenceDe() func([]byte) (interface{}, error) {
|
||||||
|
return c.sequenceDe
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetSequenceDe(fn func([]byte) (interface{}, error)) {
|
||||||
|
c.sequenceDe = fn
|
||||||
|
}
|
||||||
+437
@@ -0,0 +1,437 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/starcrypto"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientConnTransportDetachState struct {
|
||||||
|
Generation uint64
|
||||||
|
Reason string
|
||||||
|
Err string
|
||||||
|
At time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
clientConnTransportDetachKindReadError = "read_error"
|
||||||
|
clientConnTransportDetachKindHeartbeatTimeout = "heartbeat_timeout"
|
||||||
|
clientConnTransportDetachKindOther = "other"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ClientConn struct {
|
||||||
|
alive atomic.Value
|
||||||
|
status Status
|
||||||
|
logicalView atomic.Pointer[LogicalConn]
|
||||||
|
logicalState atomic.Pointer[logicalConnState]
|
||||||
|
runtimeState atomic.Pointer[logicalConnRuntimeState]
|
||||||
|
transportState atomic.Pointer[clientConnTransportState]
|
||||||
|
sessionRuntime atomic.Pointer[clientConnSessionRuntime]
|
||||||
|
attachment atomic.Pointer[clientConnAttachmentState]
|
||||||
|
identityBound atomic.Bool
|
||||||
|
ClientID string
|
||||||
|
ClientAddr net.Addr
|
||||||
|
server Server
|
||||||
|
}
|
||||||
|
|
||||||
|
type Status struct {
|
||||||
|
Alive bool
|
||||||
|
Reason string
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) readTUMessage() {
|
||||||
|
if logical := c.LogicalConn(); logical != nil {
|
||||||
|
logical.readTUMessage()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.readTUMessageLoop(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
|
||||||
|
if logical := c.LogicalConn(); logical != nil {
|
||||||
|
logical.readTUMessageLoop(rt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopCtx := rt.transportStopCtx
|
||||||
|
if stopCtx == nil {
|
||||||
|
stopCtx = rt.stopCtx
|
||||||
|
}
|
||||||
|
if stopCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn := rt.tuConn
|
||||||
|
generation := rt.transportGeneration
|
||||||
|
defer closeClientConnSessionRuntimeTransportDone(rt)
|
||||||
|
buf := streamReadBuffer()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sessionStopChan(stopCtx):
|
||||||
|
if c.shouldCloseClientConnTransportOnStop(conn) {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
|
||||||
|
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: rsaDecode exists only for the legacy MSG_KEY_CHANGE flow.
|
||||||
|
func (c *ClientConn) rsaDecode(message Message) {
|
||||||
|
privKey, err := starcrypto.DecodeRsaPrivateKey(c.clientConnHandshakeRsaKeySnapshot(), "")
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
message.Reply([]byte("failed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := starcrypto.RSADecrypt(privKey, message.Value)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(err)
|
||||||
|
message.Reply([]byte("failed"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
message.Reply([]byte("success"))
|
||||||
|
c.setClientConnSecretKey(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) sayGoodByeForTU() error {
|
||||||
|
if c == nil || c.server == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
_, err := c.server.SendWaitLogical(c.LogicalConn(), "bye", nil, time.Second*3)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
_, err = c.server.sendWait(c, TransferMsg{
|
||||||
|
ID: 10010,
|
||||||
|
Key: "bye",
|
||||||
|
Value: nil,
|
||||||
|
Type: MSG_SYS_WAIT,
|
||||||
|
}, time.Second*3)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) GetSecretKey() []byte {
|
||||||
|
return c.clientConnSecretKeySnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetSecretKey injects a raw per-connection transport key directly.
|
||||||
|
func (c *ClientConn) SetSecretKey(key []byte) {
|
||||||
|
c.setClientConnSecretKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) GetMsgEn() func([]byte, []byte) []byte {
|
||||||
|
return c.clientConnMsgEnSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetMsgEn overrides the per-connection transport codec directly.
|
||||||
|
func (c *ClientConn) SetMsgEn(fn func([]byte, []byte) []byte) {
|
||||||
|
c.setClientConnMsgEn(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) GetMsgDe() func([]byte, []byte) []byte {
|
||||||
|
return c.clientConnMsgDeSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: SetMsgDe overrides the per-connection transport codec directly.
|
||||||
|
func (c *ClientConn) SetMsgDe(fn func([]byte, []byte) []byte) {
|
||||||
|
c.setClientConnMsgDe(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) StopMonitorChan() <-chan struct{} {
|
||||||
|
return sessionStopChan(c.clientConnStopContextSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) Status() Status {
|
||||||
|
return c.clientConnStatusSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) Server() Server {
|
||||||
|
if c != nil {
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
if server := logical.Server(); server != nil {
|
||||||
|
return server
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return c.server
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) GetRemoteAddr() net.Addr {
|
||||||
|
return c.clientConnRemoteAddrSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) markClientConnIdentityBound() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.markIdentityBound()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := c.ensureLogicalConnState()
|
||||||
|
if state == nil {
|
||||||
|
c.identityBound.Store(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state.updatePeer(func(peer *logicalConnPeerState) {
|
||||||
|
peer.identityBound = true
|
||||||
|
})
|
||||||
|
c.syncLegacyLogicalFieldsFromState(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnIdentityBoundSnapshot() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.clientConnLogicalPeerStateSnapshot().identityBound
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) markClientConnStreamTransport() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.markStreamTransport()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state.streamTransport.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnUsesStreamTransportSnapshot() bool {
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return state.streamTransport.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) shouldPreserveLogicalPeerOnTransportLoss() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.clientConnIdentityBoundSnapshot() && c.clientConnUsesStreamTransportSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) markClientConnTransportAttached() uint64 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
return logical.markTransportAttached()
|
||||||
|
}
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
gen := state.transportGen.Add(1)
|
||||||
|
state.attachCount.Add(1)
|
||||||
|
state.lastAttachAt.Store(time.Now().UnixNano())
|
||||||
|
return gen
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportGenerationSnapshot() uint64 {
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return state.transportGen.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportAttachCountSnapshot() uint64 {
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return state.attachCount.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) markClientConnTransportDetached(reason string, err error) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.markTransportDetached(reason, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
detachState := &clientConnTransportDetachState{
|
||||||
|
Generation: c.clientConnTransportGenerationSnapshot(),
|
||||||
|
Reason: reason,
|
||||||
|
At: time.Now(),
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
detachState.Err = err.Error()
|
||||||
|
}
|
||||||
|
state.detachCount.Add(1)
|
||||||
|
state.transportDetach.Store(detachState)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachCountSnapshot() uint64 {
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return state.detachCount.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clearClientConnTransportDetachState() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.clearTransportDetachState()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setClientConnTransportDetachState(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachSnapshot() *clientConnTransportDetachState {
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return cloneClientConnTransportDetachState(state.transportDetach.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnLogicalTransportDetachedSnapshot() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !c.clientConnIdentityBoundSnapshot() || !c.clientConnUsesStreamTransportSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !c.clientConnAliveSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !c.clientConnTransportAttachedSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnLastTransportAttachedAtSnapshot() time.Time {
|
||||||
|
state := c.ensureClientConnTransportState()
|
||||||
|
if state == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
unixNano := state.lastAttachAt.Load()
|
||||||
|
if unixNano == 0 {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return time.Unix(0, unixNano)
|
||||||
|
}
|
||||||
|
|
||||||
|
func classifyClientConnTransportDetachReason(reason string) string {
|
||||||
|
switch reason {
|
||||||
|
case "":
|
||||||
|
return ""
|
||||||
|
case "read error":
|
||||||
|
return clientConnTransportDetachKindReadError
|
||||||
|
case "heartbeat timeout":
|
||||||
|
return clientConnTransportDetachKindHeartbeatTimeout
|
||||||
|
default:
|
||||||
|
return clientConnTransportDetachKindOther
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachKindSnapshot() string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
detach := c.clientConnTransportDetachSnapshot()
|
||||||
|
if detach == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return classifyClientConnTransportDetachReason(detach.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachGenerationSnapshot() uint64 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
detach := c.clientConnTransportDetachSnapshot()
|
||||||
|
if detach == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if detach.Generation == 0 {
|
||||||
|
return c.clientConnTransportGenerationSnapshot()
|
||||||
|
}
|
||||||
|
return detach.Generation
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachExpirySnapshot() (time.Time, bool) {
|
||||||
|
if c == nil {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
detach := c.clientConnTransportDetachSnapshot()
|
||||||
|
if detach == nil || detach.At.IsZero() {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
if c.server == nil {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
keepSec := c.server.DetachedClientKeepSec()
|
||||||
|
if keepSec <= 0 {
|
||||||
|
return time.Time{}, false
|
||||||
|
}
|
||||||
|
return detach.At.Add(time.Duration(keepSec) * time.Second), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachExpiredSnapshot(now time.Time) bool {
|
||||||
|
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !now.Before(expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportDetachRemainingSnapshot(now time.Time) time.Duration {
|
||||||
|
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
expiry, ok := c.clientConnTransportDetachExpirySnapshot()
|
||||||
|
if !ok {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if !now.Before(expiry) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return expiry.Sub(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnReattachEligibleSnapshot(now time.Time) bool {
|
||||||
|
if c == nil || !c.clientConnLogicalTransportDetachedSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !c.clientConnAliveSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.clientConnTransportAttachedSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.clientConnTransportDetachExpiredSnapshot(now) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -0,0 +1,333 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientConnAttachmentState struct {
|
||||||
|
maxReadTimeout time.Duration
|
||||||
|
maxWriteTimeout time.Duration
|
||||||
|
msgEn func([]byte, []byte) []byte
|
||||||
|
msgDe func([]byte, []byte) []byte
|
||||||
|
fastStreamEncode transportFastStreamEncoder
|
||||||
|
fastBulkEncode transportFastBulkEncoder
|
||||||
|
fastPlainEncode transportFastPlainEncoder
|
||||||
|
handshakeRsaKey []byte
|
||||||
|
secretKey []byte
|
||||||
|
lastHeartBeat int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneClientConnAttachmentState(src *clientConnAttachmentState) *clientConnAttachmentState {
|
||||||
|
if src == nil {
|
||||||
|
return &clientConnAttachmentState{}
|
||||||
|
}
|
||||||
|
cloned := *src
|
||||||
|
cloned.handshakeRsaKey = cloneClientConnAttachmentBytes(src.handshakeRsaKey)
|
||||||
|
cloned.secretKey = cloneClientConnAttachmentBytes(src.secretKey)
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneClientConnAttachmentBytes(src []byte) []byte {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return append([]byte(nil), src...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) attachmentStateSnapshot() *clientConnAttachmentState {
|
||||||
|
if c == nil {
|
||||||
|
return &clientConnAttachmentState{}
|
||||||
|
}
|
||||||
|
if state := c.attachment.Load(); state != nil {
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
client.attachment.Store(state)
|
||||||
|
}
|
||||||
|
return cloneClientConnAttachmentState(state)
|
||||||
|
}
|
||||||
|
client := c.compatClientConn()
|
||||||
|
if client != nil {
|
||||||
|
if state := client.attachment.Load(); state != nil {
|
||||||
|
if c.attachment.CompareAndSwap(nil, state) {
|
||||||
|
client.attachment.Store(state)
|
||||||
|
return cloneClientConnAttachmentState(state)
|
||||||
|
}
|
||||||
|
return c.attachmentStateSnapshot()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &clientConnAttachmentState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) setAttachmentState(state *clientConnAttachmentState) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next := cloneClientConnAttachmentState(state)
|
||||||
|
c.attachment.Store(next)
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
client.attachment.Store(next)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) updateAttachmentState(apply func(*clientConnAttachmentState)) {
|
||||||
|
if c == nil || apply == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
current := c.attachment.Load()
|
||||||
|
if current == nil {
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
current = client.attachment.Load()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next := cloneClientConnAttachmentState(current)
|
||||||
|
apply(next)
|
||||||
|
if current == nil {
|
||||||
|
if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) {
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
client.attachment.Store(next)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c.attachment.CompareAndSwap(current, next) {
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
client.attachment.Store(next)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnAttachmentStateSnapshot() *clientConnAttachmentState {
|
||||||
|
if c == nil {
|
||||||
|
return &clientConnAttachmentState{}
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
return logical.attachmentStateSnapshot()
|
||||||
|
}
|
||||||
|
if state := c.attachment.Load(); state != nil {
|
||||||
|
return cloneClientConnAttachmentState(state)
|
||||||
|
}
|
||||||
|
return &clientConnAttachmentState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnAttachmentState(state *clientConnAttachmentState) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.setAttachmentState(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.attachment.Store(cloneClientConnAttachmentState(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) updateClientConnAttachmentState(apply func(*clientConnAttachmentState)) {
|
||||||
|
if c == nil || apply == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.updateAttachmentState(apply)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
current := c.attachment.Load()
|
||||||
|
next := cloneClientConnAttachmentState(current)
|
||||||
|
apply(next)
|
||||||
|
if current == nil {
|
||||||
|
if c.attachment.CompareAndSwap((*clientConnAttachmentState)(nil), next) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if c.attachment.CompareAndSwap(current, next) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) applyClientConnAttachmentProfile(maxReadTimeout time.Duration, maxWriteTimeout time.Duration, msgEn func([]byte, []byte) []byte, msgDe func([]byte, []byte) []byte, handshakeRsaKey []byte, secretKey []byte) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.maxReadTimeout = maxReadTimeout
|
||||||
|
state.maxWriteTimeout = maxWriteTimeout
|
||||||
|
state.msgEn = msgEn
|
||||||
|
state.msgDe = msgDe
|
||||||
|
state.handshakeRsaKey = cloneClientConnAttachmentBytes(handshakeRsaKey)
|
||||||
|
state.secretKey = cloneClientConnAttachmentBytes(secretKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) inheritClientConnAttachmentProfile(src *ClientConn) {
|
||||||
|
if c == nil || src == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setClientConnAttachmentState(src.clientConnAttachmentStateSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnMaxReadTimeoutSnapshot() time.Duration {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().maxReadTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnMaxWriteTimeout(timeout time.Duration) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.updateAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.maxWriteTimeout = timeout
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.maxWriteTimeout = timeout
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnMaxWriteTimeoutSnapshot() time.Duration {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().maxWriteTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnMsgEnSnapshot() func([]byte, []byte) []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().msgEn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnMsgEn(fn func([]byte, []byte) []byte) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.msgEn = fn
|
||||||
|
state.fastStreamEncode = nil
|
||||||
|
state.fastBulkEncode = nil
|
||||||
|
state.fastPlainEncode = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnMsgDeSnapshot() func([]byte, []byte) []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().msgDe
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnMsgDe(fn func([]byte, []byte) []byte) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.msgDe = fn
|
||||||
|
state.fastStreamEncode = nil
|
||||||
|
state.fastBulkEncode = nil
|
||||||
|
state.fastPlainEncode = nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnFastStreamEncode(fn transportFastStreamEncoder) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.fastStreamEncode = fn
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnFastStreamEncodeSnapshot() transportFastStreamEncoder {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().fastStreamEncode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnFastBulkEncode(fn transportFastBulkEncoder) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.fastBulkEncode = fn
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnFastBulkEncodeSnapshot() transportFastBulkEncoder {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().fastBulkEncode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnFastPlainEncode(fn transportFastPlainEncoder) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.fastPlainEncode = fn
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnFastPlainEncodeSnapshot() transportFastPlainEncoder {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().fastPlainEncode
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnHandshakeRsaKeySnapshot() []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().handshakeRsaKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnSecretKeySnapshot() []byte {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().secretKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnSecretKey(key []byte) {
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.secretKey = cloneClientConnAttachmentBytes(key)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnLastHeartbeatUnixSnapshot() int64 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return c.clientConnAttachmentStateSnapshot().lastHeartBeat
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnLastHeartbeatUnix(unix int64) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.setClientConnLastHeartbeatUnix(unix)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.updateClientConnAttachmentState(func(state *clientConnAttachmentState) {
|
||||||
|
state.lastHeartBeat = unix
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) markClientConnHeartbeatNow() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.markHeartbeatNow()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setClientConnLastHeartbeatUnix(time.Now().Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnRemoteAddr(addr net.Addr) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state := c.ensureLogicalConnState()
|
||||||
|
if state == nil {
|
||||||
|
c.ClientAddr = addr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state.updatePeer(func(peer *logicalConnPeerState) {
|
||||||
|
peer.clientAddr = addr
|
||||||
|
})
|
||||||
|
c.syncLegacyLogicalFieldsFromState(state)
|
||||||
|
}
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ClientConn) startClientConnSession(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
|
||||||
|
if c == nil {
|
||||||
|
return stopCtx, stopFn
|
||||||
|
}
|
||||||
|
return c.LogicalConn().startSession(tuConn, stopCtx, stopFn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) startClientConnSessionTransport(tuConn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (context.Context, context.CancelFunc) {
|
||||||
|
if c == nil {
|
||||||
|
return stopCtx, stopFn
|
||||||
|
}
|
||||||
|
return c.LogicalConn().startSessionTransport(tuConn, stopCtx, stopFn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) attachClientConnSessionTransport(tuConn net.Conn) error {
|
||||||
|
if c == nil {
|
||||||
|
return errors.New("client conn is nil")
|
||||||
|
}
|
||||||
|
return c.LogicalConn().attachSessionTransport(tuConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) detachClientConnTransportForTransfer() (net.Conn, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errors.New("client conn is nil")
|
||||||
|
}
|
||||||
|
return c.LogicalConn().detachTransportForTransfer()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) stopServerOwnedSession(reason string, err error) {
|
||||||
|
c.stopServerOwnedSessionWith(nil, reason, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) stopServerOwnedSession(reason string, err error) {
|
||||||
|
c.stopServerOwnedSessionWith(nil, reason, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) stopServerOwnedSessionWith(removeFn func(*ClientConn), reason string, err error) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.markSessionStopped(reason, err)
|
||||||
|
c.detachServerOwnedSessionWith(removeFn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) stopServerOwnedSessionWith(removeFn func(*LogicalConn), reason string, err error) {
|
||||||
|
client := c.compatClientConn()
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
client.markSessionStopped(reason, err)
|
||||||
|
c.detachServerOwnedSessionWith(removeFn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) detachServerOwnedSession() {
|
||||||
|
c.detachServerOwnedSessionWith(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) detachServerOwnedSession() {
|
||||||
|
c.detachServerOwnedSessionWith(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) detachServerOwnedSessionWith(removeFn func(*ClientConn)) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.detachServerOwnedTransport()
|
||||||
|
if removeFn != nil {
|
||||||
|
removeFn(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.server != nil {
|
||||||
|
c.server.removeClient(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) detachServerOwnedSessionWith(removeFn func(*LogicalConn)) {
|
||||||
|
client := c.compatClientConn()
|
||||||
|
if client == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.detachServerOwnedTransport()
|
||||||
|
if removeFn != nil {
|
||||||
|
removeFn(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if client.server != nil {
|
||||||
|
client.server.removeLogical(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) detachServerOwnedTransport() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.LogicalConn().detachServerOwnedTransport()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) detachServerOwnedTransport() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.closeTransport()
|
||||||
|
c.clearSessionRuntimeTransport()
|
||||||
|
}
|
||||||
@@ -0,0 +1,232 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientConnSessionRuntime struct {
|
||||||
|
transport *transportBinding
|
||||||
|
transportAttached bool
|
||||||
|
transportGeneration uint64
|
||||||
|
tuConn net.Conn
|
||||||
|
stopCtx context.Context
|
||||||
|
stopFn context.CancelFunc
|
||||||
|
transportStopCtx context.Context
|
||||||
|
transportStopFn context.CancelFunc
|
||||||
|
transportDone chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnSessionRuntime(rt *clientConnSessionRuntime) {
|
||||||
|
if c == nil || rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
if rt.transport == nil && rt.tuConn != nil {
|
||||||
|
rt.transport = newTransportBinding(rt.tuConn, nil)
|
||||||
|
}
|
||||||
|
normalizeClientConnSessionRuntimeTransportState(rt)
|
||||||
|
ensureClientConnSessionRuntimeTransportLifecycle(rt)
|
||||||
|
ensureClientConnSessionRuntimeTransportDone(rt)
|
||||||
|
c.sessionRuntime.Store(rt)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical.setSessionRuntime(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnSessionRuntimeSnapshot() *clientConnSessionRuntime {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
state := c.ensureLogicalConnRuntimeState()
|
||||||
|
if state == nil {
|
||||||
|
return c.sessionRuntime.Load()
|
||||||
|
}
|
||||||
|
rt := state.sessionRuntimeSnapshot()
|
||||||
|
if rt != c.sessionRuntime.Load() {
|
||||||
|
c.sessionRuntime.Store(rt)
|
||||||
|
}
|
||||||
|
return rt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clearClientConnSessionRuntimeTransport() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transportStopFn != nil {
|
||||||
|
rt.transportStopFn()
|
||||||
|
}
|
||||||
|
next := *rt
|
||||||
|
next.transport = nil
|
||||||
|
next.transportAttached = false
|
||||||
|
next.transportGeneration = 0
|
||||||
|
next.tuConn = nil
|
||||||
|
next.transportStopCtx = nil
|
||||||
|
next.transportStopFn = nil
|
||||||
|
next.transportDone = nil
|
||||||
|
c.setClientConnSessionRuntime(&next)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical.clearSessionRuntimeTransport()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportSnapshot() net.Conn {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
return rt.transport.connSnapshot()
|
||||||
|
}
|
||||||
|
return rt.tuConn
|
||||||
|
}
|
||||||
|
return logical.transportSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnStopContextSnapshot() context.Context {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rt.stopCtx
|
||||||
|
}
|
||||||
|
return logical.stopContextSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnStopFuncSnapshot() context.CancelFunc {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rt.stopFn
|
||||||
|
}
|
||||||
|
return logical.stopFuncSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) closeClientConnTransport() {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
conn := c.clientConnTransportSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logical.closeTransport()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportBindingSnapshot() *transportBinding {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
return rt.transport
|
||||||
|
}
|
||||||
|
if rt.tuConn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return newTransportBinding(rt.tuConn, nil)
|
||||||
|
}
|
||||||
|
return logical.transportBindingSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeClientConnSessionRuntimeTransportState(rt *clientConnSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
rt.transportAttached = rt.transport.connSnapshot() != nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt.transportAttached = rt.tuConn != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureClientConnSessionRuntimeTransportLifecycle(rt *clientConnSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.tuConn == nil {
|
||||||
|
rt.transportStopCtx = nil
|
||||||
|
rt.transportStopFn = nil
|
||||||
|
rt.transportDone = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transportStopCtx != nil && rt.transportStopFn != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parent := rt.stopCtx
|
||||||
|
if parent == nil {
|
||||||
|
parent = context.Background()
|
||||||
|
}
|
||||||
|
rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.tuConn == nil {
|
||||||
|
rt.transportDone = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transportDone != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt.transportDone = make(chan struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeClientConnSessionRuntimeTransportDone(rt *clientConnSessionRuntime) {
|
||||||
|
if rt == nil || rt.transportDone == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-rt.transportDone:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
close(rt.transportDone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportStopContextSnapshot() context.Context {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transportStopCtx != nil {
|
||||||
|
return rt.transportStopCtx
|
||||||
|
}
|
||||||
|
return rt.stopCtx
|
||||||
|
}
|
||||||
|
return logical.transportStopContextSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) clientConnTransportAttachedSnapshot() bool {
|
||||||
|
logical := c.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rt.transportAttached
|
||||||
|
}
|
||||||
|
return logical.transportAttachedSnapshot()
|
||||||
|
}
|
||||||
@@ -0,0 +1,443 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/stario"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientConnReadTUMessagePreservesServerStopReason(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
client, _, _ := newRegisteredServerClientForTest(t, server, "client-stop", left, stopCtx, stopFn)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.readTUMessage()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
server.stopClientSession(client, "recv stop signal from server", nil)
|
||||||
|
_ = right.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("readTUMessage should exit after server stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
if status := client.Status(); status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after server stop: %+v", status)
|
||||||
|
}
|
||||||
|
if got := server.GetLogicalConn(client.ClientID); got != nil {
|
||||||
|
t.Fatalf("logical should be removed after server stop, got %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConnReadTUMessageReadErrorStopsAndRemovesClient(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
client, _, _ := newRegisteredServerClientForTest(t, server, "client-read-error", left, stopCtx, stopFn)
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.readTUMessage()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = right.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("readTUMessage should exit after read error")
|
||||||
|
}
|
||||||
|
|
||||||
|
status := client.Status()
|
||||||
|
if status.Alive || status.Reason != "read error" || status.Err == nil {
|
||||||
|
t.Fatalf("unexpected status after read error: %+v", status)
|
||||||
|
}
|
||||||
|
if got := server.GetLogicalConn(client.ClientID); got != nil {
|
||||||
|
t.Fatalf("logical should be removed after read error, got %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConnMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) {
|
||||||
|
client := &ClientConn{}
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||||
|
defer runtimeCancel()
|
||||||
|
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
|
||||||
|
stopCtx: runtimeCtx,
|
||||||
|
stopFn: runtimeCancel,
|
||||||
|
})
|
||||||
|
|
||||||
|
client.markSessionStopped("runtime stop", nil)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-runtimeCtx.Done():
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("runtime stop context should be canceled by markSessionStopped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConnDetachServerOwnedSessionClearsRuntimeTransport(t *testing.T) {
|
||||||
|
client := &ClientConn{}
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
client.startClientConnSession(left, stopCtx, stopFn)
|
||||||
|
|
||||||
|
client.detachServerOwnedSession()
|
||||||
|
|
||||||
|
if got := client.clientConnTransportSnapshot(); got != nil {
|
||||||
|
t.Fatalf("runtime transport should be cleared after detach, got %v", got)
|
||||||
|
}
|
||||||
|
if got := client.clientConnStopContextSnapshot(); got != stopCtx {
|
||||||
|
t.Fatalf("runtime stop context should be preserved after detach, got %v want %v", got, stopCtx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConnReadFromTUTransportUsesRuntimeConn(t *testing.T) {
|
||||||
|
client := &ClientConn{}
|
||||||
|
runtimeLeft, runtimeRight := net.Pipe()
|
||||||
|
defer runtimeLeft.Close()
|
||||||
|
defer runtimeRight.Close()
|
||||||
|
|
||||||
|
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||||
|
defer runtimeCancel()
|
||||||
|
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
|
||||||
|
tuConn: runtimeLeft,
|
||||||
|
stopCtx: runtimeCtx,
|
||||||
|
stopFn: runtimeCancel,
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := []byte("runtime-tu-conn")
|
||||||
|
writeDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
_, err := runtimeRight.Write(payload)
|
||||||
|
writeDone <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
num, data, err := client.readFromTUTransport()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("readFromTUTransport failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := string(data[:num]), string(payload); got != want {
|
||||||
|
t.Fatalf("payload mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case err := <-writeDone:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("runtime writer failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("runtime writer did not finish")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStartClientConnSessionInitializesDefaultRuntime(t *testing.T) {
|
||||||
|
client := &ClientConn{}
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := client.startClientConnSession(left, nil, nil)
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
if !client.Status().Alive {
|
||||||
|
t.Fatalf("client should start alive: %+v", client.Status())
|
||||||
|
}
|
||||||
|
if stopCtx == nil || stopFn == nil {
|
||||||
|
t.Fatal("startClientConnSession should initialize default stop context")
|
||||||
|
}
|
||||||
|
if got := client.clientConnTransportSnapshot(); got != left {
|
||||||
|
t.Fatal("runtime transport snapshot should match passed conn")
|
||||||
|
}
|
||||||
|
if got := client.clientConnStopContextSnapshot(); got != stopCtx {
|
||||||
|
t.Fatal("runtime stop context snapshot should match returned context")
|
||||||
|
}
|
||||||
|
if got := client.clientConnStopFuncSnapshot(); got == nil {
|
||||||
|
t.Fatal("runtime stop func snapshot should be initialized")
|
||||||
|
}
|
||||||
|
if got := client.GetRemoteAddr(); got == nil || got.String() != left.RemoteAddr().String() {
|
||||||
|
t.Fatalf("client remote addr mismatch: got %v want %v", got, left.RemoteAddr())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogicalConnSessionTransportLifecycleUsesLogicalRuntimeOwner(t *testing.T) {
|
||||||
|
client := &ClientConn{ClientID: "logical-runtime"}
|
||||||
|
logical := client.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
t.Fatal("LogicalConn should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
firstLeft, firstRight := net.Pipe()
|
||||||
|
defer firstRight.Close()
|
||||||
|
stopCtx, stopFn := logical.startSession(firstLeft, nil, nil)
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
if stopCtx == nil {
|
||||||
|
t.Fatal("logical startSession should initialize stop context")
|
||||||
|
}
|
||||||
|
if got := logical.transportSnapshot(); got != firstLeft {
|
||||||
|
t.Fatalf("logical transport snapshot mismatch: got %v want %v", got, firstLeft)
|
||||||
|
}
|
||||||
|
if !logical.transportAttachedSnapshot() {
|
||||||
|
t.Fatal("logical transport should be attached after startSession")
|
||||||
|
}
|
||||||
|
|
||||||
|
firstGeneration := logical.transportGenerationSnapshot()
|
||||||
|
if firstGeneration == 0 {
|
||||||
|
t.Fatal("logical transport generation should advance for stream runtime")
|
||||||
|
}
|
||||||
|
|
||||||
|
secondLeft, secondRight := net.Pipe()
|
||||||
|
defer secondRight.Close()
|
||||||
|
if err := logical.attachSessionTransport(secondLeft); err != nil {
|
||||||
|
t.Fatalf("logical attachSessionTransport failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := logical.transportSnapshot(); got != secondLeft {
|
||||||
|
t.Fatalf("logical transport snapshot after attach mismatch: got %v want %v", got, secondLeft)
|
||||||
|
}
|
||||||
|
if !logical.transportAttachedSnapshot() {
|
||||||
|
t.Fatal("logical transport should stay attached after attachSessionTransport")
|
||||||
|
}
|
||||||
|
if got := logical.transportGenerationSnapshot(); got <= firstGeneration {
|
||||||
|
t.Fatalf("logical transport generation should advance after attach: got %d want > %d", got, firstGeneration)
|
||||||
|
}
|
||||||
|
|
||||||
|
detachedConn, err := logical.detachTransportForTransfer()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("logical detachTransportForTransfer failed: %v", err)
|
||||||
|
}
|
||||||
|
if detachedConn != secondLeft {
|
||||||
|
t.Fatalf("detached conn mismatch: got %v want %v", detachedConn, secondLeft)
|
||||||
|
}
|
||||||
|
if got := logical.transportSnapshot(); got != nil {
|
||||||
|
t.Fatalf("logical transport should be cleared after detach, got %v", got)
|
||||||
|
}
|
||||||
|
if logical.transportAttachedSnapshot() {
|
||||||
|
t.Fatal("logical transport should be detached after detachTransportForTransfer")
|
||||||
|
}
|
||||||
|
if got := logical.stopContextSnapshot(); got != stopCtx {
|
||||||
|
t.Fatalf("logical stop context should be preserved after detach, got %v want %v", got, stopCtx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogicalConnOwnerStateMutationsSyncLegacyClientView(t *testing.T) {
|
||||||
|
client := &ClientConn{ClientID: "logical-owner-state"}
|
||||||
|
logical := client.LogicalConn()
|
||||||
|
if logical == nil {
|
||||||
|
t.Fatal("LogicalConn should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
logical.markIdentityBound()
|
||||||
|
logical.markStreamTransport()
|
||||||
|
attachGeneration := logical.markTransportAttached()
|
||||||
|
logical.setClientConnLastHeartbeatUnix(12345)
|
||||||
|
logical.markTransportDetached("read error", errors.New("boom"))
|
||||||
|
|
||||||
|
if !client.clientConnIdentityBoundSnapshot() {
|
||||||
|
t.Fatal("legacy client identity-bound snapshot should follow logical state")
|
||||||
|
}
|
||||||
|
if !client.clientConnUsesStreamTransportSnapshot() {
|
||||||
|
t.Fatal("legacy client stream-transport snapshot should follow logical state")
|
||||||
|
}
|
||||||
|
if got := client.clientConnTransportGenerationSnapshot(); got != attachGeneration {
|
||||||
|
t.Fatalf("legacy client transport generation = %d, want %d", got, attachGeneration)
|
||||||
|
}
|
||||||
|
if got := client.clientConnLastHeartbeatUnixSnapshot(); got != 12345 {
|
||||||
|
t.Fatalf("legacy client last heartbeat = %d, want %d", got, 12345)
|
||||||
|
}
|
||||||
|
detach := client.clientConnTransportDetachSnapshot()
|
||||||
|
if detach == nil {
|
||||||
|
t.Fatal("legacy client detach snapshot should follow logical state")
|
||||||
|
}
|
||||||
|
if detach.Reason != "read error" || detach.Err != "boom" || detach.Generation != attachGeneration {
|
||||||
|
t.Fatalf("legacy client detach snapshot mismatch: %+v", detach)
|
||||||
|
}
|
||||||
|
|
||||||
|
logical.clearTransportDetachState()
|
||||||
|
if got := client.clientConnTransportDetachSnapshot(); got != nil {
|
||||||
|
t.Fatalf("legacy client detach snapshot should clear with logical state, got %+v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogicalDetachTransportForTransferKeepsHandoffConnAlive(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
client := &ClientConn{
|
||||||
|
ClientID: "client-handoff",
|
||||||
|
server: server,
|
||||||
|
}
|
||||||
|
client.startClientConnSessionTransport(left, stopCtx, stopFn)
|
||||||
|
|
||||||
|
logical := client.LogicalConn()
|
||||||
|
detachedConn, err := logical.detachTransportForTransfer()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("logical detachTransportForTransfer failed: %v", err)
|
||||||
|
}
|
||||||
|
defer detachedConn.Close()
|
||||||
|
|
||||||
|
payload := []byte("handoff-payload")
|
||||||
|
readDone := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, len(payload))
|
||||||
|
_ = right.SetReadDeadline(time.Now().Add(time.Second))
|
||||||
|
if _, err := io.ReadFull(right, buf); err != nil {
|
||||||
|
readDone <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !bytes.Equal(buf, payload) {
|
||||||
|
readDone <- fmt.Errorf("payload mismatch: got %q want %q", string(buf), string(payload))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
readDone <- nil
|
||||||
|
}()
|
||||||
|
|
||||||
|
_ = detachedConn.SetWriteDeadline(time.Now().Add(time.Second))
|
||||||
|
if _, err := detachedConn.Write(payload); err != nil {
|
||||||
|
t.Fatalf("detached handoff conn write failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-readDone:
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("handoff conn read failed: %v", err)
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("timed out waiting for handoff conn read")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConnTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
|
||||||
|
client := &ClientConn{}
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
client.setClientConnSessionRuntime(&clientConnSessionRuntime{
|
||||||
|
transport: newTransportBinding(left, nil),
|
||||||
|
tuConn: left,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
})
|
||||||
|
|
||||||
|
binding := client.clientConnTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
t.Fatal("runtime transport binding should exist")
|
||||||
|
}
|
||||||
|
if got := binding.connSnapshot(); got != left {
|
||||||
|
t.Fatal("runtime transport binding conn should match runtime conn")
|
||||||
|
}
|
||||||
|
if got := binding.queueSnapshot(); got != nil {
|
||||||
|
t.Fatalf("server-side peer binding queue should remain nil, got %v", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientConnDetachServerOwnedSessionCancelsTransportOnly(t *testing.T) {
|
||||||
|
client := &ClientConn{}
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
client.startClientConnSession(left, stopCtx, stopFn)
|
||||||
|
|
||||||
|
transportStopCtx := client.clientConnTransportStopContextSnapshot()
|
||||||
|
client.detachServerOwnedSession()
|
||||||
|
|
||||||
|
if transportStopCtx == nil {
|
||||||
|
t.Fatal("transport stop context should exist")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-transportStopCtx.Done():
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("transport stop context should be canceled after detach")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-client.clientConnStopContextSnapshot().Done():
|
||||||
|
t.Fatal("logical stop context should remain active after pure detach")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if client.clientConnTransportAttachedSnapshot() {
|
||||||
|
t.Fatal("client conn transport should be marked detached after pure detach")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttachClientConnSessionTransportRebindsRuntimeAndStartsReadLoop(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, 1024)
|
||||||
|
server.setServerSessionRuntime(&serverSessionRuntime{
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
})
|
||||||
|
|
||||||
|
oldLeft, oldRight := net.Pipe()
|
||||||
|
defer oldRight.Close()
|
||||||
|
client := &ClientConn{
|
||||||
|
ClientID: "client-reattach",
|
||||||
|
server: server,
|
||||||
|
}
|
||||||
|
client.startClientConnSession(oldLeft, stopCtx, stopFn)
|
||||||
|
|
||||||
|
newLeft, newRight := net.Pipe()
|
||||||
|
defer newRight.Close()
|
||||||
|
if err := client.attachClientConnSessionTransport(newLeft); err != nil {
|
||||||
|
t.Fatalf("attachClientConnSessionTransport failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := client.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
t.Fatal("client conn runtime should exist after attach")
|
||||||
|
}
|
||||||
|
if rt.tuConn != newLeft || !rt.transportAttached {
|
||||||
|
t.Fatalf("attached client conn runtime mismatch: %+v", rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
wire := queue.BuildMessage([]byte("reattached"))
|
||||||
|
if _, err := newRight.Write(wire); err != nil {
|
||||||
|
t.Fatalf("new transport write failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-queue.RestoreChan():
|
||||||
|
source := assertServerInboundQueueSource(t, msg.Conn, client)
|
||||||
|
if got, want := source.TransportGeneration, client.clientConnTransportGenerationSnapshot(); got != want {
|
||||||
|
t.Fatalf("queue transport generation mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := string(msg.Msg), "reattached"; got != want {
|
||||||
|
t.Fatalf("queue payload mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("reattached server-owned transport did not push framed message")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,38 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newStartedClientConnForTest(t *testing.T, id string, server Server, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) {
|
||||||
|
t.Helper()
|
||||||
|
client := &ClientConn{
|
||||||
|
ClientID: id,
|
||||||
|
server: server,
|
||||||
|
}
|
||||||
|
stopCtx, stopFn = client.startClientConnSession(conn, stopCtx, stopFn)
|
||||||
|
return client, stopCtx, stopFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRegisteredServerClientForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*ClientConn, context.Context, context.CancelFunc) {
|
||||||
|
t.Helper()
|
||||||
|
client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn)
|
||||||
|
server.getPeerRegistry().registerClient(client)
|
||||||
|
return client, stopCtx, stopFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRegisteredServerLogicalForTest(t *testing.T, server *ServerCommon, id string, conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc) (*LogicalConn, context.Context, context.CancelFunc) {
|
||||||
|
t.Helper()
|
||||||
|
client, stopCtx, stopFn := newStartedClientConnForTest(t, id, server, conn, stopCtx, stopFn)
|
||||||
|
logical := logicalConnFromClient(client)
|
||||||
|
server.getPeerRegistry().registerLogical(logical)
|
||||||
|
return logical, stopCtx, stopFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerCodecClientConnForTest(server *ServerCommon) *ClientConn {
|
||||||
|
client := &ClientConn{server: server}
|
||||||
|
client.applyClientConnAttachmentProfile(0, 0, server.defaultMsgEn, server.defaultMsgDe, server.handshakeRsaKey, server.SecretKey)
|
||||||
|
return client
|
||||||
|
}
|
||||||
@@ -0,0 +1,223 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type serverLogicalTransportDetacher interface {
|
||||||
|
detachLogicalSessionTransport(logical *LogicalConn, reason string, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverInboundSourcePusher interface {
|
||||||
|
pushMessageSource([]byte, interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) readTUMessage() {
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.readTUMessageLoop(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) readTUMessageLoop(rt *clientConnSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopCtx := rt.transportStopCtx
|
||||||
|
if stopCtx == nil {
|
||||||
|
stopCtx = rt.stopCtx
|
||||||
|
}
|
||||||
|
if stopCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn := rt.tuConn
|
||||||
|
generation := rt.transportGeneration
|
||||||
|
defer closeClientConnSessionRuntimeTransportDone(rt)
|
||||||
|
buf := streamReadBuffer()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-sessionStopChan(stopCtx):
|
||||||
|
if c.shouldCloseTransportOnStop(conn) {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
num, data, err := c.readFromTUTransportConnWithBuffer(conn, buf)
|
||||||
|
if !c.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
data = streamReadBuffer()
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||||
|
}
|
||||||
|
num, err := conn.Read(data)
|
||||||
|
return num, data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
|
||||||
|
if err == os.ErrDeadlineExceeded {
|
||||||
|
if num != 0 {
|
||||||
|
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-sessionStopChan(stopCtx):
|
||||||
|
if c.shouldCloseTransportOnStop(conn) {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if detacher, ok := c.Server().(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
|
||||||
|
detacher.detachLogicalSessionTransport(c, "read error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.stopServerOwnedSession("read error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
|
||||||
|
if c == nil || len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server := c.Server()
|
||||||
|
if server == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pusher, ok := server.(serverInboundSourcePusher); ok {
|
||||||
|
pusher.pushMessageSource(data, newServerInboundSource(c, conn, nil, generation))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
server.pushMessage(data, c.clientConnIDSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) shouldCloseTransportOnStop(conn net.Conn) bool {
|
||||||
|
if c == nil || conn == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil || !rt.transportAttached {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
current := rt.tuConn
|
||||||
|
if rt.transport != nil && rt.transport.connSnapshot() != nil {
|
||||||
|
current = rt.transport.connSnapshot()
|
||||||
|
}
|
||||||
|
return current == conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) readFromTUTransport() (int, []byte, error) {
|
||||||
|
binding := c.clientConnTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
conn := binding.connSnapshot()
|
||||||
|
return c.readFromTUTransportConn(conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) readFromTUTransportConn(conn net.Conn) (int, []byte, error) {
|
||||||
|
return c.readFromTUTransportConnWithBuffer(conn, streamReadBuffer())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) readFromTUTransportConnWithBuffer(conn net.Conn, data []byte) (int, []byte, error) {
|
||||||
|
if logical := c.LogicalConn(); logical != nil {
|
||||||
|
return logical.readFromTUTransportConnWithBuffer(conn, data)
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
data = streamReadBuffer()
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return 0, nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
if timeout := c.clientConnMaxReadTimeoutSnapshot(); timeout > 0 {
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(timeout))
|
||||||
|
}
|
||||||
|
num, err := conn.Read(data)
|
||||||
|
return num, data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) handleTUTransportReadResult(num int, data []byte, err error) bool {
|
||||||
|
return c.handleTUTransportReadResultWithSession(c.clientConnTransportStopContextSnapshot(), c.clientConnTransportSnapshot(), c.clientConnTransportGenerationSnapshot(), num, data, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) handleTUTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, generation uint64, num int, data []byte, err error) bool {
|
||||||
|
if logical := c.LogicalConn(); logical != nil {
|
||||||
|
return logical.handleTUTransportReadResultWithSession(stopCtx, conn, generation, num, data, err)
|
||||||
|
}
|
||||||
|
if err == os.ErrDeadlineExceeded {
|
||||||
|
if num != 0 {
|
||||||
|
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
select {
|
||||||
|
case <-sessionStopChan(stopCtx):
|
||||||
|
if c.shouldCloseClientConnTransportOnStop(conn) {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if detacher, ok := c.server.(serverLogicalTransportDetacher); ok && c.shouldPreserveLogicalPeerOnTransportLoss() {
|
||||||
|
detacher.detachLogicalSessionTransport(logicalConnFromClient(c), "read error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.stopServerOwnedSession("read error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.pushServerOwnedTransportMessage(data[:num], conn, generation)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) pushServerOwnedTransportMessage(data []byte, conn net.Conn, generation uint64) {
|
||||||
|
if logical := c.LogicalConn(); logical != nil {
|
||||||
|
logical.pushServerOwnedTransportMessage(data, conn, generation)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c == nil || c.server == nil || len(data) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if pusher, ok := c.server.(serverInboundSourcePusher); ok {
|
||||||
|
pusher.pushMessageSource(data, newServerInboundSource(logicalConnFromClient(c), conn, nil, generation))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.server.pushMessage(data, c.clientConnIDSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) shouldCloseClientConnTransportOnStop(conn net.Conn) bool {
|
||||||
|
if logical := c.LogicalConn(); logical != nil {
|
||||||
|
return logical.shouldCloseTransportOnStop(conn)
|
||||||
|
}
|
||||||
|
if c == nil || conn == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
rt := c.clientConnSessionRuntimeSnapshot()
|
||||||
|
if rt == nil || !rt.transportAttached {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
current := rt.tuConn
|
||||||
|
if rt.transport != nil && rt.transport.connSnapshot() != nil {
|
||||||
|
current = rt.transport.connSnapshot()
|
||||||
|
}
|
||||||
|
return current == conn
|
||||||
|
}
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "sync/atomic"
|
||||||
|
|
||||||
|
type clientConnTransportState struct {
|
||||||
|
streamTransport atomic.Bool
|
||||||
|
transportGen atomic.Uint64
|
||||||
|
attachCount atomic.Uint64
|
||||||
|
detachCount atomic.Uint64
|
||||||
|
lastAttachAt atomic.Int64
|
||||||
|
transportDetach atomic.Pointer[clientConnTransportDetachState]
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneClientConnTransportDetachState(src *clientConnTransportDetachState) *clientConnTransportDetachState {
|
||||||
|
if src == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cloned := *src
|
||||||
|
return &cloned
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) ensureTransportState() *clientConnTransportState {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if state := c.transportState.Load(); state != nil {
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
client.transportState.Store(state)
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
client := c.compatClientConn()
|
||||||
|
if client != nil {
|
||||||
|
if state := client.transportState.Load(); state != nil {
|
||||||
|
if c.transportState.CompareAndSwap(nil, state) {
|
||||||
|
client.transportState.Store(state)
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
return c.ensureTransportState()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
state := &clientConnTransportState{}
|
||||||
|
if c.transportState.CompareAndSwap(nil, state) {
|
||||||
|
if client != nil {
|
||||||
|
client.transportState.Store(state)
|
||||||
|
}
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
return c.ensureTransportState()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) ensureClientConnTransportState() *clientConnTransportState {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
return logical.ensureTransportState()
|
||||||
|
}
|
||||||
|
if state := c.transportState.Load(); state != nil {
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
state := &clientConnTransportState{}
|
||||||
|
if c.transportState.CompareAndSwap(nil, state) {
|
||||||
|
return state
|
||||||
|
}
|
||||||
|
return c.transportState.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientConn) setClientConnTransportDetachState(state *clientConnTransportDetachState) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if logical := c.logicalView.Load(); logical != nil {
|
||||||
|
logical.setTransportDetachState(state)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
transportState := c.ensureClientConnTransportState()
|
||||||
|
if transportState == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
transportState.transportDetach.Store(cloneClientConnTransportDetachState(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *LogicalConn) setTransportDetachState(state *clientConnTransportDetachState) {
|
||||||
|
transportState := c.ensureTransportState()
|
||||||
|
if transportState == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
transportState.transportDetach.Store(cloneClientConnTransportDetachState(state))
|
||||||
|
if client := c.compatClientConn(); client != nil {
|
||||||
|
client.transportState.Store(transportState)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,121 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/notify/internal/transport"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
clientConnectSourceConn = "conn"
|
||||||
|
clientConnectSourceNetwork = "network"
|
||||||
|
clientConnectSourceTimeout = "timeout"
|
||||||
|
clientConnectSourceFactory = "factory"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errClientReconnectSourceUnavailable = errors.New("client reconnect source is unavailable")
|
||||||
|
|
||||||
|
type clientConnectSource struct {
|
||||||
|
kind string
|
||||||
|
network string
|
||||||
|
addr string
|
||||||
|
dialFn func(context.Context) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientConnConnectSource(conn net.Conn) *clientConnectSource {
|
||||||
|
source := &clientConnectSource{kind: clientConnectSourceConn}
|
||||||
|
if conn == nil {
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
if remoteAddr := conn.RemoteAddr(); remoteAddr != nil {
|
||||||
|
source.network = remoteAddr.Network()
|
||||||
|
source.addr = remoteAddr.String()
|
||||||
|
}
|
||||||
|
if source.network == "" {
|
||||||
|
if localAddr := conn.LocalAddr(); localAddr != nil {
|
||||||
|
source.network = localAddr.Network()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return source
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientNetworkConnectSource(network string, addr string) *clientConnectSource {
|
||||||
|
return &clientConnectSource{
|
||||||
|
kind: clientConnectSourceNetwork,
|
||||||
|
network: network,
|
||||||
|
addr: addr,
|
||||||
|
dialFn: func(context.Context) (net.Conn, error) {
|
||||||
|
return transport.Dial(network, addr)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientTimeoutConnectSource(network string, addr string, timeout time.Duration) *clientConnectSource {
|
||||||
|
return &clientConnectSource{
|
||||||
|
kind: clientConnectSourceTimeout,
|
||||||
|
network: network,
|
||||||
|
addr: addr,
|
||||||
|
dialFn: func(context.Context) (net.Conn, error) {
|
||||||
|
return transport.DialTimeout(network, addr, timeout)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientFactoryConnectSource(dialFn func(context.Context) (net.Conn, error)) *clientConnectSource {
|
||||||
|
return &clientConnectSource{
|
||||||
|
kind: clientConnectSourceFactory,
|
||||||
|
dialFn: dialFn,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *clientConnectSource) clone() *clientConnectSource {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := *s
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *clientConnectSource) canReconnect() bool {
|
||||||
|
return s != nil && s.dialFn != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *clientConnectSource) isUDP() bool {
|
||||||
|
if s == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return transport.IsUDPNetwork(s.network)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *clientConnectSource) dial(ctx context.Context) (net.Conn, error) {
|
||||||
|
if s == nil || s.dialFn == nil {
|
||||||
|
return nil, errClientReconnectSourceUnavailable
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return s.dialFn(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setClientConnectSource(source *clientConnectSource) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if source == nil {
|
||||||
|
c.connectSource.Store(nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.connectSource.Store(source.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientConnectSourceSnapshot() *clientConnectSource {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if source := c.connectSource.Load(); source != nil {
|
||||||
|
return source.clone()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchMsg(message Message) {
|
||||||
|
switch message.TransferMsg.Type {
|
||||||
|
case MSG_SYS_WAIT:
|
||||||
|
fallthrough
|
||||||
|
case MSG_SYS:
|
||||||
|
c.sysMsg(message)
|
||||||
|
return
|
||||||
|
case MSG_KEY_CHANGE:
|
||||||
|
fallthrough
|
||||||
|
case MSG_SYS_REPLY:
|
||||||
|
fallthrough
|
||||||
|
case MSG_SYNC_REPLY:
|
||||||
|
if c.getPendingWaitPool().deliver(message.ID, message) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if c.dispatchInternalTransferControl(message) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
callFn := func(fn func(*Message)) {
|
||||||
|
fn(&message)
|
||||||
|
}
|
||||||
|
fn, ok := c.linkFns[message.Key]
|
||||||
|
if ok {
|
||||||
|
callFn(fn)
|
||||||
|
}
|
||||||
|
if c.defaultFns != nil {
|
||||||
|
callFn(c.defaultFns)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sysMsg(message Message) {
|
||||||
|
switch message.Key {
|
||||||
|
case "bye":
|
||||||
|
if message.TransferMsg.Type == MSG_SYS_WAIT {
|
||||||
|
c.setByeFromServer(true)
|
||||||
|
message.Reply(nil)
|
||||||
|
c.stopClientSession("recv stop signal from server", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.stopClientSessionFromServer("recv stop signal from server", nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/starcrypto"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow.
|
||||||
|
// Prefer UseModernPSKClient.
|
||||||
|
func (c *ClientCommon) ExchangeKey(newKey []byte) error {
|
||||||
|
pubKey, err := starcrypto.DecodeRsaPublicKey(c.handshakeRsaPubKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
newSendKey, err := starcrypto.RSAEncrypt(pubKey, newKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
data, err := c.sendWait(TransferMsg{
|
||||||
|
ID: 19961127,
|
||||||
|
Key: "sirius",
|
||||||
|
Value: newSendKey,
|
||||||
|
Type: MSG_KEY_CHANGE,
|
||||||
|
}, time.Second*10)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if string(data.Value) != "success" {
|
||||||
|
return errors.New("cannot exchange new aes-key")
|
||||||
|
}
|
||||||
|
c.SecretKey = newKey
|
||||||
|
time.Sleep(time.Millisecond * 100)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: aesRsaHello is the legacy RSA-based key exchange bootstrap.
|
||||||
|
func aesRsaHello(c Client) error {
|
||||||
|
newAesKey := []byte(fmt.Sprintf("%d%d%d%s", time.Now().UnixNano(), rand.Int63(), rand.Int63(), "b612.me"))
|
||||||
|
newAesKey = []byte(starcrypto.Md5Str(newAesKey))
|
||||||
|
return c.ExchangeKey(newAesKey)
|
||||||
|
}
|
||||||
@@ -0,0 +1,158 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestReconnectClientRejectsDirectConnSource(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ReconnectClient(context.Background(), client)
|
||||||
|
if !errors.Is(err, errClientReconnectSourceUnavailable) {
|
||||||
|
t.Fatalf("ReconnectClient error = %v, want %v", err, errClientReconnectSourceUnavailable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnectClientWithFactorySource(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
|
||||||
|
dialCount := 0
|
||||||
|
var peers []net.Conn
|
||||||
|
dialFn := func(context.Context) (net.Conn, error) {
|
||||||
|
dialCount++
|
||||||
|
left, right := net.Pipe()
|
||||||
|
peers = append(peers, right)
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.ConnectByFactory(context.Background(), dialFn); err != nil {
|
||||||
|
t.Fatalf("ConnectByFactory failed: %v", err)
|
||||||
|
}
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
before, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot before reconnect failed: %v", err)
|
||||||
|
}
|
||||||
|
if !before.CanReconnect || before.ConnectSource != clientConnectSourceFactory {
|
||||||
|
t.Fatalf("unexpected reconnect snapshot before reconnect: %+v", before)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ReconnectClient(context.Background(), client); err != nil {
|
||||||
|
t.Fatalf("ReconnectClient failed: %v", err)
|
||||||
|
}
|
||||||
|
after, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot after reconnect failed: %v", err)
|
||||||
|
}
|
||||||
|
if !after.Alive || !after.HasRuntimeConn || !after.CanReconnect {
|
||||||
|
t.Fatalf("unexpected reconnect snapshot after reconnect: %+v", after)
|
||||||
|
}
|
||||||
|
if got, want := dialCount, 2; got != want {
|
||||||
|
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("final Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, peer := range peers {
|
||||||
|
_ = peer.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReconnectClientWithRetryRecordsRetryState(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
|
||||||
|
dialCount := 0
|
||||||
|
wantErr := errors.New("dial failed once")
|
||||||
|
var peers []net.Conn
|
||||||
|
dialFn := func(context.Context) (net.Conn, error) {
|
||||||
|
dialCount++
|
||||||
|
if dialCount == 2 {
|
||||||
|
return nil, wantErr
|
||||||
|
}
|
||||||
|
left, right := net.Pipe()
|
||||||
|
peers = append(peers, right)
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.ConnectByFactory(context.Background(), dialFn); err != nil {
|
||||||
|
t.Fatalf("ConnectByFactory failed: %v", err)
|
||||||
|
}
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ReconnectClientWithRetry(context.Background(), client, &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: 0,
|
||||||
|
MaxDelay: 0,
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("ReconnectClientWithRetry failed: %v", err)
|
||||||
|
}
|
||||||
|
snapshot, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Retry.RetryEventTotal, uint64(1); got != want {
|
||||||
|
t.Fatalf("retry events mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Retry.LastRetryAttempt, 1; got != want {
|
||||||
|
t.Fatalf("last retry attempt mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Retry.LastRetryError, wantErr.Error(); got != want {
|
||||||
|
t.Fatalf("last retry error mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if snapshot.Retry.LastResultError != "" {
|
||||||
|
t.Fatalf("last result error should be empty, got %q", snapshot.Retry.LastResultError)
|
||||||
|
}
|
||||||
|
if got, want := dialCount, 3; got != want {
|
||||||
|
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("final Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, peer := range peers {
|
||||||
|
_ = peer.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,66 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetRecordStreamHandler(fn func(RecordAcceptInfo) error) {
|
||||||
|
runtime := c.getRecordRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime.setHandler(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errStreamClientNil
|
||||||
|
}
|
||||||
|
opt = normalizeRecordOpenOptions(opt)
|
||||||
|
stream, err := c.OpenStream(ctx, opt.Stream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
record, err := WrapStreamAsRecord(stream, opt)
|
||||||
|
if err != nil {
|
||||||
|
_ = stream.Reset(err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return record, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) claimInboundRecordStream(stream *streamHandle) (bool, error) {
|
||||||
|
if stream == nil || stream.Channel() != StreamRecordChannel {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
runtime := c.getRecordRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return true, errRecordRuntimeNil
|
||||||
|
}
|
||||||
|
handler := runtime.handlerSnapshot()
|
||||||
|
if handler == nil {
|
||||||
|
return true, errRecordHandlerNotConfigured
|
||||||
|
}
|
||||||
|
record, err := WrapStreamAsRecord(stream, RecordOpenOptions{
|
||||||
|
Stream: StreamOpenOptions{
|
||||||
|
ID: stream.ID(),
|
||||||
|
Channel: stream.Channel(),
|
||||||
|
Metadata: stream.Metadata(),
|
||||||
|
ReadTimeout: stream.readTimeoutSnapshot(),
|
||||||
|
WriteTimeout: stream.writeTimeoutSnapshot(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return true, err
|
||||||
|
}
|
||||||
|
info := RecordAcceptInfo{
|
||||||
|
ID: stream.ID(),
|
||||||
|
Metadata: stream.Metadata(),
|
||||||
|
TransportGeneration: stream.TransportGeneration(),
|
||||||
|
RecordStream: record,
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
if err := handler(info); err != nil {
|
||||||
|
_ = record.Reset(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,523 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/stario"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ClientCommon) closeClientTransport() {
|
||||||
|
c.closeClientTransportBinding(c.clientTransportBindingSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) closeClientTransportConn(conn net.Conn) {
|
||||||
|
if c == nil || conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) closeClientTransportBinding(binding *transportBinding) {
|
||||||
|
if binding == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.closeClientTransportConn(binding.connSnapshot())
|
||||||
|
binding.stopBackgroundWorkers()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) beginClientSessionEpoch() uint64 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return atomic.AddUint64(&c.sessionEpoch, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) currentClientSessionEpoch() uint64 {
|
||||||
|
if c == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return atomic.LoadUint64(&c.sessionEpoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) isClientSessionEpochCurrent(epoch uint64) bool {
|
||||||
|
if c == nil || epoch == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return c.currentClientSessionEpoch() == epoch
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) stopClientSessionIfCurrent(epoch uint64, reason string, err error) bool {
|
||||||
|
if !c.isClientSessionEpochCurrent(epoch) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.stopClientSession(reason, err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setByeFromServer(val bool) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
c.byeFromServer = val
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) resetClientStopState() {
|
||||||
|
c.setByeFromServer(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) shouldSayGoodByeOnStop() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
return !c.byeFromServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) stopClientSession(reason string, err error) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.markSessionStopped(reason, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) stopClientSessionFromServer(reason string, err error) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.setByeFromServer(true)
|
||||||
|
c.markSessionStopped(reason, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) beginClientConnectAttempt() (func(success bool), error) {
|
||||||
|
if !c.beginClientSessionStart() {
|
||||||
|
return nil, errors.New("client already run")
|
||||||
|
}
|
||||||
|
return func(success bool) {
|
||||||
|
if success {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.cleanupFailedClientStart()
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientCanAttachTransport() bool {
|
||||||
|
if c == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !sessionIsAlive(&c.alive) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if c.clientTransportAttachedSnapshot() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rt.stopCtx != nil && rt.queue != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) attachClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
|
||||||
|
if c == nil {
|
||||||
|
return errors.New("client is nil")
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return errors.New("conn is nil")
|
||||||
|
}
|
||||||
|
if err := c.attachClientSessionTransport(conn); err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.bootstrapClientTransportRuntime(c.clientSessionRuntimeSnapshot(), true, false); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.setClientConnectSource(source)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) Connect(network string, addr string) error {
|
||||||
|
if err := c.validateSecurityConfiguration(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
source := newClientNetworkConnectSource(network, addr)
|
||||||
|
c.applySignalReliabilityTransportDefault(source.isUDP())
|
||||||
|
if c.clientCanAttachTransport() {
|
||||||
|
conn, err := source.dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.attachClientWithConnSource(conn, source)
|
||||||
|
}
|
||||||
|
finish, err := c.beginClientConnectAttempt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started := false
|
||||||
|
defer func() {
|
||||||
|
finish(started)
|
||||||
|
}()
|
||||||
|
conn, err := source.dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) ConnectTimeout(network string, addr string, timeout time.Duration) error {
|
||||||
|
if err := c.validateSecurityConfiguration(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
source := newClientTimeoutConnectSource(network, addr, timeout)
|
||||||
|
c.applySignalReliabilityTransportDefault(source.isUDP())
|
||||||
|
if c.clientCanAttachTransport() {
|
||||||
|
conn, err := source.dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.attachClientWithConnSource(conn, source)
|
||||||
|
}
|
||||||
|
finish, err := c.beginClientConnectAttempt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started := false
|
||||||
|
defer func() {
|
||||||
|
finish(started)
|
||||||
|
}()
|
||||||
|
conn, err := source.dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) ConnectByConn(conn net.Conn) error {
|
||||||
|
if err := c.validateSecurityConfiguration(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return errors.New("conn is nil")
|
||||||
|
}
|
||||||
|
source := newClientConnConnectSource(conn)
|
||||||
|
c.applySignalReliabilityTransportDefault(false)
|
||||||
|
if c.clientCanAttachTransport() {
|
||||||
|
return c.attachClientWithConnSource(conn, source)
|
||||||
|
}
|
||||||
|
finish, err := c.beginClientConnectAttempt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started := false
|
||||||
|
defer func() {
|
||||||
|
finish(started)
|
||||||
|
}()
|
||||||
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error {
|
||||||
|
if err := c.validateSecurityConfiguration(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if dialFn == nil {
|
||||||
|
return errors.New("dialFn is nil")
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
source := newClientFactoryConnectSource(dialFn)
|
||||||
|
if c.clientCanAttachTransport() {
|
||||||
|
c.applySignalReliabilityTransportDefault(false)
|
||||||
|
conn, err := dialFn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return errors.New("conn is nil")
|
||||||
|
}
|
||||||
|
return c.attachClientWithConnSource(conn, source)
|
||||||
|
}
|
||||||
|
finish, err := c.beginClientConnectAttempt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started := false
|
||||||
|
defer func() {
|
||||||
|
finish(started)
|
||||||
|
}()
|
||||||
|
conn, err := dialFn(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return errors.New("conn is nil")
|
||||||
|
}
|
||||||
|
c.applySignalReliabilityTransportDefault(false)
|
||||||
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) startClientWithConn(conn net.Conn) error {
|
||||||
|
return c.startClientWithConnSource(conn, newClientConnConnectSource(conn))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) startClientWithConnSource(conn net.Conn, source *clientConnectSource) error {
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
epoch := c.beginClientSessionEpoch()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
c.setClientConnectSource(source)
|
||||||
|
rt := newClientSessionRuntime(conn, stopCtx, stopFn, queue, epoch)
|
||||||
|
c.setClientSessionRuntime(rt)
|
||||||
|
c.resetClientStopState()
|
||||||
|
c.markSessionStarted()
|
||||||
|
return c.clientPostInit(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) monitorPool() {
|
||||||
|
c.monitorPoolLoop(c.clientStopContextSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) monitorPoolLoop(stopCtx context.Context) {
|
||||||
|
if stopCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopCtx.Done():
|
||||||
|
if c.clientStopContextSnapshot() == stopCtx {
|
||||||
|
c.getPendingWaitPool().closeAll()
|
||||||
|
c.getFileAckPool().closeAll()
|
||||||
|
c.getSignalAckPool().closeAll()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-time.After(time.Second * 30):
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
c.getPendingWaitPool().cleanupExpired(int64(c.noFinSyncMsgMaxKeepSeconds), now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientPostInit(rt *clientSessionRuntime) error {
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
go c.monitorPoolLoop(rt.stopCtx)
|
||||||
|
if err := c.startClientTransportRuntime(rt); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return c.bootstrapClientTransportRuntime(rt, true, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) startClientTransportRuntime(rt *clientSessionRuntime) error {
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
transportStopCtx := rt.transportStopCtx
|
||||||
|
if transportStopCtx == nil {
|
||||||
|
transportStopCtx = rt.stopCtx
|
||||||
|
}
|
||||||
|
if c.useHeartBeat {
|
||||||
|
go c.heartbeatLoop(transportStopCtx, rt.epoch)
|
||||||
|
}
|
||||||
|
go c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, rt.epoch)
|
||||||
|
go c.loadMessageLoop(rt)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) bootstrapClientTransportRuntime(rt *clientSessionRuntime, runKeyExchange bool, stopSessionOnFailure bool) error {
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if runKeyExchange && !c.skipKeyExchange {
|
||||||
|
if err := c.keyExchangeFn(c); err != nil {
|
||||||
|
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "key exchange failed", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err := c.announceClientPeerIdentity(); err != nil {
|
||||||
|
return c.failClientTransportBootstrap(rt, stopSessionOnFailure, "peer attach failed", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) failClientTransportBootstrap(rt *clientSessionRuntime, stopSessionOnFailure bool, reason string, err error) error {
|
||||||
|
if c == nil || rt == nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.retireClientSessionRuntime(rt, true)
|
||||||
|
c.closeClientTransportConn(rt.conn)
|
||||||
|
if stopSessionOnFailure {
|
||||||
|
c.stopClientSessionIfCurrent(rt.epoch, reason, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.clearClientSessionRuntimeTransport()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) Heartbeat() {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
epoch := rt.epoch
|
||||||
|
if epoch == 0 {
|
||||||
|
epoch = c.currentClientSessionEpoch()
|
||||||
|
}
|
||||||
|
transportStopCtx := rt.transportStopCtx
|
||||||
|
if transportStopCtx == nil {
|
||||||
|
transportStopCtx = rt.stopCtx
|
||||||
|
}
|
||||||
|
c.heartbeatLoop(transportStopCtx, epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) heartbeatLoop(stopCtx context.Context, epoch uint64) {
|
||||||
|
if stopCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
failedCount := 0
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopCtx.Done():
|
||||||
|
return
|
||||||
|
case <-time.After(c.heartbeatPeriod):
|
||||||
|
}
|
||||||
|
err := c.sendHeartbeat()
|
||||||
|
var stop bool
|
||||||
|
failedCount, stop = c.handleHeartbeatResultWithSession(epoch, err, failedCount)
|
||||||
|
if stop {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readMessage() {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
epoch := rt.epoch
|
||||||
|
if epoch == 0 {
|
||||||
|
epoch = c.currentClientSessionEpoch()
|
||||||
|
}
|
||||||
|
transportStopCtx := rt.transportStopCtx
|
||||||
|
if transportStopCtx == nil {
|
||||||
|
transportStopCtx = rt.stopCtx
|
||||||
|
}
|
||||||
|
c.readMessageLoop(transportStopCtx, rt.conn, rt.queue, epoch)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readMessageLoop(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, epoch uint64) {
|
||||||
|
if stopCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
binding := newTransportBinding(conn, queue)
|
||||||
|
dispatcher := c.clientInboundDispatcherSnapshot()
|
||||||
|
buf := streamReadBuffer()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopCtx.Done():
|
||||||
|
c.closeClientTransportBinding(binding)
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
readNum, data, err := c.readFromTransportBindingWithBuffer(binding, buf)
|
||||||
|
if !c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, dispatcher) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sayGoodBye() error {
|
||||||
|
_, err := c.sendWait(TransferMsg{
|
||||||
|
ID: 10010,
|
||||||
|
Key: "bye",
|
||||||
|
Value: nil,
|
||||||
|
Type: MSG_SYS_WAIT,
|
||||||
|
}, time.Second*3)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) loadMessage() {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.loadMessageLoop(rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) loadMessageLoop(rt *clientSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopCtx := rt.transportStopCtx
|
||||||
|
if stopCtx == nil {
|
||||||
|
stopCtx = rt.stopCtx
|
||||||
|
}
|
||||||
|
if stopCtx == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
queue := rt.queue
|
||||||
|
if rt.transport != nil {
|
||||||
|
queue = rt.transport.queueSnapshot()
|
||||||
|
}
|
||||||
|
if queue == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dispatcher := rt.inboundDispatcher
|
||||||
|
if dispatcher == nil {
|
||||||
|
dispatcher = newInboundDispatcher()
|
||||||
|
defer dispatcher.CloseAndWait()
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-stopCtx.Done():
|
||||||
|
sessionStopping := rt.stopCtx != nil && rt.stopCtx.Err() != nil
|
||||||
|
if sessionStopping && rt.inboundDispatcher != nil {
|
||||||
|
rt.inboundDispatcher.CloseAndWait()
|
||||||
|
}
|
||||||
|
if sessionStopping && !rt.runtimeShouldSuppressGoodByeOnStop() && c.shouldSayGoodByeOnStop() {
|
||||||
|
c.sayGoodBye()
|
||||||
|
}
|
||||||
|
c.closeClientTransportBinding(rt.transport)
|
||||||
|
return
|
||||||
|
case data, ok := <-queue.RestoreChan():
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
msg := data
|
||||||
|
c.wg.Add(1)
|
||||||
|
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
now := time.Now()
|
||||||
|
if err := c.dispatchInboundTransportPayload(msg.Msg, now); err != nil {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client decode envelope error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
c.wg.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
+193
@@ -0,0 +1,193 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ClientCommon) send(msg TransferMsg) (WaitMsg, error) {
|
||||||
|
if err := c.ensureClientSendReady(); err != nil {
|
||||||
|
return WaitMsg{}, err
|
||||||
|
}
|
||||||
|
var wait WaitMsg
|
||||||
|
if msg.Type != MSG_SYNC_REPLY && msg.Type != MSG_KEY_CHANGE && msg.Type != MSG_SYS_REPLY || msg.ID == 0 {
|
||||||
|
msg.ID = atomic.AddUint64(&c.msgID, 1)
|
||||||
|
}
|
||||||
|
env, err := wrapTransferMsgEnvelope(msg, c.sequenceEn)
|
||||||
|
if err != nil {
|
||||||
|
return WaitMsg{}, err
|
||||||
|
}
|
||||||
|
if requiresSignalReplyWait(msg) {
|
||||||
|
wait = c.getPendingWaitPool().createAndStore(msg)
|
||||||
|
}
|
||||||
|
err = c.sendSignalEnvelopeMaybeReliable(env, msg)
|
||||||
|
if err != nil {
|
||||||
|
if requiresSignalReplyWait(msg) {
|
||||||
|
c.getPendingWaitPool().removeAndClose(msg.ID)
|
||||||
|
}
|
||||||
|
return WaitMsg{}, err
|
||||||
|
}
|
||||||
|
return wait, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendEnvelope(env Envelope) error {
|
||||||
|
if err := c.ensureClientSendReady(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
payload, err := c.encodeEnvelopePayload(env)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if batchedControlEnvelope(env) {
|
||||||
|
return c.writeControlPayloadToTransport(payload)
|
||||||
|
}
|
||||||
|
return c.writePayloadToTransport(payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchEnvelope(env Envelope, now time.Time) {
|
||||||
|
switch env.Kind {
|
||||||
|
case EnvelopeSignalAck:
|
||||||
|
if c.handleSignalAckEnvelope(env) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case EnvelopeStreamData:
|
||||||
|
c.dispatchStreamEnvelope(env)
|
||||||
|
return
|
||||||
|
case EnvelopeSignal:
|
||||||
|
transfer, err := unwrapTransferMsgEnvelope(env, c.sequenceDe)
|
||||||
|
if err != nil {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client unwrap signal envelope error", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if c.handleReceivedSignalReliability(transfer) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
message := Message{
|
||||||
|
ServerConn: c,
|
||||||
|
TransferMsg: transfer,
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Time: now,
|
||||||
|
}
|
||||||
|
c.dispatchMsg(message)
|
||||||
|
case EnvelopeFileMeta, EnvelopeFileChunk, EnvelopeFileEnd, EnvelopeFileAbort, EnvelopeAck:
|
||||||
|
c.dispatchFileEnvelope(env, now)
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) Send(key string, value MsgVal) error {
|
||||||
|
_, err := c.send(TransferMsg{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Type: MSG_ASYNC,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendWait(msg TransferMsg, timeout time.Duration) (Message, error) {
|
||||||
|
data, err := c.send(msg)
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, err
|
||||||
|
}
|
||||||
|
stopCh := sessionStopChan(c.clientStopContextSnapshot())
|
||||||
|
if timeout.Seconds() == 0 {
|
||||||
|
msg, ok := <-data.Reply
|
||||||
|
if !ok {
|
||||||
|
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-time.After(timeout):
|
||||||
|
c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID)
|
||||||
|
return Message{}, os.ErrDeadlineExceeded
|
||||||
|
case <-stopCh:
|
||||||
|
return Message{}, errServiceShutdown
|
||||||
|
case msg, ok := <-data.Reply:
|
||||||
|
if !ok {
|
||||||
|
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendCtx(msg TransferMsg, ctx context.Context) (Message, error) {
|
||||||
|
data, err := c.send(msg)
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, err
|
||||||
|
}
|
||||||
|
stopCh := sessionStopChan(c.clientStopContextSnapshot())
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.getPendingWaitPool().removeAndClose(data.TransferMsg.ID)
|
||||||
|
return Message{}, normalizeStreamDeadlineError(ctx.Err())
|
||||||
|
case <-stopCh:
|
||||||
|
return Message{}, errServiceShutdown
|
||||||
|
case msg, ok := <-data.Reply:
|
||||||
|
if !ok {
|
||||||
|
return msg, pendingWaitClosedErrorWith(stopCh, clientTransportDetachedError(c))
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error) {
|
||||||
|
data, err := c.sequenceEn(val)
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, err
|
||||||
|
}
|
||||||
|
return c.sendCtx(TransferMsg{
|
||||||
|
Key: key,
|
||||||
|
Value: data,
|
||||||
|
Type: MSG_SYNC_ASK,
|
||||||
|
}, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SendObj(key string, val interface{}) error {
|
||||||
|
data, err := encode(val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = c.send(TransferMsg{
|
||||||
|
Key: key,
|
||||||
|
Value: data,
|
||||||
|
Type: MSG_ASYNC,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SendCtx(ctx context.Context, key string, value MsgVal) (Message, error) {
|
||||||
|
return c.sendCtx(TransferMsg{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Type: MSG_SYNC_ASK,
|
||||||
|
}, ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SendWait(key string, value MsgVal, timeout time.Duration) (Message, error) {
|
||||||
|
return c.sendWait(TransferMsg{
|
||||||
|
Key: key,
|
||||||
|
Value: value,
|
||||||
|
Type: MSG_SYNC_ASK,
|
||||||
|
}, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error) {
|
||||||
|
data, err := c.sequenceEn(value)
|
||||||
|
if err != nil {
|
||||||
|
return Message{}, err
|
||||||
|
}
|
||||||
|
return c.SendWait(key, data, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) Reply(m Message, value MsgVal) error {
|
||||||
|
return m.Reply(value)
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientStopSessionIfCurrentEpoch(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
staleEpoch := client.beginClientSessionEpoch()
|
||||||
|
currentEpoch := client.beginClientSessionEpoch()
|
||||||
|
|
||||||
|
if client.stopClientSessionIfCurrent(staleEpoch, "stale", nil) {
|
||||||
|
t.Fatal("stale epoch should not stop current session")
|
||||||
|
}
|
||||||
|
status := client.Status()
|
||||||
|
if !status.Alive || status.Reason != "" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after stale stop: %+v", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !client.stopClientSessionIfCurrent(currentEpoch, "current", nil) {
|
||||||
|
t.Fatal("current epoch should stop session")
|
||||||
|
}
|
||||||
|
status = client.Status()
|
||||||
|
if status.Alive || status.Reason != "current" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after current stop: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReadErrorWithStaleEpochDoesNotStopCurrentSession(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
staleEpoch := client.beginClientSessionEpoch()
|
||||||
|
currentEpoch := client.beginClientSessionEpoch()
|
||||||
|
|
||||||
|
readErr := errors.New("read failed")
|
||||||
|
client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, staleEpoch)
|
||||||
|
|
||||||
|
status := client.Status()
|
||||||
|
if !status.Alive || status.Reason != "" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after stale read error: %+v", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.handleTransportReadResultWithSession(context.Background(), nil, nil, 0, nil, readErr, currentEpoch)
|
||||||
|
|
||||||
|
status = client.Status()
|
||||||
|
if status.Alive || status.Reason != "client read error" || !errors.Is(status.Err, readErr) {
|
||||||
|
t.Fatalf("unexpected status after current read error: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHeartbeatFailureWithStaleEpochDoesNotStopCurrentSession(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
staleEpoch := client.beginClientSessionEpoch()
|
||||||
|
currentEpoch := client.beginClientSessionEpoch()
|
||||||
|
heartbeatErr := errors.New("heartbeat failed")
|
||||||
|
|
||||||
|
failedCount, stop := client.handleHeartbeatResultWithSession(staleEpoch, heartbeatErr, 2)
|
||||||
|
if failedCount != 3 || !stop {
|
||||||
|
t.Fatalf("unexpected stale heartbeat result: failedCount=%d stop=%v", failedCount, stop)
|
||||||
|
}
|
||||||
|
status := client.Status()
|
||||||
|
if !status.Alive || status.Reason != "" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after stale heartbeat error: %+v", status)
|
||||||
|
}
|
||||||
|
|
||||||
|
failedCount, stop = client.handleHeartbeatResultWithSession(currentEpoch, heartbeatErr, 2)
|
||||||
|
if failedCount != 3 || !stop {
|
||||||
|
t.Fatalf("unexpected current heartbeat result: failedCount=%d stop=%v", failedCount, stop)
|
||||||
|
}
|
||||||
|
status = client.Status()
|
||||||
|
if status.Alive || status.Reason != "heartbeat failed more than 3 times" || status.Err == nil || status.Err.Error() != "heartbeat failed more than 3 times" {
|
||||||
|
t.Fatalf("unexpected status after current heartbeat error: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,323 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/stario"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type clientSessionRuntime struct {
|
||||||
|
transport *transportBinding
|
||||||
|
transportAttached bool
|
||||||
|
conn net.Conn
|
||||||
|
stopCtx context.Context
|
||||||
|
stopFn context.CancelFunc
|
||||||
|
transportStopCtx context.Context
|
||||||
|
transportStopFn context.CancelFunc
|
||||||
|
queue *stario.StarQueue
|
||||||
|
inboundDispatcher *inboundDispatcher
|
||||||
|
epoch uint64
|
||||||
|
suppressGoodByeOnStop *atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientSessionRuntimeBase(stopCtx context.Context, stopFn context.CancelFunc) *clientSessionRuntime {
|
||||||
|
return &clientSessionRuntime{
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
inboundDispatcher: newInboundDispatcher(),
|
||||||
|
suppressGoodByeOnStop: &atomic.Bool{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareClientSessionRuntime(rt *clientSessionRuntime) *clientSessionRuntime {
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.inboundDispatcher == nil {
|
||||||
|
rt.inboundDispatcher = newInboundDispatcher()
|
||||||
|
}
|
||||||
|
if rt.suppressGoodByeOnStop == nil {
|
||||||
|
rt.suppressGoodByeOnStop = &atomic.Bool{}
|
||||||
|
}
|
||||||
|
if rt.transport == nil && rt.conn != nil {
|
||||||
|
rt.transport = newTransportBinding(rt.conn, rt.queue)
|
||||||
|
}
|
||||||
|
normalizeClientSessionRuntimeTransportState(rt)
|
||||||
|
ensureClientSessionRuntimeTransportLifecycle(rt)
|
||||||
|
return rt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setClientSessionRuntime(rt *clientSessionRuntime) {
|
||||||
|
if c == nil || rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var oldBinding *transportBinding
|
||||||
|
if prev := c.clientSessionRuntimeSnapshot(); prev != nil && prev.transport != nil && prev.transport != rt.transport {
|
||||||
|
oldBinding = prev.transport
|
||||||
|
}
|
||||||
|
rt = prepareClientSessionRuntime(rt)
|
||||||
|
c.sessionRuntime.Store(rt)
|
||||||
|
c.stopCtx = rt.stopCtx
|
||||||
|
c.stopFn = rt.stopFn
|
||||||
|
if rt.transport != nil {
|
||||||
|
c.queue = rt.transport.queueSnapshot()
|
||||||
|
c.conn = rt.transport.connSnapshot()
|
||||||
|
} else {
|
||||||
|
c.queue = rt.queue
|
||||||
|
c.conn = rt.conn
|
||||||
|
}
|
||||||
|
if oldBinding != nil {
|
||||||
|
oldBinding.stopBackgroundWorkers()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) resetClientSessionRuntimeBase() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
c.sessionRuntime.Store(newClientSessionRuntimeBase(stopCtx, stopFn))
|
||||||
|
c.conn = nil
|
||||||
|
c.queue = nil
|
||||||
|
c.stopCtx = stopCtx
|
||||||
|
c.stopFn = stopFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) cleanupFailedClientStart() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt != nil && rt.stopFn != nil {
|
||||||
|
rt.stopFn()
|
||||||
|
}
|
||||||
|
c.cleanupClientSessionResources()
|
||||||
|
c.rollbackClientSessionStart()
|
||||||
|
c.resetClientSessionRuntimeBase()
|
||||||
|
}
|
||||||
|
|
||||||
|
func newClientSessionRuntime(conn net.Conn, stopCtx context.Context, stopFn context.CancelFunc, queue *stario.StarQueue, epoch uint64) *clientSessionRuntime {
|
||||||
|
return prepareClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
transport: newTransportBinding(conn, queue),
|
||||||
|
transportAttached: conn != nil,
|
||||||
|
conn: conn,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
inboundDispatcher: newInboundDispatcher(),
|
||||||
|
epoch: epoch,
|
||||||
|
suppressGoodByeOnStop: &atomic.Bool{},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *clientSessionRuntime) runtimeShouldSuppressGoodByeOnStop() bool {
|
||||||
|
if rt == nil || rt.suppressGoodByeOnStop == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rt.suppressGoodByeOnStop.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rt *clientSessionRuntime) markRuntimeSuppressGoodByeOnStop() {
|
||||||
|
if rt == nil || rt.suppressGoodByeOnStop == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt.suppressGoodByeOnStop.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) retireClientSessionRuntime(rt *clientSessionRuntime, suppressGoodBye bool) {
|
||||||
|
if c == nil || rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if suppressGoodBye {
|
||||||
|
rt.markRuntimeSuppressGoodByeOnStop()
|
||||||
|
}
|
||||||
|
if rt.transportStopFn != nil {
|
||||||
|
rt.transportStopFn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clearClientSessionRuntimeTransport() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transportStopFn != nil {
|
||||||
|
rt.transportStopFn()
|
||||||
|
}
|
||||||
|
next := *rt
|
||||||
|
next.transport = nil
|
||||||
|
next.transportAttached = false
|
||||||
|
next.conn = nil
|
||||||
|
next.transportStopCtx = nil
|
||||||
|
next.transportStopFn = nil
|
||||||
|
c.setClientSessionRuntime(&next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clearClientSessionRuntimeQueue() {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next := *rt
|
||||||
|
next.queue = nil
|
||||||
|
if next.transport != nil {
|
||||||
|
next.transport = newTransportBinding(next.transport.connSnapshot(), nil)
|
||||||
|
}
|
||||||
|
c.setClientSessionRuntime(&next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) attachClientSessionTransport(conn net.Conn) error {
|
||||||
|
if c == nil {
|
||||||
|
return errors.New("client is nil")
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return errors.New("conn is nil")
|
||||||
|
}
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return errors.New("client session runtime is nil")
|
||||||
|
}
|
||||||
|
if rt.queue == nil {
|
||||||
|
return errClientSessionQueueUnavailable
|
||||||
|
}
|
||||||
|
oldBinding := rt.transport
|
||||||
|
if rt.transportStopFn != nil {
|
||||||
|
rt.transportStopFn()
|
||||||
|
}
|
||||||
|
next := *rt
|
||||||
|
next.transport = newTransportBinding(conn, rt.queue)
|
||||||
|
next.transportAttached = true
|
||||||
|
next.conn = conn
|
||||||
|
next.transportStopCtx = nil
|
||||||
|
next.transportStopFn = nil
|
||||||
|
next.suppressGoodByeOnStop = &atomic.Bool{}
|
||||||
|
c.setClientSessionRuntime(&next)
|
||||||
|
if oldConn := oldBinding.connSnapshot(); oldConn != nil && oldConn != conn {
|
||||||
|
_ = oldConn.Close()
|
||||||
|
}
|
||||||
|
return c.startClientTransportRuntime(c.clientSessionRuntimeSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientSessionRuntimeSnapshot() *clientSessionRuntime {
|
||||||
|
if c == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.sessionRuntime.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeClientSessionRuntimeTransportState(rt *clientSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
rt.transportAttached = rt.transport.connSnapshot() != nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rt.transportAttached = rt.conn != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureClientSessionRuntimeTransportLifecycle(rt *clientSessionRuntime) {
|
||||||
|
if rt == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.conn == nil {
|
||||||
|
rt.transportStopCtx = nil
|
||||||
|
rt.transportStopFn = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rt.transportStopCtx != nil && rt.transportStopFn != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
parent := rt.stopCtx
|
||||||
|
if parent == nil {
|
||||||
|
parent = context.Background()
|
||||||
|
}
|
||||||
|
rt.transportStopCtx, rt.transportStopFn = context.WithCancel(parent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientTransportConnSnapshot() net.Conn {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
return rt.transport.connSnapshot()
|
||||||
|
}
|
||||||
|
return rt.conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientInboundDispatcherSnapshot() *inboundDispatcher {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rt.inboundDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientStopContextSnapshot() context.Context {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rt.stopCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientStopFuncSnapshot() context.CancelFunc {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rt.stopFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientQueueSnapshot() *stario.StarQueue {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
return rt.transport.queueSnapshot()
|
||||||
|
}
|
||||||
|
return rt.queue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientTransportBindingSnapshot() *transportBinding {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transport != nil {
|
||||||
|
return rt.transport
|
||||||
|
}
|
||||||
|
if rt.conn == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return newTransportBinding(rt.conn, rt.queue)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientTransportStopContextSnapshot() context.Context {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if rt.transportStopCtx != nil {
|
||||||
|
return rt.transportStopCtx
|
||||||
|
}
|
||||||
|
return rt.stopCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientTransportAttachedSnapshot() bool {
|
||||||
|
rt := c.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return rt.transportAttached
|
||||||
|
}
|
||||||
@@ -0,0 +1,352 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/stario"
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientWriteToTransportUsesRuntimeConn(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
fallbackLeft, fallbackRight := net.Pipe()
|
||||||
|
defer fallbackLeft.Close()
|
||||||
|
defer fallbackRight.Close()
|
||||||
|
runtimeLeft, runtimeRight := net.Pipe()
|
||||||
|
defer runtimeLeft.Close()
|
||||||
|
defer runtimeRight.Close()
|
||||||
|
|
||||||
|
client.conn = fallbackLeft
|
||||||
|
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||||
|
defer runtimeCancel()
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
conn: runtimeLeft,
|
||||||
|
stopCtx: runtimeCtx,
|
||||||
|
stopFn: runtimeCancel,
|
||||||
|
epoch: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
payload := []byte("runtime-conn")
|
||||||
|
recvCh := make(chan []byte, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, len(payload))
|
||||||
|
_, err := io.ReadFull(runtimeRight, buf)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
recvCh <- buf
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := client.writeToTransport(payload); err != nil {
|
||||||
|
t.Fatalf("writeToTransport failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-errCh:
|
||||||
|
t.Fatalf("runtime conn read failed: %v", err)
|
||||||
|
case got := <-recvCh:
|
||||||
|
if string(got) != string(payload) {
|
||||||
|
t.Fatalf("runtime payload mismatch: got %q want %q", string(got), string(payload))
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("runtime conn did not receive payload")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = fallbackRight.SetReadDeadline(time.Now().Add(20 * time.Millisecond))
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
if _, err := fallbackRight.Read(buf); err == nil {
|
||||||
|
t.Fatal("fallback conn should not receive payload when runtime conn is active")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientMarkSessionStoppedUsesRuntimeStopFn(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if !client.beginClientSessionStart() {
|
||||||
|
t.Fatal("beginClientSessionStart should succeed")
|
||||||
|
}
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
runtimeCtx, runtimeCancel := context.WithCancel(context.Background())
|
||||||
|
defer runtimeCancel()
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
stopCtx: runtimeCtx,
|
||||||
|
stopFn: runtimeCancel,
|
||||||
|
epoch: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
fallbackCtx, fallbackCancel := context.WithCancel(context.Background())
|
||||||
|
defer fallbackCancel()
|
||||||
|
client.stopCtx = fallbackCtx
|
||||||
|
client.stopFn = fallbackCancel
|
||||||
|
|
||||||
|
client.markSessionStopped("runtime stop", nil)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-runtimeCtx.Done():
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("runtime stop context should be canceled by markSessionStopped")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-fallbackCtx.Done():
|
||||||
|
t.Fatal("fallback owner stop context should not be canceled when runtime stopFn is active")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
rt := client.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
t.Fatal("runtime snapshot should remain available after stop")
|
||||||
|
}
|
||||||
|
if rt.conn != nil || rt.queue != nil {
|
||||||
|
t.Fatalf("runtime transport should be cleared after stop: %+v", rt)
|
||||||
|
}
|
||||||
|
if rt.stopCtx == nil {
|
||||||
|
t.Fatalf("runtime stop context should be preserved after stop: %+v", rt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientClearSessionRuntimeTransportPreservesStopState(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
conn: left,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
epoch: 7,
|
||||||
|
})
|
||||||
|
|
||||||
|
client.clearClientSessionRuntimeTransport()
|
||||||
|
|
||||||
|
rt := client.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
t.Fatal("runtime snapshot should remain after transport clear")
|
||||||
|
}
|
||||||
|
if rt.conn != nil {
|
||||||
|
t.Fatalf("runtime conn should be cleared: %+v", rt)
|
||||||
|
}
|
||||||
|
if rt.queue != queue {
|
||||||
|
t.Fatalf("runtime queue should be preserved across pure transport clear: got %v want %v", rt.queue, queue)
|
||||||
|
}
|
||||||
|
if rt.stopCtx != stopCtx || rt.stopFn == nil || rt.epoch != 7 {
|
||||||
|
t.Fatalf("runtime control state should be preserved: %+v", rt)
|
||||||
|
}
|
||||||
|
if client.clientTransportAttachedSnapshot() {
|
||||||
|
t.Fatal("client transport should be marked detached after runtime clear")
|
||||||
|
}
|
||||||
|
if got := client.clientQueueSnapshot(); got != queue {
|
||||||
|
t.Fatalf("client queue snapshot should be preserved after transport clear: got %v want %v", got, queue)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientTransportBindingSnapshotUsesRuntimeBinding(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
transport: newTransportBinding(left, queue),
|
||||||
|
conn: left,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
epoch: 9,
|
||||||
|
})
|
||||||
|
|
||||||
|
binding := client.clientTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
t.Fatal("runtime transport binding should exist")
|
||||||
|
}
|
||||||
|
if got := binding.connSnapshot(); got != left {
|
||||||
|
t.Fatal("runtime transport binding conn should match runtime conn")
|
||||||
|
}
|
||||||
|
if got := binding.queueSnapshot(); got != queue {
|
||||||
|
t.Fatal("runtime transport binding queue should match runtime queue")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetireClientSessionRuntimeCancelsTransportOnly(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
rt := newClientSessionRuntime(left, stopCtx, stopFn, queue, 3)
|
||||||
|
client.setClientSessionRuntime(rt)
|
||||||
|
client.retireClientSessionRuntime(rt, true)
|
||||||
|
|
||||||
|
transportStopCtx := client.clientTransportStopContextSnapshot()
|
||||||
|
if transportStopCtx == nil {
|
||||||
|
t.Fatal("transport stop context should exist")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-transportStopCtx.Done():
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("transport stop context should be canceled by retireClientSessionRuntime")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-client.clientStopContextSnapshot().Done():
|
||||||
|
t.Fatal("logical stop context should remain active when only retiring transport")
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientClearSessionRuntimeTransportPreservesQueueForEncoding(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
UseLegacySecurityClient(client)
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
conn: left,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
epoch: 8,
|
||||||
|
})
|
||||||
|
client.markSessionStarted()
|
||||||
|
defer client.markSessionStopped("test done", nil)
|
||||||
|
|
||||||
|
client.clearClientSessionRuntimeTransport()
|
||||||
|
|
||||||
|
data, err := client.encodeEnvelope(newSignalAckEnvelope(1003))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encodeEnvelope failed after pure transport clear: %v", err)
|
||||||
|
}
|
||||||
|
if len(data) == 0 {
|
||||||
|
t.Fatal("encodeEnvelope should still return framed payload after pure transport clear")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAttachClientSessionTransportRebindsRuntimeAndDispatchesOnNewConn(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
UseLegacySecurityClient(client)
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
oldLeft, oldRight := net.Pipe()
|
||||||
|
defer oldRight.Close()
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
conn: oldLeft,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
epoch: 11,
|
||||||
|
suppressGoodByeOnStop: &atomic.Bool{},
|
||||||
|
})
|
||||||
|
client.markSessionStarted()
|
||||||
|
defer client.markSessionStopped("test done", nil)
|
||||||
|
|
||||||
|
recvCh := make(chan Message, 1)
|
||||||
|
client.SetLink("reattach", func(message *Message) {
|
||||||
|
recvCh <- *message
|
||||||
|
})
|
||||||
|
|
||||||
|
newLeft, newRight := net.Pipe()
|
||||||
|
defer newRight.Close()
|
||||||
|
if err := client.attachClientSessionTransport(newLeft); err != nil {
|
||||||
|
t.Fatalf("attachClientSessionTransport failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rt := client.clientSessionRuntimeSnapshot()
|
||||||
|
if rt == nil {
|
||||||
|
t.Fatal("runtime snapshot should exist after attach")
|
||||||
|
}
|
||||||
|
if rt.conn != newLeft || !rt.transportAttached || rt.queue != queue || rt.epoch != 11 {
|
||||||
|
t.Fatalf("attached runtime mismatch: %+v", rt)
|
||||||
|
}
|
||||||
|
|
||||||
|
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
||||||
|
ID: 42,
|
||||||
|
Key: "reattach",
|
||||||
|
Value: []byte("ok"),
|
||||||
|
Type: MSG_ASYNC,
|
||||||
|
}, client.sequenceEn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
|
||||||
|
}
|
||||||
|
wire, err := client.encodeEnvelope(env)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encodeEnvelope failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := newRight.Write(wire); err != nil {
|
||||||
|
t.Fatalf("new transport write failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case message := <-recvCh:
|
||||||
|
if got, want := message.Key, "reattach"; got != want {
|
||||||
|
t.Fatalf("message key mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := string(message.Value), "ok"; got != want {
|
||||||
|
t.Fatalf("message value mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("reattached transport did not dispatch message")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetClientSessionRuntimeStopsOldBindingWorkersOnReattach(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
queue := stario.NewQueueCtx(stopCtx, 4, math.MaxUint32)
|
||||||
|
|
||||||
|
oldLeft, oldRight := net.Pipe()
|
||||||
|
defer oldLeft.Close()
|
||||||
|
defer oldRight.Close()
|
||||||
|
oldBinding := newTransportBinding(oldLeft, queue)
|
||||||
|
oldSender := oldBinding.bulkBatchSenderSnapshot()
|
||||||
|
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
transport: oldBinding,
|
||||||
|
conn: oldLeft,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
epoch: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
newLeft, newRight := net.Pipe()
|
||||||
|
defer newLeft.Close()
|
||||||
|
defer newRight.Close()
|
||||||
|
newBinding := newTransportBinding(newLeft, queue)
|
||||||
|
|
||||||
|
client.setClientSessionRuntime(&clientSessionRuntime{
|
||||||
|
transport: newBinding,
|
||||||
|
conn: newLeft,
|
||||||
|
stopCtx: stopCtx,
|
||||||
|
stopFn: stopFn,
|
||||||
|
queue: queue,
|
||||||
|
epoch: 2,
|
||||||
|
})
|
||||||
|
|
||||||
|
err := oldSender.submit(context.Background(), []byte("payload"))
|
||||||
|
if err != errTransportDetached {
|
||||||
|
t.Fatalf("old sender submit after reattach = %v, want %v", err, errTransportDetached)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,116 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ClientCommon) SetStreamHandler(fn func(StreamAcceptInfo) error) {
|
||||||
|
runtime := c.getStreamRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime.setHandler(fn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, errStreamClientNil
|
||||||
|
}
|
||||||
|
runtime := c.getStreamRuntime()
|
||||||
|
if runtime == nil {
|
||||||
|
return nil, errStreamRuntimeNil
|
||||||
|
}
|
||||||
|
req := clientStreamRequest(runtime, opt)
|
||||||
|
if req.StreamID == "" {
|
||||||
|
return nil, errStreamIDEmpty
|
||||||
|
}
|
||||||
|
if _, exists := runtime.lookup(clientFileScope(), req.StreamID); exists {
|
||||||
|
return nil, errStreamAlreadyExists
|
||||||
|
}
|
||||||
|
resp, err := sendStreamOpenClient(ctx, c, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp.DataID != 0 {
|
||||||
|
req.DataID = resp.DataID
|
||||||
|
}
|
||||||
|
stream := newStreamHandle(c.clientStopContextSnapshot(), runtime, clientFileScope(), req, c.currentClientSessionEpoch(), nil, nil, resp.TransportGeneration, clientStreamCloseSender(c), clientStreamResetSender(c), clientStreamDataSender(c, c.currentClientSessionEpoch()), runtime.configSnapshot())
|
||||||
|
stream.setClientSnapshotOwner(c)
|
||||||
|
stream.setAddrSnapshot(c.clientStreamAddrSnapshot())
|
||||||
|
if err := runtime.register(clientFileScope(), stream); err != nil {
|
||||||
|
_, _ = sendStreamResetClient(context.Background(), c, StreamResetRequest{
|
||||||
|
StreamID: req.StreamID,
|
||||||
|
Error: err.Error(),
|
||||||
|
})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientStreamAddrSnapshot() (net.Addr, net.Addr) {
|
||||||
|
if c == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
conn := c.clientTransportConnSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return conn.LocalAddr(), conn.RemoteAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientStreamRequest(runtime *streamRuntime, opt StreamOpenOptions) StreamOpenRequest {
|
||||||
|
id := opt.ID
|
||||||
|
if id == "" && runtime != nil {
|
||||||
|
id = runtime.nextID()
|
||||||
|
}
|
||||||
|
return normalizeStreamOpenRequest(StreamOpenRequest{
|
||||||
|
StreamID: id,
|
||||||
|
Channel: opt.Channel,
|
||||||
|
Metadata: cloneStreamMetadata(opt.Metadata),
|
||||||
|
ReadTimeout: opt.ReadTimeout,
|
||||||
|
WriteTimeout: opt.WriteTimeout,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientStreamCloseSender(c *ClientCommon) streamCloseSender {
|
||||||
|
return func(ctx context.Context, stream *streamHandle, full bool) error {
|
||||||
|
_, err := sendStreamCloseClient(ctx, c, StreamCloseRequest{
|
||||||
|
StreamID: stream.ID(),
|
||||||
|
Full: full,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientStreamResetSender(c *ClientCommon) streamResetSender {
|
||||||
|
return func(ctx context.Context, stream *streamHandle, message string) error {
|
||||||
|
_, err := sendStreamResetClient(ctx, c, StreamResetRequest{
|
||||||
|
StreamID: stream.ID(),
|
||||||
|
Error: message,
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientStreamDataSender(c *ClientCommon, epoch uint64) streamDataSender {
|
||||||
|
return func(ctx context.Context, stream *streamHandle, chunk []byte) error {
|
||||||
|
if c == nil {
|
||||||
|
return errStreamClientNil
|
||||||
|
}
|
||||||
|
if epoch != 0 && !c.isClientSessionEpochCurrent(epoch) {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if dataID := stream.dataIDSnapshot(); dataID != 0 {
|
||||||
|
return c.sendFastStreamData(dataID, stream.nextOutboundDataSeq(), chunk)
|
||||||
|
}
|
||||||
|
return c.sendEnvelope(newStreamDataEnvelope(stream.ID(), chunk))
|
||||||
|
}
|
||||||
|
}
|
||||||
-159
@@ -1,159 +0,0 @@
|
|||||||
package notify
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_usechannel(t *testing.T) {
|
|
||||||
server, err := NewNotifyS("udp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.SetNotify("nihao", func(data SMsg) string {
|
|
||||||
fmt.Println("server recv:", data.Key, data.Value)
|
|
||||||
if data.Value != "" {
|
|
||||||
data.Reply("nba")
|
|
||||||
return "nb"
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
})
|
|
||||||
client, err := NewNotifyC("udp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
client.UseChannel = true
|
|
||||||
//time.Sleep(time.Second * 10)
|
|
||||||
client.Send("nihao")
|
|
||||||
client.SendValue("nihao", "lalala")
|
|
||||||
txt := <-client.Notify("nihao")
|
|
||||||
fmt.Println("client", txt)
|
|
||||||
txt = <-client.Notify("nihao")
|
|
||||||
fmt.Println("client", txt)
|
|
||||||
server.ServerStop()
|
|
||||||
<-client.Stop
|
|
||||||
client.ClientStop()
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_nochannel(t *testing.T) {
|
|
||||||
server, err := NewNotifyS("udp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.SetNotify("nihao", func(data SMsg) string {
|
|
||||||
fmt.Println("server recv:", data.Key, data.Value)
|
|
||||||
if data.Value != "" {
|
|
||||||
data.Reply("nbaz")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
})
|
|
||||||
client, err := NewNotifyC("udp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//time.Sleep(time.Second * 10)
|
|
||||||
client.UseChannel = false
|
|
||||||
client.SetNotify("nihao", func(data CMsg) {
|
|
||||||
fmt.Println("client recv:", data.Key, data.Value)
|
|
||||||
if data.Value != "" {
|
|
||||||
time.Sleep(time.Millisecond * 900)
|
|
||||||
client.SendValue("nihao", "dsb")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
client.SendValue("nihao", "lalala")
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
server.ServerStop()
|
|
||||||
<-client.Stop
|
|
||||||
client.ClientStop()
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_pipec(t *testing.T) {
|
|
||||||
server, err := NewNotifyS("tcp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.SetNotify("ni\\||hao", func(data SMsg) string {
|
|
||||||
fmt.Println("name-get", data.GetName())
|
|
||||||
fmt.Println("name-set", data.SetName("iiiis"))
|
|
||||||
fmt.Println("name-get", data.GetName())
|
|
||||||
fmt.Println("server recv:", data.Key, data.Value, data.mode)
|
|
||||||
if data.Value != "" {
|
|
||||||
data.Reply("nba")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
})
|
|
||||||
client, err := NewNotifyC("tcp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
client.UseChannel = false
|
|
||||||
sa, err := client.SendValueWait("ni\\||hao", "lalaeee", time.Second*10)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Println(sa)
|
|
||||||
sa, err = client.SendValueWait("ni\\||hao", "lalasdeee", time.Second*10)
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
fmt.Println(sa)
|
|
||||||
fmt.Println("sukidesu")
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
server.ServerStop()
|
|
||||||
<-client.Stop
|
|
||||||
client.ClientStop()
|
|
||||||
time.Sleep(time.Second * 2)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_pips(t *testing.T) {
|
|
||||||
var testmsg SMsg
|
|
||||||
server, err := NewNotifyS("udp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
server.SetNotify("nihao", func(data SMsg) string {
|
|
||||||
fmt.Println("server recv:", data.Key, data.Value, data.mode)
|
|
||||||
testmsg = data
|
|
||||||
if data.Value != "" {
|
|
||||||
data.Reply("nbaz")
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
})
|
|
||||||
client, err := NewNotifyC("udp", "127.0.0.1:1926")
|
|
||||||
if err != nil {
|
|
||||||
fmt.Println(err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
//time.Sleep(time.Second * 10)
|
|
||||||
client.UseChannel = false
|
|
||||||
client.SetNotify("nihao", func(data CMsg) {
|
|
||||||
fmt.Println("client recv:", data.Key, data.Value, data.mode)
|
|
||||||
if data.mode != "pa" {
|
|
||||||
time.Sleep(time.Millisecond * 1200)
|
|
||||||
client.ReplyMsg(data, "nihao", "dsb")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
client.SendValue("nihao", "lalala")
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
fmt.Println(server.SendWait(testmsg, "nihao", "wozuinb", time.Second*20))
|
|
||||||
fmt.Println("sakura")
|
|
||||||
server.ServerStop()
|
|
||||||
<-client.Stop
|
|
||||||
client.ClientStop()
|
|
||||||
time.Sleep(time.Second * 3)
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,206 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/stario"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func batchedControlEnvelope(env Envelope) bool {
|
||||||
|
switch env.Kind {
|
||||||
|
case EnvelopeSignal, EnvelopeSignalAck:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeDeadlineFromTimeout(timeout time.Duration) time.Time {
|
||||||
|
if timeout <= 0 {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return time.Now().Add(timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendHeartbeat() error {
|
||||||
|
_, err := c.sendWait(TransferMsg{
|
||||||
|
ID: 10000,
|
||||||
|
Key: "heartbeat",
|
||||||
|
Value: nil,
|
||||||
|
Type: MSG_SYS_WAIT,
|
||||||
|
}, time.Second*5)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleHeartbeatResult(err error, failedCount int) (int, bool) {
|
||||||
|
return c.handleHeartbeatResultWithSession(c.currentClientSessionEpoch(), err, failedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleHeartbeatResultWithSession(epoch uint64, err error, failedCount int) (int, bool) {
|
||||||
|
if err == nil {
|
||||||
|
c.lastHeartbeat = time.Now().Unix()
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
if c.debugMode {
|
||||||
|
fmt.Println("failed to recv heartbeat,timeout!")
|
||||||
|
}
|
||||||
|
failedCount++
|
||||||
|
if failedCount < 3 {
|
||||||
|
return failedCount, false
|
||||||
|
}
|
||||||
|
if c.debugMode {
|
||||||
|
fmt.Println("heatbeat failed more than 3 times,stop client")
|
||||||
|
}
|
||||||
|
if !c.stopClientSessionIfCurrent(epoch, "heartbeat failed more than 3 times", errors.New("heartbeat failed more than 3 times")) {
|
||||||
|
return failedCount, true
|
||||||
|
}
|
||||||
|
return failedCount, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readFromTransport() (int, []byte, error) {
|
||||||
|
return c.readFromTransportBinding(c.clientTransportBindingSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readFromTransportConn(conn net.Conn) (int, []byte, error) {
|
||||||
|
return c.readFromTransportBinding(newTransportBinding(conn, nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readFromTransportBinding(binding *transportBinding) (int, []byte, error) {
|
||||||
|
return c.readFromTransportBindingWithBuffer(binding, streamReadBuffer())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) readFromTransportBindingWithBuffer(binding *transportBinding, data []byte) (int, []byte, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
data = streamReadBuffer()
|
||||||
|
}
|
||||||
|
if binding == nil {
|
||||||
|
return 0, data, net.ErrClosed
|
||||||
|
}
|
||||||
|
conn := binding.connSnapshot()
|
||||||
|
if conn == nil {
|
||||||
|
return 0, data, net.ErrClosed
|
||||||
|
}
|
||||||
|
if c.maxReadTimeout.Seconds() != 0 {
|
||||||
|
_ = conn.SetReadDeadline(time.Now().Add(c.maxReadTimeout))
|
||||||
|
}
|
||||||
|
readNum, err := conn.Read(data)
|
||||||
|
return readNum, data, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleTransportReadResult(readNum int, data []byte, err error) bool {
|
||||||
|
return c.handleTransportReadResultWithSession(c.clientStopContextSnapshot(), c.clientTransportConnSnapshot(), c.clientQueueSnapshot(), readNum, data, err, c.currentClientSessionEpoch())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleTransportReadResultWithSession(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64) bool {
|
||||||
|
return c.handleTransportReadResultWithSessionDispatcher(stopCtx, conn, queue, readNum, data, err, epoch, c.clientInboundDispatcherSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) handleTransportReadResultWithSessionDispatcher(stopCtx context.Context, conn net.Conn, queue *stario.StarQueue, readNum int, data []byte, err error, epoch uint64, dispatcher *inboundDispatcher) bool {
|
||||||
|
binding := newTransportBinding(conn, queue)
|
||||||
|
if err == os.ErrDeadlineExceeded {
|
||||||
|
if readNum != 0 && queue != nil {
|
||||||
|
if !c.pushMessageFast(queue, data[:readNum], dispatcher) {
|
||||||
|
queue.ParseMessage(data[:readNum], "b612")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client read error", err)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-sessionStopChan(stopCtx):
|
||||||
|
c.closeClientTransportBinding(binding)
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
c.stopClientSessionIfCurrent(epoch, "client read error", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if queue != nil {
|
||||||
|
if !c.pushMessageFast(queue, data[:readNum], dispatcher) {
|
||||||
|
queue.ParseMessage(data[:readNum], "b612")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) pushMessageFast(queue *stario.StarQueue, data []byte, dispatcher *inboundDispatcher) bool {
|
||||||
|
if queue == nil || dispatcher == nil || len(data) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if err := queue.ParseMessageOwned(data, "b612", func(msg stario.MsgQueue) error {
|
||||||
|
payload := msg.Msg
|
||||||
|
c.wg.Add(1)
|
||||||
|
if !dispatcher.Dispatch(clientInboundDispatchSource(), func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
now := time.Now()
|
||||||
|
if err := c.dispatchInboundTransportPayload(payload, now); err != nil {
|
||||||
|
if c.showError || c.debugMode {
|
||||||
|
fmt.Println("client decode envelope error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
c.wg.Done()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}); err != nil && (c.showError || c.debugMode) {
|
||||||
|
fmt.Println("client parse inbound frame error", err)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) writeToTransport(data []byte) error {
|
||||||
|
binding := c.clientTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
return binding.withConnWriteLock(func(conn net.Conn) error {
|
||||||
|
if c.maxWriteTimeout.Seconds() != 0 {
|
||||||
|
_ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
|
||||||
|
}
|
||||||
|
return writeFullToConnUnlocked(conn, data)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) writePayloadToTransport(payload []byte) error {
|
||||||
|
binding := c.clientTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
queue := binding.queueSnapshot()
|
||||||
|
if queue == nil {
|
||||||
|
return errClientSessionQueueUnavailable
|
||||||
|
}
|
||||||
|
return binding.withConnWriteLock(func(conn net.Conn) error {
|
||||||
|
if c.maxWriteTimeout.Seconds() != 0 {
|
||||||
|
_ = conn.SetWriteDeadline(time.Now().Add(c.maxWriteTimeout))
|
||||||
|
}
|
||||||
|
return writeFramedPayloadUnlocked(conn, queue, payload)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) writeControlPayloadToTransport(payload []byte) error {
|
||||||
|
binding := c.clientTransportBindingSnapshot()
|
||||||
|
if binding == nil {
|
||||||
|
return net.ErrClosed
|
||||||
|
}
|
||||||
|
queue := binding.queueSnapshot()
|
||||||
|
if queue == nil {
|
||||||
|
return errClientSessionQueueUnavailable
|
||||||
|
}
|
||||||
|
conn := binding.connSnapshot()
|
||||||
|
if conn == nil || isPacketTransportConn(conn) {
|
||||||
|
return c.writePayloadToTransport(payload)
|
||||||
|
}
|
||||||
|
sender := binding.controlBatchSenderSnapshot()
|
||||||
|
if sender == nil {
|
||||||
|
return c.writePayloadToTransport(payload)
|
||||||
|
}
|
||||||
|
return sender.submit(payload, writeDeadlineFromTimeout(c.maxWriteTimeout))
|
||||||
|
}
|
||||||
@@ -0,0 +1,82 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client interface {
|
||||||
|
SetDefaultLink(func(message *Message))
|
||||||
|
SetLink(string, func(*Message))
|
||||||
|
SetFileHandler(func(FileEvent))
|
||||||
|
SetStreamHandler(func(StreamAcceptInfo) error)
|
||||||
|
SetRecordStreamHandler(func(RecordAcceptInfo) error)
|
||||||
|
SetBulkHandler(func(BulkAcceptInfo) error)
|
||||||
|
SetTransferHandler(func(TransferAcceptInfo) (TransferReceiveOptions, error))
|
||||||
|
GetStreamConfig() StreamConfig
|
||||||
|
SetStreamConfig(StreamConfig)
|
||||||
|
SetTransferResumeStore(TransferResumeStore)
|
||||||
|
RecoverTransferSnapshots(context.Context) error
|
||||||
|
SetFileReceiveDir(dir string) error
|
||||||
|
send(msg TransferMsg) (WaitMsg, error)
|
||||||
|
sendEnvelope(env Envelope) error
|
||||||
|
sendWait(msg TransferMsg, timeout time.Duration) (Message, error)
|
||||||
|
Send(key string, value MsgVal) error
|
||||||
|
SendWait(key string, value MsgVal, timeout time.Duration) (Message, error)
|
||||||
|
SendWaitObj(key string, value interface{}, timeout time.Duration) (Message, error)
|
||||||
|
SendCtx(ctx context.Context, key string, value MsgVal) (Message, error)
|
||||||
|
Reply(m Message, value MsgVal) error
|
||||||
|
// Deprecated: ExchangeKey drives the legacy RSA-based key exchange flow.
|
||||||
|
// Prefer UseModernPSKClient.
|
||||||
|
ExchangeKey(newKey []byte) error
|
||||||
|
Connect(network string, addr string) error
|
||||||
|
ConnectTimeout(network string, addr string, timeout time.Duration) error
|
||||||
|
ConnectByConn(conn net.Conn) error
|
||||||
|
ConnectByFactory(ctx context.Context, dialFn func(context.Context) (net.Conn, error)) error
|
||||||
|
// Deprecated: SkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||||
|
SkipExchangeKey() bool
|
||||||
|
// Deprecated: SetSkipExchangeKey only controls the legacy RSA-based key exchange.
|
||||||
|
SetSkipExchangeKey(bool)
|
||||||
|
|
||||||
|
GetMsgEn() func([]byte, []byte) []byte
|
||||||
|
// Deprecated: SetMsgEn overrides the transport codec directly.
|
||||||
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
|
SetMsgEn(func([]byte, []byte) []byte)
|
||||||
|
GetMsgDe() func([]byte, []byte) []byte
|
||||||
|
// Deprecated: SetMsgDe overrides the transport codec directly.
|
||||||
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
|
SetMsgDe(func([]byte, []byte) []byte)
|
||||||
|
|
||||||
|
Heartbeat()
|
||||||
|
HeartbeatPeroid() time.Duration
|
||||||
|
SetHeartbeatPeroid(duration time.Duration)
|
||||||
|
|
||||||
|
GetSecretKey() []byte
|
||||||
|
// Deprecated: SetSecretKey injects a raw transport key directly.
|
||||||
|
// Prefer UseModernPSKClient or UseLegacySecurityClient.
|
||||||
|
SetSecretKey(key []byte)
|
||||||
|
// Deprecated: RsaPubKey exposes the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||||
|
RsaPubKey() []byte
|
||||||
|
// Deprecated: SetRsaPubKey configures the legacy RSA handshake key. Prefer UseModernPSKClient.
|
||||||
|
SetRsaPubKey([]byte)
|
||||||
|
|
||||||
|
Stop() error
|
||||||
|
StopMonitorChan() <-chan struct{}
|
||||||
|
Status() Status
|
||||||
|
ShowError(bool)
|
||||||
|
DebugMode(bool)
|
||||||
|
IsDebugMode() bool
|
||||||
|
|
||||||
|
GetSequenceEn() func(interface{}) ([]byte, error)
|
||||||
|
SetSequenceEn(func(interface{}) ([]byte, error))
|
||||||
|
GetSequenceDe() func([]byte) (interface{}, error)
|
||||||
|
SetSequenceDe(func([]byte) (interface{}, error))
|
||||||
|
SendObjCtx(ctx context.Context, key string, val interface{}) (Message, error)
|
||||||
|
SendObj(key string, val interface{}) error
|
||||||
|
OpenStream(ctx context.Context, opt StreamOpenOptions) (Stream, error)
|
||||||
|
OpenRecordStream(ctx context.Context, opt RecordOpenOptions) (RecordStream, error)
|
||||||
|
OpenBulk(ctx context.Context, opt BulkOpenOptions) (Bulk, error)
|
||||||
|
SendTransfer(ctx context.Context, opt TransferSendOptions) (TransferHandle, error)
|
||||||
|
SendFile(ctx context.Context, filePath string) error
|
||||||
|
}
|
||||||
@@ -0,0 +1,544 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errInMemoryListenerClosed = errors.New("in-memory listener closed")
|
||||||
|
|
||||||
|
type inMemoryListener struct {
|
||||||
|
closed chan struct{}
|
||||||
|
once sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInMemoryListener() *inMemoryListener {
|
||||||
|
return &inMemoryListener{
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *inMemoryListener) Accept() (net.Conn, error) {
|
||||||
|
<-l.closed
|
||||||
|
return nil, errInMemoryListenerClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *inMemoryListener) Close() error {
|
||||||
|
l.once.Do(func() {
|
||||||
|
close(l.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *inMemoryListener) Addr() net.Addr {
|
||||||
|
return inMemoryAddr("in-memory-listener")
|
||||||
|
}
|
||||||
|
|
||||||
|
type inMemoryAddr string
|
||||||
|
|
||||||
|
func (a inMemoryAddr) Network() string { return "in-memory" }
|
||||||
|
func (a inMemoryAddr) String() string { return string(a) }
|
||||||
|
|
||||||
|
func TestConnectByConnRequiresModernPSK(t *testing.T) {
|
||||||
|
client := NewClient()
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer left.Close()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
err := client.ConnectByConn(left)
|
||||||
|
if !errors.Is(err, errModernPSKRequired) {
|
||||||
|
t.Fatalf("ConnectByConn error = %v, want %v", err, errModernPSKRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByConnWithConfiguredSecurity(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
if err := client.ConnectByConn(left); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByFactoryRequiresModernPSK(t *testing.T) {
|
||||||
|
client := NewClient()
|
||||||
|
called := false
|
||||||
|
|
||||||
|
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||||
|
called = true
|
||||||
|
left, right := net.Pipe()
|
||||||
|
_ = right.Close()
|
||||||
|
return left, nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, errModernPSKRequired) {
|
||||||
|
t.Fatalf("ConnectByFactory error = %v, want %v", err, errModernPSKRequired)
|
||||||
|
}
|
||||||
|
if called {
|
||||||
|
t.Fatal("dialFn should not be called before security validation passes")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByFactoryRejectsNilDialFn(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
|
||||||
|
err := client.ConnectByFactory(context.Background(), nil)
|
||||||
|
if err == nil || err.Error() != "dialFn is nil" {
|
||||||
|
t.Fatalf("ConnectByFactory nil dialFn error = %v, want dialFn is nil", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByFactoryPropagatesDialError(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
wantErr := errors.New("dial failed")
|
||||||
|
|
||||||
|
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||||
|
return nil, wantErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("ConnectByFactory error = %v, want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByFactoryWithConfiguredSecurity(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
if err := client.ConnectByFactory(nil, func(ctx context.Context) (net.Conn, error) {
|
||||||
|
if ctx == nil {
|
||||||
|
t.Fatal("ConnectByFactory should normalize nil context")
|
||||||
|
}
|
||||||
|
return left, nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("ConnectByFactory failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByFactoryRejectsConcurrentStart(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
firstDialEntered := make(chan struct{}, 1)
|
||||||
|
firstDone := make(chan error, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
firstDone <- client.ConnectByFactory(ctx, func(ctx context.Context) (net.Conn, error) {
|
||||||
|
firstDialEntered <- struct{}{}
|
||||||
|
<-ctx.Done()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-firstDialEntered:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("first connect attempt did not enter dialFn")
|
||||||
|
}
|
||||||
|
|
||||||
|
secondDialCalled := false
|
||||||
|
err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||||
|
secondDialCalled = true
|
||||||
|
return nil, errors.New("second dial should not run")
|
||||||
|
})
|
||||||
|
if err == nil || err.Error() != "client already run" {
|
||||||
|
t.Fatalf("concurrent ConnectByFactory error = %v, want client already run", err)
|
||||||
|
}
|
||||||
|
if secondDialCalled {
|
||||||
|
t.Fatal("second dialFn should not be called during first connect start")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
select {
|
||||||
|
case err = <-firstDone:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("first ConnectByFactory did not finish after cancel")
|
||||||
|
}
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("first ConnectByFactory error = %v, want %v", err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantErr := errors.New("dial after rollback")
|
||||||
|
err = client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||||
|
return nil, wantErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("ConnectByFactory after rollback error = %v, want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByConnReattachesDetachedAliveSession(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
|
||||||
|
firstLeft, firstRight := net.Pipe()
|
||||||
|
defer firstRight.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, firstRight)
|
||||||
|
if err := client.ConnectByConn(firstLeft); err != nil {
|
||||||
|
t.Fatalf("initial ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
before := client.clientSessionRuntimeSnapshot()
|
||||||
|
if before == nil {
|
||||||
|
t.Fatal("runtime should exist after initial connect")
|
||||||
|
}
|
||||||
|
initialEpoch := before.epoch
|
||||||
|
initialStopCtx := before.stopCtx
|
||||||
|
initialQueue := before.queue
|
||||||
|
|
||||||
|
client.clearClientSessionRuntimeTransport()
|
||||||
|
|
||||||
|
recvCh := make(chan Message, 1)
|
||||||
|
client.SetLink("reattach-public", func(message *Message) {
|
||||||
|
recvCh <- *message
|
||||||
|
})
|
||||||
|
|
||||||
|
secondLeft, secondRight := net.Pipe()
|
||||||
|
defer secondRight.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, secondRight)
|
||||||
|
if err := client.ConnectByConn(secondLeft); err != nil {
|
||||||
|
t.Fatalf("reattach ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
after := client.clientSessionRuntimeSnapshot()
|
||||||
|
if after == nil {
|
||||||
|
t.Fatal("runtime should exist after reattach")
|
||||||
|
}
|
||||||
|
if after.conn != secondLeft || after.queue != initialQueue || after.stopCtx != initialStopCtx || after.epoch != initialEpoch || !after.transportAttached {
|
||||||
|
t.Fatalf("reattached runtime mismatch: %+v", after)
|
||||||
|
}
|
||||||
|
|
||||||
|
env, err := wrapTransferMsgEnvelope(TransferMsg{
|
||||||
|
ID: 88,
|
||||||
|
Key: "reattach-public",
|
||||||
|
Value: []byte("ok"),
|
||||||
|
Type: MSG_ASYNC,
|
||||||
|
}, client.sequenceEn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("wrapTransferMsgEnvelope failed: %v", err)
|
||||||
|
}
|
||||||
|
wire, err := client.encodeEnvelope(env)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("encodeEnvelope failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := secondRight.Write(wire); err != nil {
|
||||||
|
t.Fatalf("reattached conn write failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case msg := <-recvCh:
|
||||||
|
if got, want := msg.Key, "reattach-public"; got != want {
|
||||||
|
t.Fatalf("message key mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := string(msg.Value), "ok"; got != want {
|
||||||
|
t.Fatalf("message value mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("reattached public conn did not dispatch message")
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("final Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByFactoryReattachesDetachedAliveSessionAndUpdatesSource(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
secret := []byte("0123456789abcdef0123456789abcdef")
|
||||||
|
client.SetSecretKey(secret)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
server.SetSecretKey(secret)
|
||||||
|
})
|
||||||
|
|
||||||
|
firstLeft, firstRight := net.Pipe()
|
||||||
|
defer firstRight.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, firstRight)
|
||||||
|
if err := client.ConnectByConn(firstLeft); err != nil {
|
||||||
|
t.Fatalf("initial ConnectByConn failed: %v", err)
|
||||||
|
}
|
||||||
|
before := client.clientSessionRuntimeSnapshot()
|
||||||
|
if before == nil {
|
||||||
|
t.Fatal("runtime should exist after initial connect")
|
||||||
|
}
|
||||||
|
initialEpoch := before.epoch
|
||||||
|
|
||||||
|
client.clearClientSessionRuntimeTransport()
|
||||||
|
|
||||||
|
var dialCount atomic.Int32
|
||||||
|
secondLeft, secondRight := net.Pipe()
|
||||||
|
defer secondRight.Close()
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, secondRight)
|
||||||
|
if err := client.ConnectByFactory(context.Background(), func(context.Context) (net.Conn, error) {
|
||||||
|
dialCount.Add(1)
|
||||||
|
return secondLeft, nil
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("reattach ConnectByFactory failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := dialCount.Load(), int32(1); got != want {
|
||||||
|
t.Fatalf("dial count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
after := client.clientSessionRuntimeSnapshot()
|
||||||
|
if after == nil {
|
||||||
|
t.Fatal("runtime should exist after factory reattach")
|
||||||
|
}
|
||||||
|
if after.epoch != initialEpoch || after.conn != secondLeft || !after.transportAttached {
|
||||||
|
t.Fatalf("reattached runtime mismatch: %+v", after)
|
||||||
|
}
|
||||||
|
snapshot, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.ConnectSource, clientConnectSourceFactory; got != want {
|
||||||
|
t.Fatalf("connect source mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if !snapshot.CanReconnect {
|
||||||
|
t.Fatalf("snapshot should be reconnectable after factory reattach: %+v", snapshot)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("final Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectByConnFailureCleansRuntimeAndAllowsRetry(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
UseLegacySecurityClient(client)
|
||||||
|
failErr := errors.New("key exchange fail for test")
|
||||||
|
client.keyExchangeFn = func(Client) error {
|
||||||
|
return failErr
|
||||||
|
}
|
||||||
|
|
||||||
|
left1, right1 := net.Pipe()
|
||||||
|
defer right1.Close()
|
||||||
|
err := client.ConnectByConn(left1)
|
||||||
|
if !errors.Is(err, failErr) {
|
||||||
|
t.Fatalf("ConnectByConn first error = %v, want %v", err, failErr)
|
||||||
|
}
|
||||||
|
status := client.Status()
|
||||||
|
if status.Alive || status.Reason != "key exchange failed" || !errors.Is(status.Err, failErr) {
|
||||||
|
t.Fatalf("unexpected status after failed key exchange: %+v", status)
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-client.StopMonitorChan():
|
||||||
|
t.Fatal("StopMonitorChan should remain open after failed connect cleanup")
|
||||||
|
case <-time.After(20 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
client.SetSkipExchangeKey(true)
|
||||||
|
left2, right2 := net.Pipe()
|
||||||
|
defer right2.Close()
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
UseLegacySecurityServer(server)
|
||||||
|
})
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right2)
|
||||||
|
if err := client.ConnectByConn(left2); err != nil {
|
||||||
|
t.Fatalf("ConnectByConn second attempt failed: %v", err)
|
||||||
|
}
|
||||||
|
if !client.Status().Alive {
|
||||||
|
t.Fatalf("client should be alive after second ConnectByConn: %+v", client.Status())
|
||||||
|
}
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListenByListenerRequiresModernPSK(t *testing.T) {
|
||||||
|
server := NewServer()
|
||||||
|
listener := newInMemoryListener()
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
err := server.ListenByListener(listener)
|
||||||
|
if !errors.Is(err, errModernPSKRequired) {
|
||||||
|
t.Fatalf("ListenByListener error = %v, want %v", err, errModernPSKRequired)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListenByListenerWithConfiguredSecurity(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
listener := newInMemoryListener()
|
||||||
|
defer listener.Close()
|
||||||
|
|
||||||
|
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
if err := server.ListenByListener(listener); err != nil {
|
||||||
|
t.Fatalf("ListenByListener failed: %v", err)
|
||||||
|
}
|
||||||
|
if !server.Status().Alive {
|
||||||
|
t.Fatal("server should be alive after ListenByListener")
|
||||||
|
}
|
||||||
|
if err := server.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListenByListenerRejectsNil(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
server.SetSecretKey([]byte("0123456789abcdef0123456789abcdef"))
|
||||||
|
err := server.ListenByListener(nil)
|
||||||
|
if err == nil || err.Error() != "listener is nil" {
|
||||||
|
t.Fatalf("ListenByListener nil error = %v, want listener is nil", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReadMessagePreservesUserStopReason(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
client.conn = left
|
||||||
|
client.stopCtx = stopCtx
|
||||||
|
client.stopFn = stopFn
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.readMessage()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
_ = right.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("readMessage should exit after user stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
status := client.Status()
|
||||||
|
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after user stop: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReadMessagePreservesServerStopReason(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
stopCtx, stopFn := context.WithCancel(context.Background())
|
||||||
|
defer stopFn()
|
||||||
|
|
||||||
|
client.conn = left
|
||||||
|
client.stopCtx = stopCtx
|
||||||
|
client.stopFn = stopFn
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.readMessage()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client.stopClientSessionFromServer("recv stop signal from server", nil)
|
||||||
|
_ = right.Close()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("readMessage should exit after server stop")
|
||||||
|
}
|
||||||
|
|
||||||
|
status := client.Status()
|
||||||
|
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after server stop: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientStopClientSessionFromServerDisablesGoodBye(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
client.stopClientSessionFromServer("recv stop signal from server", nil)
|
||||||
|
|
||||||
|
if client.shouldSayGoodByeOnStop() {
|
||||||
|
t.Fatal("server stop should disable goodbye on stop")
|
||||||
|
}
|
||||||
|
status := client.Status()
|
||||||
|
if status.Alive || status.Reason != "recv stop signal from server" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after server stop helper: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientStopClientSessionKeepsGoodByeEnabled(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
client.markSessionStarted()
|
||||||
|
|
||||||
|
client.stopClientSession("recv stop signal from user", nil)
|
||||||
|
|
||||||
|
if !client.shouldSayGoodByeOnStop() {
|
||||||
|
t.Fatal("local stop should keep goodbye enabled")
|
||||||
|
}
|
||||||
|
status := client.Status()
|
||||||
|
if status.Alive || status.Reason != "recv stop signal from user" || status.Err != nil {
|
||||||
|
t.Fatalf("unexpected status after local stop helper: %+v", status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientReadMessageLoopUsesProvidedStopCtx(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
loopCtx, loopCancel := context.WithCancel(context.Background())
|
||||||
|
loopCancel()
|
||||||
|
|
||||||
|
client.stopCtx = context.Background()
|
||||||
|
client.conn = nil
|
||||||
|
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
client.readMessageLoop(loopCtx, left, nil, 1)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("readMessageLoop should exit when provided stopCtx is canceled")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := right.Write([]byte("x")); err == nil {
|
||||||
|
t.Fatal("peer conn should be closed when loop exits")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,257 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultConnectRetryAttempts = 3
|
||||||
|
defaultConnectRetryBase = 200 * time.Millisecond
|
||||||
|
defaultConnectRetryMax = 2 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectRetryOptions struct {
|
||||||
|
MaxAttempts int
|
||||||
|
BaseDelay time.Duration
|
||||||
|
MaxDelay time.Duration
|
||||||
|
ShouldRetry func(error) bool
|
||||||
|
OnRetry func(ConnectRetryEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectRetryEvent struct {
|
||||||
|
Attempt int
|
||||||
|
MaxAttempts int
|
||||||
|
Err error
|
||||||
|
NextDelay time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errConnectRetryClientNil = errors.New("connect retry client is nil")
|
||||||
|
errConnectRetryServerNil = errors.New("connect retry server is nil")
|
||||||
|
errConnectRetryFnNil = errors.New("connect retry fn is nil")
|
||||||
|
errConnectRetryDialFnNil = errors.New("connect retry dialFn is nil")
|
||||||
|
errClientReconnectNil = errors.New("client reconnect target is nil")
|
||||||
|
errClientReconnectUnsupported = errors.New("client reconnect target type is unsupported")
|
||||||
|
errClientReconnectActive = errors.New("client reconnect requires an inactive session")
|
||||||
|
)
|
||||||
|
|
||||||
|
func DefaultConnectRetryOptions() ConnectRetryOptions {
|
||||||
|
return ConnectRetryOptions{
|
||||||
|
MaxAttempts: defaultConnectRetryAttempts,
|
||||||
|
BaseDelay: defaultConnectRetryBase,
|
||||||
|
MaxDelay: defaultConnectRetryMax,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeConnectRetryOptions(opts *ConnectRetryOptions) ConnectRetryOptions {
|
||||||
|
cfg := DefaultConnectRetryOptions()
|
||||||
|
if opts == nil {
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
if opts.MaxAttempts > 0 {
|
||||||
|
cfg.MaxAttempts = opts.MaxAttempts
|
||||||
|
}
|
||||||
|
if opts.BaseDelay > 0 {
|
||||||
|
cfg.BaseDelay = opts.BaseDelay
|
||||||
|
}
|
||||||
|
if opts.MaxDelay > 0 {
|
||||||
|
cfg.MaxDelay = opts.MaxDelay
|
||||||
|
}
|
||||||
|
cfg.ShouldRetry = opts.ShouldRetry
|
||||||
|
cfg.OnRetry = opts.OnRetry
|
||||||
|
if cfg.MaxDelay < cfg.BaseDelay {
|
||||||
|
cfg.MaxDelay = cfg.BaseDelay
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func RetryConnect(ctx context.Context, opts *ConnectRetryOptions, fn func(context.Context) error) error {
|
||||||
|
if fn == nil {
|
||||||
|
return errConnectRetryFnNil
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
cfg := normalizeConnectRetryOptions(opts)
|
||||||
|
var lastErr error
|
||||||
|
for attempt := 1; attempt <= cfg.MaxAttempts; attempt++ {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
lastErr = fn(ctx)
|
||||||
|
if lastErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if cfg.ShouldRetry != nil && !cfg.ShouldRetry(lastErr) {
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
if attempt >= cfg.MaxAttempts {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
delay := connectRetryBackoffDelay(cfg, attempt)
|
||||||
|
if cfg.OnRetry != nil {
|
||||||
|
cfg.OnRetry(ConnectRetryEvent{
|
||||||
|
Attempt: attempt,
|
||||||
|
MaxAttempts: cfg.MaxAttempts,
|
||||||
|
Err: lastErr,
|
||||||
|
NextDelay: delay,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if err := waitConnectRetryDelay(ctx, delay); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectClientWithRetry(ctx context.Context, client Client, network string, addr string, opts *ConnectRetryOptions) error {
|
||||||
|
if client == nil {
|
||||||
|
return errConnectRetryClientNil
|
||||||
|
}
|
||||||
|
recorder, _ := any(client).(connectionRetryRecorder)
|
||||||
|
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||||
|
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
|
||||||
|
return client.Connect(network, addr)
|
||||||
|
})
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.recordConnectionRetryResult(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectClientFactoryWithRetry(ctx context.Context, client Client, dialFn func(context.Context) (net.Conn, error), opts *ConnectRetryOptions) error {
|
||||||
|
if client == nil {
|
||||||
|
return errConnectRetryClientNil
|
||||||
|
}
|
||||||
|
if dialFn == nil {
|
||||||
|
return errConnectRetryDialFnNil
|
||||||
|
}
|
||||||
|
recorder, _ := any(client).(connectionRetryRecorder)
|
||||||
|
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||||
|
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
|
||||||
|
return client.ConnectByFactory(ctx, dialFn)
|
||||||
|
})
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.recordConnectionRetryResult(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientReconnecter interface {
|
||||||
|
reconnect(context.Context) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReconnectClient(ctx context.Context, client Client) error {
|
||||||
|
if client == nil {
|
||||||
|
return errClientReconnectNil
|
||||||
|
}
|
||||||
|
reconnecter, ok := any(client).(clientReconnecter)
|
||||||
|
if !ok {
|
||||||
|
return errClientReconnectUnsupported
|
||||||
|
}
|
||||||
|
return reconnecter.reconnect(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ReconnectClientWithRetry(ctx context.Context, client Client, opts *ConnectRetryOptions) error {
|
||||||
|
if client == nil {
|
||||||
|
return errConnectRetryClientNil
|
||||||
|
}
|
||||||
|
recorder, _ := any(client).(connectionRetryRecorder)
|
||||||
|
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||||
|
err := RetryConnect(ctx, retryOpts, func(ctx context.Context) error {
|
||||||
|
return ReconnectClient(ctx, client)
|
||||||
|
})
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.recordConnectionRetryResult(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func ListenServerWithRetry(ctx context.Context, server Server, network string, addr string, opts *ConnectRetryOptions) error {
|
||||||
|
if server == nil {
|
||||||
|
return errConnectRetryServerNil
|
||||||
|
}
|
||||||
|
recorder, _ := any(server).(connectionRetryRecorder)
|
||||||
|
retryOpts := wrapConnectRetryOptionsWithRecorder(opts, recorder)
|
||||||
|
err := RetryConnect(ctx, retryOpts, func(context.Context) error {
|
||||||
|
return server.Listen(network, addr)
|
||||||
|
})
|
||||||
|
if recorder != nil {
|
||||||
|
recorder.recordConnectionRetryResult(err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) reconnect(ctx context.Context) error {
|
||||||
|
if c == nil {
|
||||||
|
return errClientReconnectNil
|
||||||
|
}
|
||||||
|
if sessionIsAlive(&c.alive) {
|
||||||
|
return errClientReconnectActive
|
||||||
|
}
|
||||||
|
source := c.clientConnectSourceSnapshot()
|
||||||
|
if source == nil || !source.canReconnect() {
|
||||||
|
return errClientReconnectSourceUnavailable
|
||||||
|
}
|
||||||
|
finish, err := c.beginClientConnectAttempt()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started := false
|
||||||
|
defer func() {
|
||||||
|
finish(started)
|
||||||
|
}()
|
||||||
|
if err := c.validateSecurityConfiguration(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.closeClientTransport()
|
||||||
|
c.applySignalReliabilityTransportDefault(source.isUDP())
|
||||||
|
conn, err := source.dial(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if conn == nil {
|
||||||
|
return errors.New("conn is nil")
|
||||||
|
}
|
||||||
|
if err := c.startClientWithConnSource(conn, source); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
started = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectRetryBackoffDelay(cfg ConnectRetryOptions, failedAttempt int) time.Duration {
|
||||||
|
delay := cfg.BaseDelay
|
||||||
|
if delay <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
for i := 1; i < failedAttempt; i++ {
|
||||||
|
if delay >= cfg.MaxDelay/2 {
|
||||||
|
return cfg.MaxDelay
|
||||||
|
}
|
||||||
|
delay *= 2
|
||||||
|
}
|
||||||
|
if delay > cfg.MaxDelay {
|
||||||
|
return cfg.MaxDelay
|
||||||
|
}
|
||||||
|
return delay
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitConnectRetryDelay(ctx context.Context, delay time.Duration) error {
|
||||||
|
if delay <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(delay)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectionRetrySnapshot struct {
|
||||||
|
RetryEventTotal uint64
|
||||||
|
LastRetryAttempt int
|
||||||
|
LastRetryDelay time.Duration
|
||||||
|
LastRetryError string
|
||||||
|
LastRetryAt time.Time
|
||||||
|
LastResultError string
|
||||||
|
LastResultAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionRetryState struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
retryEventTotal uint64
|
||||||
|
lastRetryAttempt int
|
||||||
|
lastRetryDelay time.Duration
|
||||||
|
lastRetryError string
|
||||||
|
lastRetryAt time.Time
|
||||||
|
lastResultError string
|
||||||
|
lastResultAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func newConnectionRetryState() *connectionRetryState {
|
||||||
|
return &connectionRetryState{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *connectionRetryState) recordRetryEvent(event ConnectRetryEvent) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.retryEventTotal++
|
||||||
|
s.lastRetryAttempt = event.Attempt
|
||||||
|
s.lastRetryDelay = event.NextDelay
|
||||||
|
if event.Err != nil {
|
||||||
|
s.lastRetryError = event.Err.Error()
|
||||||
|
} else {
|
||||||
|
s.lastRetryError = ""
|
||||||
|
}
|
||||||
|
s.lastRetryAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *connectionRetryState) recordResult(err error) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
s.lastResultError = err.Error()
|
||||||
|
} else {
|
||||||
|
s.lastResultError = ""
|
||||||
|
}
|
||||||
|
s.lastResultAt = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *connectionRetryState) snapshot() ConnectionRetrySnapshot {
|
||||||
|
if s == nil {
|
||||||
|
return ConnectionRetrySnapshot{}
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return ConnectionRetrySnapshot{
|
||||||
|
RetryEventTotal: s.retryEventTotal,
|
||||||
|
LastRetryAttempt: s.lastRetryAttempt,
|
||||||
|
LastRetryDelay: s.lastRetryDelay,
|
||||||
|
LastRetryError: s.lastRetryError,
|
||||||
|
LastRetryAt: s.lastRetryAt,
|
||||||
|
LastResultError: s.lastResultError,
|
||||||
|
LastResultAt: s.lastResultAt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type connectionRetryRecorder interface {
|
||||||
|
recordConnectionRetryEvent(event ConnectRetryEvent)
|
||||||
|
recordConnectionRetryResult(err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapConnectRetryOptionsWithRecorder(opts *ConnectRetryOptions, recorder connectionRetryRecorder) *ConnectRetryOptions {
|
||||||
|
if recorder == nil {
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
if opts == nil {
|
||||||
|
return &ConnectRetryOptions{
|
||||||
|
OnRetry: recorder.recordConnectionRetryEvent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next := *opts
|
||||||
|
originOnRetry := next.OnRetry
|
||||||
|
next.OnRetry = func(event ConnectRetryEvent) {
|
||||||
|
recorder.recordConnectionRetryEvent(event)
|
||||||
|
if originOnRetry != nil {
|
||||||
|
originOnRetry(event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &next
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getConnectionRetryState() *connectionRetryState {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.connectionRetryState == nil {
|
||||||
|
c.connectionRetryState = newConnectionRetryState()
|
||||||
|
}
|
||||||
|
return c.connectionRetryState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) recordConnectionRetryEvent(event ConnectRetryEvent) {
|
||||||
|
c.getConnectionRetryState().recordRetryEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) recordConnectionRetryResult(err error) {
|
||||||
|
c.getConnectionRetryState().recordResult(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) connectionRetrySnapshot() ConnectionRetrySnapshot {
|
||||||
|
return c.getConnectionRetryState().snapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getConnectionRetryState() *connectionRetryState {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
if s.connectionRetryState == nil {
|
||||||
|
s.connectionRetryState = newConnectionRetryState()
|
||||||
|
}
|
||||||
|
return s.connectionRetryState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) recordConnectionRetryEvent(event ConnectRetryEvent) {
|
||||||
|
s.getConnectionRetryState().recordRetryEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) recordConnectionRetryResult(err error) {
|
||||||
|
s.getConnectionRetryState().recordResult(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) connectionRetrySnapshot() ConnectionRetrySnapshot {
|
||||||
|
return s.getConnectionRetryState().snapshot()
|
||||||
|
}
|
||||||
@@ -0,0 +1,309 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRetryConnectSucceedsAfterRetries(t *testing.T) {
|
||||||
|
var attempts int
|
||||||
|
wantErr := errors.New("dial failed")
|
||||||
|
|
||||||
|
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 4,
|
||||||
|
BaseDelay: time.Millisecond,
|
||||||
|
MaxDelay: 2 * time.Millisecond,
|
||||||
|
}, func(context.Context) error {
|
||||||
|
attempts++
|
||||||
|
if attempts < 3 {
|
||||||
|
return wantErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RetryConnect failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := attempts, 3; got != want {
|
||||||
|
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryConnectReturnsLastError(t *testing.T) {
|
||||||
|
var attempts int
|
||||||
|
wantErr := errors.New("connect failed")
|
||||||
|
|
||||||
|
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Millisecond,
|
||||||
|
MaxDelay: time.Millisecond,
|
||||||
|
}, func(context.Context) error {
|
||||||
|
attempts++
|
||||||
|
return wantErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
if got, want := attempts, 3; got != want {
|
||||||
|
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryConnectContextCanceled(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
var attempts int
|
||||||
|
|
||||||
|
err := RetryConnect(ctx, &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: 100 * time.Millisecond,
|
||||||
|
MaxDelay: 100 * time.Millisecond,
|
||||||
|
}, func(context.Context) error {
|
||||||
|
attempts++
|
||||||
|
cancel()
|
||||||
|
return errors.New("fail")
|
||||||
|
})
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("RetryConnect error = %v, want context canceled", err)
|
||||||
|
}
|
||||||
|
if got, want := attempts, 1; got != want {
|
||||||
|
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectRetryRejectsNilInputs(t *testing.T) {
|
||||||
|
if err := RetryConnect(context.Background(), nil, nil); !errors.Is(err, errConnectRetryFnNil) {
|
||||||
|
t.Fatalf("RetryConnect nil fn error = %v, want %v", err, errConnectRetryFnNil)
|
||||||
|
}
|
||||||
|
if err := ConnectClientWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryClientNil) {
|
||||||
|
t.Fatalf("ConnectClientWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil)
|
||||||
|
}
|
||||||
|
if err := ConnectClientFactoryWithRetry(context.Background(), nil, nil, nil); !errors.Is(err, errConnectRetryClientNil) {
|
||||||
|
t.Fatalf("ConnectClientFactoryWithRetry nil client error = %v, want %v", err, errConnectRetryClientNil)
|
||||||
|
}
|
||||||
|
if err := ConnectClientFactoryWithRetry(context.Background(), NewClient(), nil, nil); !errors.Is(err, errConnectRetryDialFnNil) {
|
||||||
|
t.Fatalf("ConnectClientFactoryWithRetry nil dialFn error = %v, want %v", err, errConnectRetryDialFnNil)
|
||||||
|
}
|
||||||
|
if err := ListenServerWithRetry(context.Background(), nil, "tcp", "127.0.0.1:1", nil); !errors.Is(err, errConnectRetryServerNil) {
|
||||||
|
t.Fatalf("ListenServerWithRetry nil server error = %v, want %v", err, errConnectRetryServerNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectRetryBackoffDelayCapped(t *testing.T) {
|
||||||
|
cfg := normalizeConnectRetryOptions(&ConnectRetryOptions{
|
||||||
|
MaxAttempts: 5,
|
||||||
|
BaseDelay: 10 * time.Millisecond,
|
||||||
|
MaxDelay: 30 * time.Millisecond,
|
||||||
|
})
|
||||||
|
if got, want := connectRetryBackoffDelay(cfg, 1), 10*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("delay attempt1 mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := connectRetryBackoffDelay(cfg, 2), 20*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("delay attempt2 mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := connectRetryBackoffDelay(cfg, 3), 30*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("delay attempt3 mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := connectRetryBackoffDelay(cfg, 4), 30*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("delay attempt4 mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryConnectShouldRetryCanStopEarly(t *testing.T) {
|
||||||
|
var attempts int
|
||||||
|
wantErr := errors.New("not retriable")
|
||||||
|
|
||||||
|
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 5,
|
||||||
|
BaseDelay: time.Millisecond,
|
||||||
|
MaxDelay: 2 * time.Millisecond,
|
||||||
|
ShouldRetry: func(err error) bool {
|
||||||
|
return !errors.Is(err, wantErr)
|
||||||
|
},
|
||||||
|
}, func(context.Context) error {
|
||||||
|
attempts++
|
||||||
|
return wantErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
if got, want := attempts, 1; got != want {
|
||||||
|
t.Fatalf("attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryConnectOnRetryHook(t *testing.T) {
|
||||||
|
var events []ConnectRetryEvent
|
||||||
|
wantErr := errors.New("dial failed")
|
||||||
|
|
||||||
|
err := RetryConnect(context.Background(), &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Millisecond,
|
||||||
|
MaxDelay: 2 * time.Millisecond,
|
||||||
|
OnRetry: func(event ConnectRetryEvent) {
|
||||||
|
events = append(events, event)
|
||||||
|
},
|
||||||
|
}, func(context.Context) error {
|
||||||
|
return wantErr
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("RetryConnect error = %v, want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
if got, want := len(events), 2; got != want {
|
||||||
|
t.Fatalf("retry events mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[0].Attempt, 1; got != want {
|
||||||
|
t.Fatalf("event[0] attempt mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[0].MaxAttempts, 3; got != want {
|
||||||
|
t.Fatalf("event[0] max attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if !errors.Is(events[0].Err, wantErr) {
|
||||||
|
t.Fatalf("event[0] err mismatch: got %v want %v", events[0].Err, wantErr)
|
||||||
|
}
|
||||||
|
if got, want := events[0].NextDelay, time.Millisecond; got != want {
|
||||||
|
t.Fatalf("event[0] next delay mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[1].Attempt, 2; got != want {
|
||||||
|
t.Fatalf("event[1] attempt mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[1].NextDelay, 2*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("event[1] next delay mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConnectClientFactoryWithRetryRecoversFromFailedStart(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
UseLegacySecurityClient(client)
|
||||||
|
server := newRunningPeerAttachServerForTest(t, func(server *ServerCommon) {
|
||||||
|
UseLegacySecurityServer(server)
|
||||||
|
})
|
||||||
|
|
||||||
|
wantErr := errors.New("key exchange failed on first attempt")
|
||||||
|
keyExchangeAttempts := 0
|
||||||
|
client.keyExchangeFn = func(Client) error {
|
||||||
|
keyExchangeAttempts++
|
||||||
|
if keyExchangeAttempts == 1 {
|
||||||
|
return wantErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dialAttempts := 0
|
||||||
|
var peerConns []net.Conn
|
||||||
|
dialFn := func(context.Context) (net.Conn, error) {
|
||||||
|
dialAttempts++
|
||||||
|
left, right := net.Pipe()
|
||||||
|
peerConns = append(peerConns, right)
|
||||||
|
bootstrapPeerAttachConnForTest(t, server, right)
|
||||||
|
return left, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := ConnectClientFactoryWithRetry(context.Background(), client, dialFn, &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Millisecond,
|
||||||
|
MaxDelay: time.Millisecond,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ConnectClientFactoryWithRetry failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := dialAttempts, 2; got != want {
|
||||||
|
t.Fatalf("dial attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := keyExchangeAttempts, 2; got != want {
|
||||||
|
t.Fatalf("key exchange attempts mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if status := client.Status(); !status.Alive {
|
||||||
|
t.Fatalf("client should be alive after retry success: %+v", status)
|
||||||
|
}
|
||||||
|
runtimeSnapshot, err := GetClientRuntimeSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want {
|
||||||
|
t.Fatalf("client retry events mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want {
|
||||||
|
t.Fatalf("client last retry attempt mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := runtimeSnapshot.Retry.LastRetryError, wantErr.Error(); got != want {
|
||||||
|
t.Fatalf("client last retry error mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if runtimeSnapshot.Retry.LastRetryAt.IsZero() {
|
||||||
|
t.Fatal("client last retry time should be recorded")
|
||||||
|
}
|
||||||
|
if runtimeSnapshot.Retry.LastResultError != "" {
|
||||||
|
t.Fatalf("client last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError)
|
||||||
|
}
|
||||||
|
if runtimeSnapshot.Retry.LastResultAt.IsZero() {
|
||||||
|
t.Fatal("client last result time should be recorded")
|
||||||
|
}
|
||||||
|
|
||||||
|
client.setByeFromServer(true)
|
||||||
|
if err := client.Stop(); err != nil {
|
||||||
|
t.Fatalf("client Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
for _, conn := range peerConns {
|
||||||
|
_ = conn.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListenServerWithRetryRecoversFromFailedStart(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
var retryEvents []ConnectRetryEvent
|
||||||
|
|
||||||
|
err := ListenServerWithRetry(context.Background(), server, "tcp", "127.0.0.1:0", &ConnectRetryOptions{
|
||||||
|
MaxAttempts: 3,
|
||||||
|
BaseDelay: time.Millisecond,
|
||||||
|
MaxDelay: time.Millisecond,
|
||||||
|
OnRetry: func(event ConnectRetryEvent) {
|
||||||
|
retryEvents = append(retryEvents, event)
|
||||||
|
if event.Attempt == 1 {
|
||||||
|
UseLegacySecurityServer(server)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListenServerWithRetry failed: %v", err)
|
||||||
|
}
|
||||||
|
if status := server.Status(); !status.Alive {
|
||||||
|
t.Fatalf("server should be alive after retry success: %+v", status)
|
||||||
|
}
|
||||||
|
if got := len(retryEvents); got < 1 {
|
||||||
|
t.Fatal("OnRetry should be called at least once")
|
||||||
|
}
|
||||||
|
if got, want := retryEvents[0].Attempt, 1; got != want {
|
||||||
|
t.Fatalf("retry event attempt mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if !errors.Is(retryEvents[0].Err, errModernPSKRequired) {
|
||||||
|
t.Fatalf("retry event err mismatch: got %v want %v", retryEvents[0].Err, errModernPSKRequired)
|
||||||
|
}
|
||||||
|
runtimeSnapshot, err := GetServerRuntimeSnapshot(server)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerRuntimeSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := runtimeSnapshot.Retry.RetryEventTotal, uint64(1); got != want {
|
||||||
|
t.Fatalf("server retry events mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := runtimeSnapshot.Retry.LastRetryAttempt, 1; got != want {
|
||||||
|
t.Fatalf("server last retry attempt mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := runtimeSnapshot.Retry.LastRetryError, errModernPSKRequired.Error(); got != want {
|
||||||
|
t.Fatalf("server last retry error mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if runtimeSnapshot.Retry.LastRetryAt.IsZero() {
|
||||||
|
t.Fatal("server last retry time should be recorded")
|
||||||
|
}
|
||||||
|
if runtimeSnapshot.Retry.LastResultError != "" {
|
||||||
|
t.Fatalf("server last result error should be empty on success, got %q", runtimeSnapshot.Retry.LastResultError)
|
||||||
|
}
|
||||||
|
if runtimeSnapshot.Retry.LastResultAt.IsZero() {
|
||||||
|
t.Fatal("server last result time should be recorded")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := server.Stop(); err != nil {
|
||||||
|
t.Fatalf("server Stop failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const controlBatchMaxPayloads = 16
|
||||||
|
|
||||||
|
type controlBatchRequest struct {
|
||||||
|
payload []byte
|
||||||
|
deadline time.Time
|
||||||
|
done chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
type controlBatchSender struct {
|
||||||
|
binding *transportBinding
|
||||||
|
reqCh chan controlBatchRequest
|
||||||
|
stopCh chan struct{}
|
||||||
|
doneCh chan struct{}
|
||||||
|
|
||||||
|
stopOnce sync.Once
|
||||||
|
errMu sync.Mutex
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newControlBatchSender(binding *transportBinding) *controlBatchSender {
|
||||||
|
sender := &controlBatchSender{
|
||||||
|
binding: binding,
|
||||||
|
reqCh: make(chan controlBatchRequest, controlBatchMaxPayloads*4),
|
||||||
|
stopCh: make(chan struct{}),
|
||||||
|
doneCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go sender.run()
|
||||||
|
return sender
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) submit(payload []byte, deadline time.Time) error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
req := controlBatchRequest{
|
||||||
|
payload: payload,
|
||||||
|
deadline: deadline,
|
||||||
|
done: make(chan error, 1),
|
||||||
|
}
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
return s.stoppedErr()
|
||||||
|
case s.reqCh <- req:
|
||||||
|
}
|
||||||
|
return <-req.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) run() {
|
||||||
|
defer close(s.doneCh)
|
||||||
|
for {
|
||||||
|
req, ok := s.nextRequest()
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
batch := []controlBatchRequest{req}
|
||||||
|
drain:
|
||||||
|
for len(batch) < controlBatchMaxPayloads {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
s.failPending(s.stoppedErr())
|
||||||
|
return
|
||||||
|
case next := <-s.reqCh:
|
||||||
|
batch = append(batch, next)
|
||||||
|
default:
|
||||||
|
break drain
|
||||||
|
}
|
||||||
|
}
|
||||||
|
payloads := make([][]byte, 0, len(batch))
|
||||||
|
for _, item := range batch {
|
||||||
|
payloads = append(payloads, item.payload)
|
||||||
|
}
|
||||||
|
err := s.flush(payloads, controlBatchRequestsEarliestDeadline(batch))
|
||||||
|
if err != nil {
|
||||||
|
s.setErr(err)
|
||||||
|
for _, item := range batch {
|
||||||
|
item.done <- err
|
||||||
|
}
|
||||||
|
s.failPending(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, item := range batch {
|
||||||
|
item.done <- nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) nextRequest() (controlBatchRequest, bool) {
|
||||||
|
select {
|
||||||
|
case <-s.stopCh:
|
||||||
|
s.failPending(s.stoppedErr())
|
||||||
|
return controlBatchRequest{}, false
|
||||||
|
case req := <-s.reqCh:
|
||||||
|
return req, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func controlBatchRequestsEarliestDeadline(batch []controlBatchRequest) time.Time {
|
||||||
|
var deadline time.Time
|
||||||
|
for _, item := range batch {
|
||||||
|
if item.deadline.IsZero() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if deadline.IsZero() || item.deadline.Before(deadline) {
|
||||||
|
deadline = item.deadline
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return deadline
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) flush(payloads [][]byte, deadline time.Time) error {
|
||||||
|
if s == nil || s.binding == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
queue := s.binding.queueSnapshot()
|
||||||
|
if queue == nil {
|
||||||
|
return errTransportFrameQueueUnavailable
|
||||||
|
}
|
||||||
|
return s.binding.withConnWriteLockDeadline(deadline, func(conn net.Conn) error {
|
||||||
|
return writeFramedPayloadBatchUnlocked(conn, queue, payloads)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) stop() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.stopOnce.Do(func() {
|
||||||
|
s.setErr(errTransportDetached)
|
||||||
|
close(s.stopCh)
|
||||||
|
})
|
||||||
|
<-s.doneCh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) failPending(err error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case item := <-s.reqCh:
|
||||||
|
item.done <- err
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) setErr(err error) {
|
||||||
|
if s == nil || err == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.errMu.Lock()
|
||||||
|
if s.err == nil {
|
||||||
|
s.err = err
|
||||||
|
}
|
||||||
|
s.errMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) errSnapshot() error {
|
||||||
|
if s == nil {
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
|
s.errMu.Lock()
|
||||||
|
defer s.errMu.Unlock()
|
||||||
|
return s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *controlBatchSender) stoppedErr() error {
|
||||||
|
if err := s.errSnapshot(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return errTransportDetached
|
||||||
|
}
|
||||||
+143
@@ -0,0 +1,143 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
itransfer "b612.me/notify/internal/transfer"
|
||||||
|
"b612.me/starcrypto"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Deprecated: legacy static RSA private key retained only for compatibility
|
||||||
|
// with MSG_KEY_CHANGE.
|
||||||
|
var defaultRsaKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||||
|
MIIJKAIBAAKCAgEAxmeMqr9yfJFKZn26oe/HvC7bZXNLC9Nk55AuTkb4XuIoqXDb
|
||||||
|
AJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT0ZCEf37ILU0G+scRzVwYHiLMwOUC
|
||||||
|
bS2o4Xor3zqUi9f1piJBvoBNh8RKKtsmJW6VQZdiUGJHbgX4MdOdtf/6TvxZMwSX
|
||||||
|
U+PRSCAjy04A31Zi7DEWUWJPyqmHeu++PxXU5lvoMdCGDqpcF2j2uO7oJJUww01M
|
||||||
|
3F5FtTElMrK4/P9gD4kP7NiPhOfVPEfBsYT/DSSjvqNZJZuWnxu+cDxE7J/sBvdp
|
||||||
|
eNRLhqzdmMYagZFuUmVrz8QmsD6jKHgydW+r7irllvb8WJPK/RIMif+4Rg7rDKFb
|
||||||
|
j8+ZQ3HZ/gKELoRSyb3zL6RC2qlGLjC1tdeN7TNTinCv092y39T8jIARJ7tpfePh
|
||||||
|
NBxsBdxfXbCAzHYZIHufI9Zlsc+felQwanlDhq+q8YLcnKHvNKYVyCf/upExpAiA
|
||||||
|
rr88y/KbeKes0KorKkwMBnGUMTothWM25wHozcurixNvP4UMWX7LWD7vOZZuNDQN
|
||||||
|
utZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SePUMJNDyjfDUJM8C2DOlyhGNPkgazO
|
||||||
|
GdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9GGoBHEfvmAoGGrk4qNbjm7JECAwEA
|
||||||
|
AQKCAgBYzHe05ELFZfG6tYMWf08R9pbTbSqlfFOpIGrZNgJr1SUF0TDzq+3bCXpF
|
||||||
|
qtn4VAw1en/JZkOV8Gp1+Bm6jWymWtwyg/fr7pG1I+vf0dwpgMHLg7P2UX1IjXmd
|
||||||
|
S4a4oEuds69hJ+OLZFsdm0ATeM7ssGicOaBmqd1Pz7rCfnL1bxQtNVzVex1r/paG
|
||||||
|
o77YNr3HoKCwhCPaPM4aQ7sOWSMUhwYBZabaYX0eLShf1O2pkexlPO+tobPpSLmx
|
||||||
|
WzRYZ6QC0AGEq9hwT6KsfCFA5pmQtFllNY7suhpL1AsECLWAgoMNCyb1oW68NBpq
|
||||||
|
CiBK5WBPGH2MW+pE74Pu1P0gen6kLGnApKQjprE1aGuR+xkZe3uEnXwSryU9TXki
|
||||||
|
wINTEMsX8dkmofFqaJhUwSubrb+t7gvv9E9ZZe0X6UgKzAVVqvh4z1pP8VT+xHpu
|
||||||
|
pW7SR8n9cFddaEPUijSb1rSpJrNzfJJ+G7yrB7Cw2kBgQ07vzD3z/3kA9cwFevLS
|
||||||
|
mv3l3OQuB6y9c+AG3cX5WGAt/BVOLjimj9qJt+YglG0SwG31U0PUnnx6QVz/UtJm
|
||||||
|
CbJQ2TpJd+mk0HyuMU+eycp7BWF3PMN+SE4QgKCKWnhsLeAd3gcvifsbLOYE1OPg
|
||||||
|
wv1tqyJy0VsJiSn6Ub6Qq0kPLwCLlQTnLWk5mIhnRpHYufTSwQKCAQEA4gS4FKPU
|
||||||
|
tAcQ82dEYW4OjGfhNWrjFpF+A8K5zufleQWcgzQ3fQho13zH0vZobukfkEVlVxla
|
||||||
|
OIVk7ZgNA4mCSFrATjIx3RMqzrAUvTte0O4wkjYgCwVvTdS1W8nvRLKgugLygyoo
|
||||||
|
r+MLW5IT3eNMK/2fZbftNlAkbc7NCo3c2tS6MXFgjx5JUuzChOY73Kp4p5KS38L5
|
||||||
|
wRRiI8KTIKjBjMZ5q/l8VLKX89bKOCaWibmItoXY6QMbIjargb7YLp3X6uGEyGIu
|
||||||
|
VhPbQ80/+OC2ZqIvDecp4PYnJNZFeqfjyfhJCNqDjBKYwIscBLMU/Wf9OY258OR4
|
||||||
|
snQaerN1M0h9lQKCAQEA4LkZIRLLw+8bIVM+7VXxFwOAGy+MH35tvuNIToItAoUh
|
||||||
|
zjL5LG34PjID8J0DPyP8VRVanak1EcxF0aTEkvnt2f2RAVsW89ytcn8Lybb12Ae8
|
||||||
|
ia2ZWuIM+J40nuKOGPs3lJ9HqdPWmZYWsWKxFJmYBBnwD6CADYqhqambQn0HeaYl
|
||||||
|
/WUD7blLYg+4Kk1mt9/hIw93jTWP/86O2H0ia+AhYPTqyvVXfIXKhat6NlOYksGf
|
||||||
|
Hdv+aCC8Ukg6FyEgiNc/rFn0MWPnEX+cM1AwubviHIBhV8QWILLBTjupwsEBZVah
|
||||||
|
60ftH+HRUCmEeOpI7jyzIlfEUNLoBHfswKMhMPtcDQKCAQEA0JFkQX+xn/PJW6PX
|
||||||
|
AUWrXTvbIg0hw8i9DcFa76klJBnehWDhN5tUDE5Uo8PJOVgdTWgMjWSS0geezHX8
|
||||||
|
xF/XfudoAIDnbMfsP9FTQhCQfaLf5XzW8vSv8pWwSiS9jJp+IUjo+8siwrR03aqe
|
||||||
|
dKr0tr+ToS0qVG1+QGqO4gdpX/LgYxHp9ggPx9s94aAIa6hQMOrcaGqnSNqDedZr
|
||||||
|
KL8x5LOewek3J32rJVP3Rfut/SfeFfjL4rKADoF+oPs4yUPVZSV4/+VCNyKZuyaj
|
||||||
|
uwm6qFlPrLe9+J+OHbsxYG+fj9hzpRzoOZFLrppwX5HWc8XLcpnrlXVwP9VOPh5u
|
||||||
|
r8VcRQKCAQAJFHGHfJLvH8Ig3pQ0UryjCWkrsAghXaJhjB1nzqqy514uTrDysp7N
|
||||||
|
JIg0OKPg8TtI1MwMgsG6Ll7D0bx/k8mgfTZWr6+FuuznK2r2g4X7bJSZm4IOwgN0
|
||||||
|
KDBIGy9SoxPj1Wu32O9a1U2lbS9qfao+wC2K9Bk4ctmFWW0Eiri6mZP/YQ1/lXUO
|
||||||
|
SURPsUDtPQaDvCRAeGGRHG95H9U8NpoiqMKz4KXgSiecrwkJGOeZRml/c1wcKPZy
|
||||||
|
/KgcNyJxZQEVnazYMgksE9Pj3uGZH5ZLQISuXyXlvFNDLfX2AIZl6dIxB371QtKK
|
||||||
|
QqMvn4fC2IEEajdsbJkjVRUj03OL3xwhAoIBAAfMhDSvBbDkGTaXnNMjPPSbswqK
|
||||||
|
qcSRhSG27mjs1dDNBKuFbz6TkIOp4nxjuS9Zp19fErXlAE9mF5yXSmuiAkZmWfhs
|
||||||
|
HKpWIdjFJK1EqSfcINe2YuoyUIulz9oG7ObRHD4D8jSPjA8Ete+XsBHGyOtUl09u
|
||||||
|
X4u9uClhqjK+r1Tno2vw5yF6ZxfQtdWuL4W0UL1S8E+VO7vjTjNOYvgjAIpAM/gW
|
||||||
|
sqjA2Qw52UZqhhLXoTfRvtJilxlXXhIRJSsnUoGiYVCQ/upjqJCClEvJfIWdGY/U
|
||||||
|
I2CbFrwJcNvOG1lUsSM55JUmbrSWVPfo7yq2k9GCuFxOy2n/SVlvlQUcNkA=
|
||||||
|
-----END RSA PRIVATE KEY-----`)
|
||||||
|
|
||||||
|
// Deprecated: legacy static RSA public key retained only for compatibility
|
||||||
|
// with MSG_KEY_CHANGE.
|
||||||
|
var defaultRsaPubKey = []byte(`-----BEGIN PUBLIC KEY-----
|
||||||
|
MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEAxmeMqr9yfJFKZn26oe/H
|
||||||
|
vC7bZXNLC9Nk55AuTkb4XuIoqXDbAJD2Y/p167oJLKIqL3edcj7h+oTfn6s79vxT
|
||||||
|
0ZCEf37ILU0G+scRzVwYHiLMwOUCbS2o4Xor3zqUi9f1piJBvoBNh8RKKtsmJW6V
|
||||||
|
QZdiUGJHbgX4MdOdtf/6TvxZMwSXU+PRSCAjy04A31Zi7DEWUWJPyqmHeu++PxXU
|
||||||
|
5lvoMdCGDqpcF2j2uO7oJJUww01M3F5FtTElMrK4/P9gD4kP7NiPhOfVPEfBsYT/
|
||||||
|
DSSjvqNZJZuWnxu+cDxE7J/sBvdpeNRLhqzdmMYagZFuUmVrz8QmsD6jKHgydW+r
|
||||||
|
7irllvb8WJPK/RIMif+4Rg7rDKFbj8+ZQ3HZ/gKELoRSyb3zL6RC2qlGLjC1tdeN
|
||||||
|
7TNTinCv092y39T8jIARJ7tpfePhNBxsBdxfXbCAzHYZIHufI9Zlsc+felQwanlD
|
||||||
|
hq+q8YLcnKHvNKYVyCf/upExpAiArr88y/KbeKes0KorKkwMBnGUMTothWM25wHo
|
||||||
|
zcurixNvP4UMWX7LWD7vOZZuNDQNutZYeTwdsniI3mTO9vlPWEK8JTfxBU7x9SeP
|
||||||
|
UMJNDyjfDUJM8C2DOlyhGNPkgazOGdliH87tHkEy/7jJnGclgKmciiVPgwHfFx9G
|
||||||
|
GoBHEfvmAoGGrk4qNbjm7JECAwEAAQ==
|
||||||
|
-----END PUBLIC KEY-----`)
|
||||||
|
|
||||||
|
// Deprecated: legacy static AES key retained only for compatibility with the
|
||||||
|
// old AES-CFB transport path.
|
||||||
|
var defaultAesKey = []byte{0x19, 0x96, 0x11, 0x27, 228, 187, 187, 231, 142, 137, 230, 179, 189, 229, 184, 133}
|
||||||
|
|
||||||
|
// Deprecated: legacy AES-CFB transport codec retained only for compatibility.
|
||||||
|
func defaultMsgEn(key []byte, d []byte) []byte {
|
||||||
|
data, err := starcrypto.CustomEncryptAesCFB(d, key)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: legacy AES-CFB transport codec retained only for compatibility.
|
||||||
|
func defaultMsgDe(key []byte, d []byte) []byte {
|
||||||
|
data, err := starcrypto.CustomDecryptAesCFB(d, key)
|
||||||
|
if err != nil {
|
||||||
|
log.Print(err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
RegisterName("b612.me/notify.Transfer", TransferMsg{})
|
||||||
|
RegisterName("b612.me/notify.Envelope", Envelope{})
|
||||||
|
RegisterName("b612.me/notify.TransferRange", TransferRange{})
|
||||||
|
RegisterName("b612.me/notify.TransferBeginRequest", TransferBeginRequest{})
|
||||||
|
RegisterName("b612.me/notify.TransferBeginResponse", TransferBeginResponse{})
|
||||||
|
RegisterName("b612.me/notify.TransferResumeRequest", TransferResumeRequest{})
|
||||||
|
RegisterName("b612.me/notify.TransferResumeResponse", TransferResumeResponse{})
|
||||||
|
RegisterName("b612.me/notify.TransferCommitRequest", TransferCommitRequest{})
|
||||||
|
RegisterName("b612.me/notify.TransferCommitResponse", TransferCommitResponse{})
|
||||||
|
RegisterName("b612.me/notify.TransferAbortRequest", TransferAbortRequest{})
|
||||||
|
RegisterName("b612.me/notify.TransferAbortResponse", TransferAbortResponse{})
|
||||||
|
RegisterName("b612.me/notify.StreamOpenRequest", StreamOpenRequest{})
|
||||||
|
RegisterName("b612.me/notify.StreamOpenResponse", StreamOpenResponse{})
|
||||||
|
RegisterName("b612.me/notify.StreamCloseRequest", StreamCloseRequest{})
|
||||||
|
RegisterName("b612.me/notify.StreamCloseResponse", StreamCloseResponse{})
|
||||||
|
RegisterName("b612.me/notify.StreamResetRequest", StreamResetRequest{})
|
||||||
|
RegisterName("b612.me/notify.StreamResetResponse", StreamResetResponse{})
|
||||||
|
RegisterName("b612.me/notify.BulkRange", BulkRange{})
|
||||||
|
RegisterName("b612.me/notify.BulkOpenRequest", BulkOpenRequest{})
|
||||||
|
RegisterName("b612.me/notify.BulkOpenResponse", BulkOpenResponse{})
|
||||||
|
RegisterName("b612.me/notify.BulkCloseRequest", BulkCloseRequest{})
|
||||||
|
RegisterName("b612.me/notify.BulkCloseResponse", BulkCloseResponse{})
|
||||||
|
RegisterName("b612.me/notify.BulkResetRequest", BulkResetRequest{})
|
||||||
|
RegisterName("b612.me/notify.BulkResetResponse", BulkResetResponse{})
|
||||||
|
RegisterName("b612.me/notify.BulkReleaseRequest", BulkReleaseRequest{})
|
||||||
|
RegisterName("b612.me/notify.bulkAttachRequest", bulkAttachRequest{})
|
||||||
|
RegisterName("b612.me/notify.bulkAttachResponse", bulkAttachResponse{})
|
||||||
|
RegisterName("b612.me/notify.peerAttachRequest", peerAttachRequest{})
|
||||||
|
RegisterName("b612.me/notify.peerAttachResponse", peerAttachResponse{})
|
||||||
|
RegisterName("b612.me/notify/transfer.Begin", itransfer.Begin{})
|
||||||
|
RegisterName("b612.me/notify/transfer.BeginAck", itransfer.BeginAck{})
|
||||||
|
RegisterName("b612.me/notify/transfer.Resume", itransfer.Resume{})
|
||||||
|
RegisterName("b612.me/notify/transfer.ResumeAck", itransfer.ResumeAck{})
|
||||||
|
RegisterName("b612.me/notify/transfer.Commit", itransfer.Commit{})
|
||||||
|
RegisterName("b612.me/notify/transfer.CommitAck", itransfer.CommitAck{})
|
||||||
|
RegisterName("b612.me/notify/transfer.Abort", itransfer.Abort{})
|
||||||
|
RegisterName("b612.me/notify/transfer.Segment", itransfer.Segment{})
|
||||||
|
RegisterName("b612.me/notify/transfer.Ack", itransfer.Ack{})
|
||||||
|
}
|
||||||
@@ -0,0 +1,387 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DiagnosticsResetCauseSummary struct {
|
||||||
|
Total int
|
||||||
|
TransportDetached int
|
||||||
|
ServiceShutdown int
|
||||||
|
Backpressure int
|
||||||
|
Other int
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiagnosticsTransferTelemetrySummary struct {
|
||||||
|
SourceReadBytes int64
|
||||||
|
StreamWriteBytes int64
|
||||||
|
SinkWriteBytes int64
|
||||||
|
SourceReadDuration time.Duration
|
||||||
|
StreamWriteDuration time.Duration
|
||||||
|
SinkWriteDuration time.Duration
|
||||||
|
SyncDuration time.Duration
|
||||||
|
VerifyDuration time.Duration
|
||||||
|
CommitDuration time.Duration
|
||||||
|
CommitWaitDuration time.Duration
|
||||||
|
WorkDuration time.Duration
|
||||||
|
ObservedDuration time.Duration
|
||||||
|
SourceReadThroughputBPS float64
|
||||||
|
StreamWriteThroughputBPS float64
|
||||||
|
SinkWriteThroughputBPS float64
|
||||||
|
CommitWaitRatio float64
|
||||||
|
}
|
||||||
|
|
||||||
|
type DiagnosticsSummary struct {
|
||||||
|
LogicalCount int
|
||||||
|
CurrentTransportCount int
|
||||||
|
|
||||||
|
StreamCount int
|
||||||
|
ActiveStreamCount int
|
||||||
|
StaleStreamCount int
|
||||||
|
ResetStreamCount int
|
||||||
|
|
||||||
|
BulkCount int
|
||||||
|
DedicatedBulkCount int
|
||||||
|
ActiveBulkCount int
|
||||||
|
StaleBulkCount int
|
||||||
|
ResetBulkCount int
|
||||||
|
|
||||||
|
TransferCount int
|
||||||
|
ActiveTransferCount int
|
||||||
|
PausedTransferCount int
|
||||||
|
DoneTransferCount int
|
||||||
|
FailedTransferCount int
|
||||||
|
AbortedTransferCount int
|
||||||
|
|
||||||
|
StreamResetCauses DiagnosticsResetCauseSummary
|
||||||
|
BulkResetCauses DiagnosticsResetCauseSummary
|
||||||
|
TransferTelemetry DiagnosticsTransferTelemetrySummary
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientDiagnosticsSnapshot struct {
|
||||||
|
Runtime ClientRuntimeSnapshot
|
||||||
|
Streams []StreamSnapshot
|
||||||
|
Bulks []BulkSnapshot
|
||||||
|
Transfers []TransferSnapshot
|
||||||
|
Summary DiagnosticsSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
type ServerDiagnosticsSnapshot struct {
|
||||||
|
Runtime ServerRuntimeSnapshot
|
||||||
|
Logicals []ClientConnRuntimeSnapshot
|
||||||
|
CurrentTransports []TransportConnRuntimeSnapshot
|
||||||
|
Streams []StreamSnapshot
|
||||||
|
Bulks []BulkSnapshot
|
||||||
|
Transfers []TransferSnapshot
|
||||||
|
Summary DiagnosticsSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errClientDiagnosticsSnapshotNil = errors.New("client diagnostics snapshot target is nil")
|
||||||
|
errServerDiagnosticsSnapshotNil = errors.New("server diagnostics snapshot target is nil")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetClientDiagnosticsSnapshot(c Client) (ClientDiagnosticsSnapshot, error) {
|
||||||
|
if c == nil {
|
||||||
|
return ClientDiagnosticsSnapshot{}, errClientDiagnosticsSnapshotNil
|
||||||
|
}
|
||||||
|
runtime, err := GetClientRuntimeSnapshot(c)
|
||||||
|
if err != nil {
|
||||||
|
return ClientDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
streams, err := GetClientStreamSnapshots(c)
|
||||||
|
if err != nil {
|
||||||
|
return ClientDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
bulks, err := GetClientBulkSnapshots(c)
|
||||||
|
if err != nil {
|
||||||
|
return ClientDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
transfers, err := GetClientTransferSnapshots(c)
|
||||||
|
if err != nil {
|
||||||
|
return ClientDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
snapshot := ClientDiagnosticsSnapshot{
|
||||||
|
Runtime: runtime,
|
||||||
|
Streams: streams,
|
||||||
|
Bulks: bulks,
|
||||||
|
Transfers: transfers,
|
||||||
|
}
|
||||||
|
snapshot.Summary = summarizeClientDiagnosticsSnapshot(snapshot)
|
||||||
|
return snapshot, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerDiagnosticsSnapshot(s Server) (ServerDiagnosticsSnapshot, error) {
|
||||||
|
if s == nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, errServerDiagnosticsSnapshotNil
|
||||||
|
}
|
||||||
|
runtime, err := GetServerRuntimeSnapshot(s)
|
||||||
|
if err != nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
logicals, err := serverLogicalRuntimeSnapshots(s)
|
||||||
|
if err != nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
transports, err := serverCurrentTransportRuntimeSnapshots(s)
|
||||||
|
if err != nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
streams, err := GetServerStreamSnapshots(s)
|
||||||
|
if err != nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
bulks, err := GetServerBulkSnapshots(s)
|
||||||
|
if err != nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
transfers, err := GetServerTransferSnapshots(s)
|
||||||
|
if err != nil {
|
||||||
|
return ServerDiagnosticsSnapshot{}, err
|
||||||
|
}
|
||||||
|
snapshot := ServerDiagnosticsSnapshot{
|
||||||
|
Runtime: runtime,
|
||||||
|
Logicals: logicals,
|
||||||
|
CurrentTransports: transports,
|
||||||
|
Streams: streams,
|
||||||
|
Bulks: bulks,
|
||||||
|
Transfers: transfers,
|
||||||
|
}
|
||||||
|
snapshot.Summary = summarizeServerDiagnosticsSnapshot(snapshot)
|
||||||
|
return snapshot, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverLogicalRuntimeSnapshots(s Server) ([]ClientConnRuntimeSnapshot, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, errServerDiagnosticsSnapshotNil
|
||||||
|
}
|
||||||
|
logicals := s.GetLogicalConnList()
|
||||||
|
out := make([]ClientConnRuntimeSnapshot, 0, len(logicals))
|
||||||
|
for _, logical := range logicals {
|
||||||
|
if logical == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snapshot, err := GetLogicalConnRuntimeSnapshot(logical)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, snapshot)
|
||||||
|
}
|
||||||
|
sortClientConnRuntimeSnapshots(out)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverCurrentTransportRuntimeSnapshots(s Server) ([]TransportConnRuntimeSnapshot, error) {
|
||||||
|
if s == nil {
|
||||||
|
return nil, errServerDiagnosticsSnapshotNil
|
||||||
|
}
|
||||||
|
transports := s.GetCurrentTransportConnList()
|
||||||
|
out := make([]TransportConnRuntimeSnapshot, 0, len(transports))
|
||||||
|
for _, transport := range transports {
|
||||||
|
if transport == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snapshot, err := GetTransportConnRuntimeSnapshot(transport)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out = append(out, snapshot)
|
||||||
|
}
|
||||||
|
sortTransportConnRuntimeSnapshots(out)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeClientDiagnosticsSnapshot(snapshot ClientDiagnosticsSnapshot) DiagnosticsSummary {
|
||||||
|
summary := DiagnosticsSummary{
|
||||||
|
LogicalCount: diagnosticsLogicalCountFromClientRuntime(snapshot.Runtime),
|
||||||
|
}
|
||||||
|
if snapshot.Runtime.TransportAttached {
|
||||||
|
summary.CurrentTransportCount = 1
|
||||||
|
}
|
||||||
|
summarizeStreamSnapshots(&summary, snapshot.Streams)
|
||||||
|
summarizeBulkSnapshots(&summary, snapshot.Bulks)
|
||||||
|
summarizeTransferSnapshots(&summary, snapshot.Transfers)
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeServerDiagnosticsSnapshot(snapshot ServerDiagnosticsSnapshot) DiagnosticsSummary {
|
||||||
|
summary := DiagnosticsSummary{
|
||||||
|
LogicalCount: len(snapshot.Logicals),
|
||||||
|
CurrentTransportCount: len(snapshot.CurrentTransports),
|
||||||
|
}
|
||||||
|
summarizeStreamSnapshots(&summary, snapshot.Streams)
|
||||||
|
summarizeBulkSnapshots(&summary, snapshot.Bulks)
|
||||||
|
summarizeTransferSnapshots(&summary, snapshot.Transfers)
|
||||||
|
return summary
|
||||||
|
}
|
||||||
|
|
||||||
|
func diagnosticsLogicalCountFromClientRuntime(runtime ClientRuntimeSnapshot) int {
|
||||||
|
if runtime.Alive || runtime.SessionEpoch != 0 || runtime.TransportAttached || runtime.HasRuntimeConn || runtime.HasRuntimeQueue {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeStreamSnapshots(summary *DiagnosticsSummary, snapshots []StreamSnapshot) {
|
||||||
|
if summary == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.StreamCount = len(snapshots)
|
||||||
|
for _, snapshot := range snapshots {
|
||||||
|
switch {
|
||||||
|
case snapshot.ResetError != "":
|
||||||
|
summary.ResetStreamCount++
|
||||||
|
accumulateDiagnosticsResetCause(&summary.StreamResetCauses, snapshot.ResetError, errStreamBackpressureExceeded.Error())
|
||||||
|
case streamSnapshotFinished(snapshot):
|
||||||
|
case streamSnapshotBoundActive(snapshot):
|
||||||
|
summary.ActiveStreamCount++
|
||||||
|
default:
|
||||||
|
summary.StaleStreamCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeBulkSnapshots(summary *DiagnosticsSummary, snapshots []BulkSnapshot) {
|
||||||
|
if summary == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.BulkCount = len(snapshots)
|
||||||
|
for _, snapshot := range snapshots {
|
||||||
|
if snapshot.Dedicated {
|
||||||
|
summary.DedicatedBulkCount++
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case snapshot.ResetError != "":
|
||||||
|
summary.ResetBulkCount++
|
||||||
|
accumulateDiagnosticsResetCause(&summary.BulkResetCauses, snapshot.ResetError, errBulkBackpressureExceeded.Error())
|
||||||
|
case bulkSnapshotFinished(snapshot):
|
||||||
|
case bulkSnapshotBoundActive(snapshot):
|
||||||
|
summary.ActiveBulkCount++
|
||||||
|
default:
|
||||||
|
summary.StaleBulkCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func summarizeTransferSnapshots(summary *DiagnosticsSummary, snapshots []TransferSnapshot) {
|
||||||
|
if summary == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.TransferCount = len(snapshots)
|
||||||
|
for _, snapshot := range snapshots {
|
||||||
|
switch snapshot.State {
|
||||||
|
case TransferStateDone:
|
||||||
|
summary.DoneTransferCount++
|
||||||
|
case TransferStateFailed:
|
||||||
|
summary.FailedTransferCount++
|
||||||
|
case TransferStateAborted:
|
||||||
|
summary.AbortedTransferCount++
|
||||||
|
case TransferStatePaused:
|
||||||
|
summary.PausedTransferCount++
|
||||||
|
default:
|
||||||
|
summary.ActiveTransferCount++
|
||||||
|
}
|
||||||
|
accumulateDiagnosticsTransferTelemetry(&summary.TransferTelemetry, snapshot)
|
||||||
|
}
|
||||||
|
finalizeDiagnosticsTransferTelemetry(&summary.TransferTelemetry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamSnapshotFinished(snapshot StreamSnapshot) bool {
|
||||||
|
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkSnapshotFinished(snapshot BulkSnapshot) bool {
|
||||||
|
return snapshot.ResetError == "" && snapshot.LocalClosed && snapshot.RemoteClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamSnapshotBoundActive(snapshot StreamSnapshot) bool {
|
||||||
|
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkSnapshotBoundActive(snapshot BulkSnapshot) bool {
|
||||||
|
return snapshot.BindingCurrent && snapshot.TransportAttached && snapshot.TransportCurrent
|
||||||
|
}
|
||||||
|
|
||||||
|
func accumulateDiagnosticsResetCause(summary *DiagnosticsResetCauseSummary, resetError string, backpressureError string) {
|
||||||
|
if summary == nil || resetError == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.Total++
|
||||||
|
if diagnosticsResetErrorMatches(resetError, errTransportDetached) {
|
||||||
|
summary.TransportDetached++
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if diagnosticsResetErrorMatches(resetError, errServiceShutdown) {
|
||||||
|
summary.ServiceShutdown++
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if resetError == backpressureError || strings.HasPrefix(resetError, backpressureError+":") {
|
||||||
|
summary.Backpressure++
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.Other++
|
||||||
|
}
|
||||||
|
|
||||||
|
func diagnosticsResetErrorMatches(resetError string, target error) bool {
|
||||||
|
if resetError == "" || target == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
base := target.Error()
|
||||||
|
return resetError == base || strings.HasPrefix(resetError, base+":")
|
||||||
|
}
|
||||||
|
|
||||||
|
func accumulateDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary, snapshot TransferSnapshot) {
|
||||||
|
if summary == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.SourceReadBytes += transferSummarySourceReadBytes(snapshot)
|
||||||
|
summary.StreamWriteBytes += transferSummaryStreamWriteBytes(snapshot)
|
||||||
|
summary.SinkWriteBytes += transferSummarySinkWriteBytes(snapshot)
|
||||||
|
summary.SourceReadDuration += snapshot.SourceReadDuration
|
||||||
|
summary.StreamWriteDuration += snapshot.StreamWriteDuration
|
||||||
|
summary.SinkWriteDuration += snapshot.SinkWriteDuration
|
||||||
|
summary.SyncDuration += snapshot.SyncDuration
|
||||||
|
summary.VerifyDuration += snapshot.VerifyDuration
|
||||||
|
summary.CommitDuration += snapshot.CommitDuration
|
||||||
|
summary.CommitWaitDuration += snapshot.CommitWaitDuration
|
||||||
|
}
|
||||||
|
|
||||||
|
func finalizeDiagnosticsTransferTelemetry(summary *DiagnosticsTransferTelemetrySummary) {
|
||||||
|
if summary == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
summary.WorkDuration = summary.SourceReadDuration + summary.StreamWriteDuration + summary.SinkWriteDuration +
|
||||||
|
summary.SyncDuration + summary.VerifyDuration + summary.CommitDuration
|
||||||
|
summary.ObservedDuration = summary.WorkDuration + summary.CommitWaitDuration
|
||||||
|
summary.SourceReadThroughputBPS = throughputBytesPerSecond(summary.SourceReadBytes, summary.SourceReadDuration)
|
||||||
|
summary.StreamWriteThroughputBPS = throughputBytesPerSecond(summary.StreamWriteBytes, summary.StreamWriteDuration)
|
||||||
|
summary.SinkWriteThroughputBPS = throughputBytesPerSecond(summary.SinkWriteBytes, summary.SinkWriteDuration)
|
||||||
|
summary.CommitWaitRatio = durationRatio(summary.CommitWaitDuration, summary.ObservedDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortClientConnRuntimeSnapshots(src []ClientConnRuntimeSnapshot) {
|
||||||
|
sort.Slice(src, func(i, j int) bool {
|
||||||
|
if src[i].ClientID != src[j].ClientID {
|
||||||
|
return src[i].ClientID < src[j].ClientID
|
||||||
|
}
|
||||||
|
if src[i].TransportGeneration != src[j].TransportGeneration {
|
||||||
|
return src[i].TransportGeneration < src[j].TransportGeneration
|
||||||
|
}
|
||||||
|
return src[i].RemoteAddress < src[j].RemoteAddress
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortTransportConnRuntimeSnapshots(src []TransportConnRuntimeSnapshot) {
|
||||||
|
sort.Slice(src, func(i, j int) bool {
|
||||||
|
if src[i].ClientID != src[j].ClientID {
|
||||||
|
return src[i].ClientID < src[j].ClientID
|
||||||
|
}
|
||||||
|
if src[i].TransportGeneration != src[j].TransportGeneration {
|
||||||
|
return src[i].TransportGeneration < src[j].TransportGeneration
|
||||||
|
}
|
||||||
|
return src[i].RemoteAddress < src[j].RemoteAddress
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,417 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"math"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
itransfer "b612.me/notify/internal/transfer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetClientDiagnosticsSnapshotDefaults(t *testing.T) {
|
||||||
|
client := NewClient()
|
||||||
|
snapshot, err := GetClientDiagnosticsSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Runtime.OwnerState, "idle"; got != want {
|
||||||
|
t.Fatalf("Runtime.OwnerState = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if len(snapshot.Streams) != 0 || len(snapshot.Bulks) != 0 || len(snapshot.Transfers) != 0 {
|
||||||
|
t.Fatalf("default diagnostics should be empty: %+v", snapshot)
|
||||||
|
}
|
||||||
|
if snapshot.Summary != (DiagnosticsSummary{}) {
|
||||||
|
t.Fatalf("default summary mismatch: %+v", snapshot.Summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientDiagnosticsSnapshotAggregatesActiveState(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
streamAcceptCh := make(chan StreamAcceptInfo, 1)
|
||||||
|
bulkAcceptCh := make(chan BulkAcceptInfo, 1)
|
||||||
|
server.SetStreamHandler(func(info StreamAcceptInfo) error {
|
||||||
|
streamAcceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
server.SetBulkHandler(func(info BulkAcceptInfo) error {
|
||||||
|
bulkAcceptCh <- info
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||||
|
t.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = server.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||||
|
t.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = client.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
stream, err := client.OpenStream(context.Background(), StreamOpenOptions{
|
||||||
|
ID: "diag-client-stream",
|
||||||
|
Channel: StreamDataChannel,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client OpenStream failed: %v", err)
|
||||||
|
}
|
||||||
|
waitAcceptedStream(t, streamAcceptCh, 2*time.Second)
|
||||||
|
|
||||||
|
bulk, err := client.OpenBulk(context.Background(), BulkOpenOptions{
|
||||||
|
ID: "diag-client-bulk",
|
||||||
|
Range: BulkRange{
|
||||||
|
Length: 64,
|
||||||
|
},
|
||||||
|
ChunkSize: 16 * 1024,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("client OpenBulk failed: %v", err)
|
||||||
|
}
|
||||||
|
waitAcceptedBulk(t, bulkAcceptCh, 2*time.Second)
|
||||||
|
|
||||||
|
transferRuntime := client.getTransferRuntime()
|
||||||
|
transferRuntime.ensureTransferDescriptor(fileTransferDirectionSend, clientFileScope(), clientFileScope(), 0, itransfer.Descriptor{
|
||||||
|
ID: "diag-client-transfer-done",
|
||||||
|
Channel: itransfer.DataChannel,
|
||||||
|
Size: 32,
|
||||||
|
Checksum: "sum-client",
|
||||||
|
})
|
||||||
|
transferRuntime.activate(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done")
|
||||||
|
transferRuntime.complete(fileTransferDirectionSend, clientFileScope(), "diag-client-transfer-done")
|
||||||
|
|
||||||
|
snapshot, err := GetClientDiagnosticsSnapshot(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientDiagnosticsSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.LogicalCount, 1; got != want {
|
||||||
|
t.Fatalf("LogicalCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.CurrentTransportCount, 1; got != want {
|
||||||
|
t.Fatalf("CurrentTransportCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.StreamCount, 1; got != want {
|
||||||
|
t.Fatalf("StreamCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.ActiveStreamCount, 1; got != want {
|
||||||
|
t.Fatalf("ActiveStreamCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.BulkCount, 1; got != want {
|
||||||
|
t.Fatalf("BulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.ActiveBulkCount, 1; got != want {
|
||||||
|
t.Fatalf("ActiveBulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.TransferCount, 1; got != want {
|
||||||
|
t.Fatalf("TransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.DoneTransferCount, 1; got != want {
|
||||||
|
t.Fatalf("DoneTransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got := snapshot.Summary.StaleStreamCount + snapshot.Summary.ResetStreamCount + snapshot.Summary.StaleBulkCount + snapshot.Summary.ResetBulkCount + snapshot.Summary.FailedTransferCount; got != 0 {
|
||||||
|
t.Fatalf("unexpected unhealthy counters in active snapshot: %+v", snapshot.Summary)
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = stream.Close()
|
||||||
|
_ = bulk.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetServerDiagnosticsSnapshotAggregatesStaleAndResetState(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
|
||||||
|
left, right := net.Pipe()
|
||||||
|
defer right.Close()
|
||||||
|
|
||||||
|
logical := server.bootstrapAcceptedLogical("diag-server-peer", nil, left)
|
||||||
|
if logical == nil {
|
||||||
|
t.Fatal("bootstrapAcceptedLogical should return logical")
|
||||||
|
}
|
||||||
|
logical.markIdentityBound()
|
||||||
|
logical.compatClientConn().markClientConnStreamTransport()
|
||||||
|
transport := logical.CurrentTransportConn()
|
||||||
|
if transport == nil {
|
||||||
|
t.Fatal("CurrentTransportConn should return active transport")
|
||||||
|
}
|
||||||
|
scope := serverFileScope(logical)
|
||||||
|
|
||||||
|
streamStale := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{
|
||||||
|
StreamID: "diag-stream-stale",
|
||||||
|
DataID: 1,
|
||||||
|
Channel: StreamDataChannel,
|
||||||
|
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
|
||||||
|
if err := server.getStreamRuntime().register(scope, streamStale); err != nil {
|
||||||
|
t.Fatalf("register stale stream failed: %v", err)
|
||||||
|
}
|
||||||
|
streamReset := newStreamHandle(context.Background(), server.getStreamRuntime(), scope, StreamOpenRequest{
|
||||||
|
StreamID: "diag-stream-reset",
|
||||||
|
DataID: 2,
|
||||||
|
Channel: StreamDataChannel,
|
||||||
|
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, defaultStreamConfig())
|
||||||
|
if err := server.getStreamRuntime().register(scope, streamReset); err != nil {
|
||||||
|
t.Fatalf("register reset stream failed: %v", err)
|
||||||
|
}
|
||||||
|
streamReset.mu.Lock()
|
||||||
|
streamReset.resetErr = errTransportDetached
|
||||||
|
streamReset.mu.Unlock()
|
||||||
|
|
||||||
|
bulkStale := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{
|
||||||
|
BulkID: "diag-bulk-stale",
|
||||||
|
DataID: 3,
|
||||||
|
Range: BulkRange{
|
||||||
|
Length: 16,
|
||||||
|
},
|
||||||
|
ChunkSize: 32 * 1024,
|
||||||
|
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil)
|
||||||
|
if err := server.getBulkRuntime().register(scope, bulkStale); err != nil {
|
||||||
|
t.Fatalf("register stale bulk failed: %v", err)
|
||||||
|
}
|
||||||
|
bulkReset := newBulkHandle(context.Background(), server.getBulkRuntime(), scope, BulkOpenRequest{
|
||||||
|
BulkID: "diag-bulk-reset",
|
||||||
|
DataID: 4,
|
||||||
|
Dedicated: true,
|
||||||
|
Range: BulkRange{
|
||||||
|
Length: 16,
|
||||||
|
},
|
||||||
|
ChunkSize: 32 * 1024,
|
||||||
|
}, 0, logical, transport, transport.TransportGeneration(), nil, nil, nil, nil, nil)
|
||||||
|
if err := server.getBulkRuntime().register(scope, bulkReset); err != nil {
|
||||||
|
t.Fatalf("register reset bulk failed: %v", err)
|
||||||
|
}
|
||||||
|
bulkReset.mu.Lock()
|
||||||
|
bulkReset.resetErr = errTransportDetached
|
||||||
|
bulkReset.mu.Unlock()
|
||||||
|
|
||||||
|
transferRuntime := server.getTransferRuntime()
|
||||||
|
transferRuntime.ensureTransferDescriptor(fileTransferDirectionReceive, scope, scope, transport.TransportGeneration(), itransfer.Descriptor{
|
||||||
|
ID: "diag-transfer-failed",
|
||||||
|
Channel: itransfer.DataChannel,
|
||||||
|
Size: 64,
|
||||||
|
Checksum: "sum-server",
|
||||||
|
})
|
||||||
|
transferRuntime.activate(fileTransferDirectionReceive, scope, "diag-transfer-failed")
|
||||||
|
transferRuntime.fail(fileTransferDirectionReceive, scope, "diag-transfer-failed", errors.New("boom"))
|
||||||
|
|
||||||
|
logical.markTransportDetached("heartbeat timeout", nil)
|
||||||
|
logical.detachServerOwnedTransport()
|
||||||
|
|
||||||
|
snapshot, err := GetServerDiagnosticsSnapshot(server)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerDiagnosticsSnapshot failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(snapshot.Logicals), 1; got != want {
|
||||||
|
t.Fatalf("logical snapshot count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(snapshot.CurrentTransports), 0; got != want {
|
||||||
|
t.Fatalf("current transport snapshot count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Runtime.DetachedClientCount, 1; got != want {
|
||||||
|
t.Fatalf("DetachedClientCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.LogicalCount, 1; got != want {
|
||||||
|
t.Fatalf("LogicalCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.CurrentTransportCount, 0; got != want {
|
||||||
|
t.Fatalf("CurrentTransportCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.StreamCount, 2; got != want {
|
||||||
|
t.Fatalf("StreamCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.StaleStreamCount, 1; got != want {
|
||||||
|
t.Fatalf("StaleStreamCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.ResetStreamCount, 1; got != want {
|
||||||
|
t.Fatalf("ResetStreamCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.StreamResetCauses.Total, 1; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.StreamResetCauses.TransportDetached, 1; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.BulkCount, 2; got != want {
|
||||||
|
t.Fatalf("BulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.DedicatedBulkCount, 1; got != want {
|
||||||
|
t.Fatalf("DedicatedBulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.StaleBulkCount, 1; got != want {
|
||||||
|
t.Fatalf("StaleBulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.ResetBulkCount, 1; got != want {
|
||||||
|
t.Fatalf("ResetBulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.BulkResetCauses.Total, 1; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.BulkResetCauses.TransportDetached, 1; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.TransferCount, 1; got != want {
|
||||||
|
t.Fatalf("TransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Summary.FailedTransferCount, 1; got != want {
|
||||||
|
t.Fatalf("FailedTransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiagnosticsSummaryClassifiesResetCauses(t *testing.T) {
|
||||||
|
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
|
||||||
|
Streams: []StreamSnapshot{
|
||||||
|
{ResetError: errTransportDetached.Error()},
|
||||||
|
{ResetError: errServiceShutdown.Error()},
|
||||||
|
{ResetError: errStreamBackpressureExceeded.Error()},
|
||||||
|
{ResetError: "stream boom"},
|
||||||
|
},
|
||||||
|
Bulks: []BulkSnapshot{
|
||||||
|
{ResetError: errTransportDetached.Error()},
|
||||||
|
{ResetError: errServiceShutdown.Error()},
|
||||||
|
{ResetError: errBulkBackpressureExceeded.Error()},
|
||||||
|
{ResetError: "bulk boom"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if got, want := summary.ResetStreamCount, 4; got != want {
|
||||||
|
t.Fatalf("ResetStreamCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.StreamResetCauses.Total, 4; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.Total = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.StreamResetCauses.TransportDetached, 1; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.TransportDetached = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.StreamResetCauses.ServiceShutdown, 1; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.ServiceShutdown = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.StreamResetCauses.Backpressure, 1; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.Backpressure = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.StreamResetCauses.Other, 1; got != want {
|
||||||
|
t.Fatalf("StreamResetCauses.Other = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := summary.ResetBulkCount, 4; got != want {
|
||||||
|
t.Fatalf("ResetBulkCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.BulkResetCauses.Total, 4; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.Total = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.BulkResetCauses.TransportDetached, 1; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.TransportDetached = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.BulkResetCauses.ServiceShutdown, 1; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.ServiceShutdown = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.BulkResetCauses.Backpressure, 1; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.Backpressure = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.BulkResetCauses.Other, 1; got != want {
|
||||||
|
t.Fatalf("BulkResetCauses.Other = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDiagnosticsSummaryAggregatesTransferTelemetry(t *testing.T) {
|
||||||
|
summary := summarizeClientDiagnosticsSnapshot(ClientDiagnosticsSnapshot{
|
||||||
|
Transfers: []TransferSnapshot{
|
||||||
|
{
|
||||||
|
ID: "send-done",
|
||||||
|
State: TransferStateDone,
|
||||||
|
SentBytes: 2048,
|
||||||
|
SourceReadDuration: 200 * time.Millisecond,
|
||||||
|
StreamWriteDuration: 400 * time.Millisecond,
|
||||||
|
CommitWaitDuration: 100 * time.Millisecond,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "recv-failed",
|
||||||
|
State: TransferStateFailed,
|
||||||
|
ReceivedBytes: 1024,
|
||||||
|
SinkWriteDuration: 250 * time.Millisecond,
|
||||||
|
SyncDuration: 50 * time.Millisecond,
|
||||||
|
VerifyDuration: 25 * time.Millisecond,
|
||||||
|
CommitDuration: 75 * time.Millisecond,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
if got, want := summary.TransferCount, 2; got != want {
|
||||||
|
t.Fatalf("TransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.DoneTransferCount, 1; got != want {
|
||||||
|
t.Fatalf("DoneTransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.FailedTransferCount, 1; got != want {
|
||||||
|
t.Fatalf("FailedTransferCount = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
telemetry := summary.TransferTelemetry
|
||||||
|
if got, want := telemetry.SourceReadBytes, int64(2048); got != want {
|
||||||
|
t.Fatalf("SourceReadBytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.StreamWriteBytes, int64(2048); got != want {
|
||||||
|
t.Fatalf("StreamWriteBytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.SinkWriteBytes, int64(1024); got != want {
|
||||||
|
t.Fatalf("SinkWriteBytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.SourceReadDuration, 200*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("SourceReadDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.StreamWriteDuration, 400*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("StreamWriteDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.SinkWriteDuration, 250*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("SinkWriteDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.SyncDuration, 50*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("SyncDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.VerifyDuration, 25*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("VerifyDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.CommitDuration, 75*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("CommitDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.CommitWaitDuration, 100*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("CommitWaitDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.WorkDuration, time.Second; got != want {
|
||||||
|
t.Fatalf("WorkDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.ObservedDuration, 1100*time.Millisecond; got != want {
|
||||||
|
t.Fatalf("ObservedDuration = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.SourceReadThroughputBPS, 10240.0; math.Abs(got-want) > 0.001 {
|
||||||
|
t.Fatalf("SourceReadThroughputBPS = %f, want %f", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.StreamWriteThroughputBPS, 5120.0; math.Abs(got-want) > 0.001 {
|
||||||
|
t.Fatalf("StreamWriteThroughputBPS = %f, want %f", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.SinkWriteThroughputBPS, 4096.0; math.Abs(got-want) > 0.001 {
|
||||||
|
t.Fatalf("SinkWriteThroughputBPS = %f, want %f", got, want)
|
||||||
|
}
|
||||||
|
if got, want := telemetry.CommitWaitRatio, 1.0/11.0; math.Abs(got-want) > 0.000001 {
|
||||||
|
t.Fatalf("CommitWaitRatio = %f, want %f", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetDiagnosticsSnapshotRejectsNil(t *testing.T) {
|
||||||
|
if _, err := GetClientDiagnosticsSnapshot(nil); !errors.Is(err, errClientDiagnosticsSnapshotNil) {
|
||||||
|
t.Fatalf("GetClientDiagnosticsSnapshot nil error = %v, want %v", err, errClientDiagnosticsSnapshotNil)
|
||||||
|
}
|
||||||
|
if _, err := GetServerDiagnosticsSnapshot(nil); !errors.Is(err, errServerDiagnosticsSnapshotNil) {
|
||||||
|
t.Fatalf("GetServerDiagnosticsSnapshot nil error = %v, want %v", err, errServerDiagnosticsSnapshotNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
+184
@@ -0,0 +1,184 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"b612.me/notify/internal/timeutil"
|
||||||
|
crand "crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EnvelopeKind uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
EnvelopeSignal EnvelopeKind = iota
|
||||||
|
EnvelopeSignalAck
|
||||||
|
EnvelopeStreamData
|
||||||
|
EnvelopeFileMeta
|
||||||
|
EnvelopeFileChunk
|
||||||
|
EnvelopeFileEnd
|
||||||
|
EnvelopeFileAbort
|
||||||
|
EnvelopeAck
|
||||||
|
)
|
||||||
|
|
||||||
|
type Envelope struct {
|
||||||
|
Kind EnvelopeKind
|
||||||
|
ID uint64
|
||||||
|
Body []byte
|
||||||
|
Stream StreamPacket
|
||||||
|
File FilePacket
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamPacket struct {
|
||||||
|
StreamID string
|
||||||
|
Chunk []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type FilePacket struct {
|
||||||
|
FileID string
|
||||||
|
Name string
|
||||||
|
Size int64
|
||||||
|
Mode uint32
|
||||||
|
ModTime int64
|
||||||
|
Offset int64
|
||||||
|
Chunk []byte
|
||||||
|
Checksum string
|
||||||
|
Error string
|
||||||
|
Stage string
|
||||||
|
}
|
||||||
|
|
||||||
|
func wrapTransferMsgEnvelope(msg TransferMsg, enFn func(interface{}) ([]byte, error)) (Envelope, error) {
|
||||||
|
body, err := enFn(msg)
|
||||||
|
if err != nil {
|
||||||
|
return Envelope{}, err
|
||||||
|
}
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeSignal,
|
||||||
|
ID: msg.ID,
|
||||||
|
Body: body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unwrapTransferMsgEnvelope(env Envelope, deFn func([]byte) (interface{}, error)) (TransferMsg, error) {
|
||||||
|
if env.Kind != EnvelopeSignal {
|
||||||
|
return TransferMsg{}, errors.New("envelope kind is not signal")
|
||||||
|
}
|
||||||
|
data, err := deFn(env.Body)
|
||||||
|
if err != nil {
|
||||||
|
return TransferMsg{}, err
|
||||||
|
}
|
||||||
|
msg, ok := data.(TransferMsg)
|
||||||
|
if !ok {
|
||||||
|
return TransferMsg{}, errors.New("invalid signal envelope payload")
|
||||||
|
}
|
||||||
|
return msg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSignalAckEnvelope(signalID uint64) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeSignalAck,
|
||||||
|
ID: signalID,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStreamDataEnvelope(streamID string, chunk []byte) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeStreamData,
|
||||||
|
Stream: StreamPacket{
|
||||||
|
StreamID: streamID,
|
||||||
|
Chunk: chunk,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileMetaEnvelope(fileID string, fileName string, fileSize int64, checksum string, mode uint32, modTime int64) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
File: FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
Name: filepath.Base(fileName),
|
||||||
|
Size: fileSize,
|
||||||
|
Mode: mode,
|
||||||
|
ModTime: modTime,
|
||||||
|
Checksum: checksum,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileChunkEnvelope(fileID string, offset int64, chunk []byte) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
File: FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
Offset: offset,
|
||||||
|
Chunk: chunk,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileEndEnvelope(fileID string) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
File: FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileAbortEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeFileAbort,
|
||||||
|
File: FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
Stage: stage,
|
||||||
|
Offset: offset,
|
||||||
|
Error: errMsg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileAckEnvelope(fileID string, stage string, offset int64, errMsg string) Envelope {
|
||||||
|
return Envelope{
|
||||||
|
Kind: EnvelopeAck,
|
||||||
|
File: FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
Stage: stage,
|
||||||
|
Offset: offset,
|
||||||
|
Error: errMsg,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fileIDSerial uint64
|
||||||
|
|
||||||
|
func buildFileID(fileName string) string {
|
||||||
|
base := fileIDBaseName(fileName)
|
||||||
|
ts := uint64(timeutil.NowUnixNano())
|
||||||
|
pid := uint64(os.Getpid())
|
||||||
|
seq := atomic.AddUint64(&fileIDSerial, 1)
|
||||||
|
rnd := uint64(randomFileIDSuffix())
|
||||||
|
return fmt.Sprintf("%s-%x-%x-%x-%x", base, ts, pid, seq, rnd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileIDBaseName(fileName string) string {
|
||||||
|
base := sanitizeFileName(filepath.Base(fileName))
|
||||||
|
switch base {
|
||||||
|
case "", ".", "/", "\\":
|
||||||
|
return "unnamed"
|
||||||
|
default:
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomFileIDSuffix() uint32 {
|
||||||
|
var buf [4]byte
|
||||||
|
if _, err := crand.Read(buf[:]); err == nil {
|
||||||
|
return binary.BigEndian.Uint32(buf[:])
|
||||||
|
}
|
||||||
|
seq := atomic.LoadUint64(&fileIDSerial)
|
||||||
|
mix := uint64(timeutil.NowUnixNano()) ^ (seq << 1) ^ uint64(os.Getpid())
|
||||||
|
return uint32(mix ^ (mix >> 32))
|
||||||
|
}
|
||||||
@@ -0,0 +1,52 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildFileIDUniqueAcrossBurst(t *testing.T) {
|
||||||
|
const total = 512
|
||||||
|
seen := make(map[string]struct{}, total)
|
||||||
|
for i := 0; i < total; i++ {
|
||||||
|
id := buildFileID("report.txt")
|
||||||
|
if _, ok := seen[id]; ok {
|
||||||
|
t.Fatalf("duplicate file id generated: %q", id)
|
||||||
|
}
|
||||||
|
seen[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFileIDKeepsReadableBaseName(t *testing.T) {
|
||||||
|
id := buildFileID("/tmp/demo/report.txt")
|
||||||
|
if !strings.HasPrefix(id, "report.txt-") {
|
||||||
|
t.Fatalf("unexpected file id prefix: %q", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := strings.Split(id, "-")
|
||||||
|
if got, want := len(parts), 5; got != want {
|
||||||
|
t.Fatalf("unexpected file id segment count: got %d want %d, id=%q", got, want, id)
|
||||||
|
}
|
||||||
|
for _, part := range parts[1:] {
|
||||||
|
if part == "" {
|
||||||
|
t.Fatalf("unexpected empty file id segment: %q", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildFileIDFallsBackToUnnamedBase(t *testing.T) {
|
||||||
|
id := buildFileID("")
|
||||||
|
if !strings.HasPrefix(id, "unnamed-") {
|
||||||
|
t.Fatalf("unexpected unnamed file id prefix: %q", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewFileMetaEnvelopeKeepsOptionalMeta(t *testing.T) {
|
||||||
|
env := newFileMetaEnvelope("file-1", "/tmp/demo/report.txt", 123, "sum", 0o640, 123456789)
|
||||||
|
if got, want := env.File.Mode, uint32(0o640); got != want {
|
||||||
|
t.Fatalf("mode mismatch: got %o want %o", got, want)
|
||||||
|
}
|
||||||
|
if got, want := env.File.ModTime, int64(123456789); got != want {
|
||||||
|
t.Fatalf("modtime mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
# Signal Demo
|
||||||
|
|
||||||
|
`examples/signal` 演示 `notify` 的最小消息收发路径,覆盖服务端监听、客户端 `SendWait`、服务端 `Reply` 和并发请求。
|
||||||
|
|
||||||
|
## 功能
|
||||||
|
|
||||||
|
- `serve`:启动服务端并监听本地 IPC 端点
|
||||||
|
- `signal`:发送消息并等待回包
|
||||||
|
- 并发发送:`-n` 指定总请求数,`-c` 指定并发数
|
||||||
|
|
||||||
|
## 运行
|
||||||
|
|
||||||
|
在模块根目录执行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go run ./examples/signal serve
|
||||||
|
```
|
||||||
|
|
||||||
|
另开终端发送单条消息:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go run ./examples/signal signal --msg "hello"
|
||||||
|
```
|
||||||
|
|
||||||
|
并发请求示例:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go run ./examples/signal signal --msg "ping" --n 100 --c 10
|
||||||
|
```
|
||||||
|
|
||||||
|
## 默认端点
|
||||||
|
|
||||||
|
- Windows:`network=npipe`,`addr=notify-signal-demo`
|
||||||
|
- Linux:`network=unix`,`addr=/tmp/notify-signal-demo.sock`
|
||||||
|
|
||||||
|
可通过 `--addr` 覆盖默认地址。
|
||||||
|
|
||||||
|
## 说明
|
||||||
|
|
||||||
|
- 示例中使用固定 PSK,仅用于本地演示。
|
||||||
|
- 示例的并发模式用于接口验证,不作为吞吐基准测试。
|
||||||
@@ -0,0 +1,217 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"b612.me/notify"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultPipeName = "notify-signal-demo"
|
||||||
|
defaultUnixSock = "/tmp/notify-signal-demo.sock"
|
||||||
|
sharedSecret = "0123456789abcdef0123456789abcdef"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
args := os.Args[1:]
|
||||||
|
if len(args) == 0 {
|
||||||
|
if err := runServe(nil); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch args[0] {
|
||||||
|
case "serve", "server":
|
||||||
|
if err := runServe(args[1:]); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "serve failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
case "signal":
|
||||||
|
if err := runSignal(args[1:]); err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "signal failed: %v\n", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
case "-h", "--help", "help":
|
||||||
|
printUsage()
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "unknown subcommand: %s\n", args[0])
|
||||||
|
printUsage()
|
||||||
|
os.Exit(2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runServe(args []string) error {
|
||||||
|
network, defaultAddr := defaultEndpoint()
|
||||||
|
|
||||||
|
fs := flag.NewFlagSet("serve", flag.ContinueOnError)
|
||||||
|
addr := fs.String("addr", defaultAddr, "listen address (windows: pipe name or \\\\.\\pipe\\name; linux: unix socket path)")
|
||||||
|
if err := fs.Parse(args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
srv := notify.NewServer()
|
||||||
|
if err := notify.UseModernPSKServer(srv, []byte(sharedSecret), nil); err != nil {
|
||||||
|
return fmt.Errorf("configure modern psk server: %w", err)
|
||||||
|
}
|
||||||
|
srv.SetLink("signal", func(msg *notify.Message) {
|
||||||
|
content := string(msg.Value)
|
||||||
|
fmt.Printf("[server] recv signal: %s\n", content)
|
||||||
|
reply := fmt.Sprintf("ack from server: %s", content)
|
||||||
|
if err := msg.Reply([]byte(reply)); err != nil {
|
||||||
|
fmt.Printf("[server] reply error: %v\n", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
cleanup, err := prepareEndpoint(network, *addr)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
if err := srv.Listen(network, *addr); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("[server] listening on %s %s\n", network, *addr)
|
||||||
|
|
||||||
|
stopSig := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(stopSig, os.Interrupt, syscall.SIGTERM)
|
||||||
|
<-stopSig
|
||||||
|
|
||||||
|
fmt.Println("[server] stopping...")
|
||||||
|
return srv.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func runSignal(args []string) error {
|
||||||
|
network, defaultAddr := defaultEndpoint()
|
||||||
|
|
||||||
|
fs := flag.NewFlagSet("signal", flag.ContinueOnError)
|
||||||
|
addr := fs.String("addr", defaultAddr, "target address")
|
||||||
|
msg := fs.String("msg", "hello", "signal payload")
|
||||||
|
count := fs.Int("n", 1, "total request count")
|
||||||
|
concurrency := fs.Int("c", 1, "concurrency for requests")
|
||||||
|
timeout := fs.Duration("timeout", 5*time.Second, "wait timeout per request")
|
||||||
|
if err := fs.Parse(args); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if *count <= 0 {
|
||||||
|
return errors.New("-n must be > 0")
|
||||||
|
}
|
||||||
|
if *concurrency <= 0 {
|
||||||
|
return errors.New("-c must be > 0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if *count == 1 && *concurrency == 1 {
|
||||||
|
reply, err := sendOne(network, *addr, *msg, *timeout)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
fmt.Printf("[client] recv reply: %s\n", reply)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
start := time.Now()
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
jobs := make(chan int)
|
||||||
|
errCh := make(chan error, *count)
|
||||||
|
|
||||||
|
worker := func() {
|
||||||
|
defer wg.Done()
|
||||||
|
for i := range jobs {
|
||||||
|
payload := fmt.Sprintf("%s #%d", *msg, i+1)
|
||||||
|
reply, err := sendOne(network, *addr, payload, *timeout)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- fmt.Errorf("job=%d: %w", i+1, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
fmt.Printf("[client] job=%d reply=%s\n", i+1, reply)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < *concurrency; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go worker()
|
||||||
|
}
|
||||||
|
for i := 0; i < *count; i++ {
|
||||||
|
jobs <- i
|
||||||
|
}
|
||||||
|
close(jobs)
|
||||||
|
wg.Wait()
|
||||||
|
close(errCh)
|
||||||
|
|
||||||
|
failures := 0
|
||||||
|
for err := range errCh {
|
||||||
|
failures++
|
||||||
|
fmt.Printf("[client] error: %v\n", err)
|
||||||
|
}
|
||||||
|
fmt.Printf("[client] done total=%d concurrency=%d failures=%d elapsed=%s\n", *count, *concurrency, failures, time.Since(start).Round(time.Millisecond))
|
||||||
|
if failures > 0 {
|
||||||
|
return fmt.Errorf("concurrent signal test finished with %d failures", failures)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendOne(network string, addr string, payload string, timeout time.Duration) (string, error) {
|
||||||
|
cli := notify.NewClient()
|
||||||
|
if err := notify.UseModernPSKClient(cli, []byte(sharedSecret), nil); err != nil {
|
||||||
|
return "", fmt.Errorf("configure modern psk client: %w", err)
|
||||||
|
}
|
||||||
|
if err := cli.Connect(network, addr); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = cli.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
reply, err := cli.SendWait("signal", []byte(payload), timeout)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(reply.Value), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultEndpoint() (network string, addr string) {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
return "npipe", defaultPipeName
|
||||||
|
}
|
||||||
|
return "unix", defaultUnixSock
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareEndpoint(network string, addr string) (func(), error) {
|
||||||
|
if network != "unix" {
|
||||||
|
return func() {}, nil
|
||||||
|
}
|
||||||
|
if addr == "" {
|
||||||
|
return nil, errors.New("unix socket path is empty")
|
||||||
|
}
|
||||||
|
if err := os.MkdirAll(filepath.Dir(addr), 0o755); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = os.Remove(addr)
|
||||||
|
return func() {
|
||||||
|
_ = os.Remove(addr)
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printUsage() {
|
||||||
|
fmt.Println("Usage:")
|
||||||
|
fmt.Println(" signal-demo serve [--addr <addr>]")
|
||||||
|
fmt.Println(" signal-demo signal [--addr <addr>] [--msg <text>] [--n <count>] [--c <concurrency>] [--timeout <duration>]")
|
||||||
|
fmt.Println("")
|
||||||
|
fmt.Println("Defaults:")
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
fmt.Printf(" network=npipe addr=%s\n", defaultPipeName)
|
||||||
|
} else {
|
||||||
|
fmt.Printf(" network=unix addr=%s\n", defaultUnixSock)
|
||||||
|
}
|
||||||
|
}
|
||||||
+177
@@ -0,0 +1,177 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errFileAckCanceled = errors.New("file ack canceled")
|
||||||
|
errFileAckTimeout = errors.New("file ack timeout")
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileAckWait struct {
|
||||||
|
key string
|
||||||
|
scope string
|
||||||
|
pool *fileAckPool
|
||||||
|
reply chan FileEvent
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileAckPool struct {
|
||||||
|
pool sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileAckPool() *fileAckPool {
|
||||||
|
return &fileAckPool{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileAckKey(scope string, fileID string, stage string, offset int64) string {
|
||||||
|
return normalizeFileScope(scope) + "|" + fileID + "|" + stage + "|" + formatInt(offset)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) prepare(scope string, fileID string, stage string, offset int64) *fileAckWait {
|
||||||
|
scope = normalizeFileScope(scope)
|
||||||
|
wait := &fileAckWait{
|
||||||
|
key: fileAckKey(scope, fileID, stage, offset),
|
||||||
|
scope: scope,
|
||||||
|
pool: p,
|
||||||
|
reply: make(chan FileEvent, 1),
|
||||||
|
}
|
||||||
|
p.pool.Store(wait.key, wait)
|
||||||
|
return wait
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) deliver(scope string, event FileEvent) bool {
|
||||||
|
return p.deliverAny([]string{scope}, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) deliverAny(scopes []string, event FileEvent) bool {
|
||||||
|
if p == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, scope := range scopes {
|
||||||
|
key := fileAckKey(scope, event.Packet.FileID, event.Packet.Stage, event.Packet.Offset)
|
||||||
|
data, ok := p.pool.LoadAndDelete(key)
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
wait := data.(*fileAckWait)
|
||||||
|
wait.deliver(event)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fileAckWait) cancel() {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if w.pool != nil {
|
||||||
|
w.pool.pool.Delete(w.key)
|
||||||
|
}
|
||||||
|
w.closeReply()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fileAckWait) deliver(event FileEvent) {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.closeOnce.Do(func() {
|
||||||
|
select {
|
||||||
|
case w.reply <- event:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
close(w.reply)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *fileAckWait) closeReply() {
|
||||||
|
if w == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.closeOnce.Do(func() {
|
||||||
|
close(w.reply)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) waitPrepared(wait *fileAckWait, timeout time.Duration) error {
|
||||||
|
if timeout <= 0 {
|
||||||
|
timeout = defaultFileAckTimeout
|
||||||
|
}
|
||||||
|
timer := time.NewTimer(timeout)
|
||||||
|
defer timer.Stop()
|
||||||
|
select {
|
||||||
|
case event, ok := <-wait.reply:
|
||||||
|
if !ok {
|
||||||
|
return errFileAckCanceled
|
||||||
|
}
|
||||||
|
if event.Err != nil {
|
||||||
|
return event.Err
|
||||||
|
}
|
||||||
|
if event.Packet.Error != "" {
|
||||||
|
return errors.New(event.Packet.Error)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case <-timer.C:
|
||||||
|
wait.cancel()
|
||||||
|
return errFileAckTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) wait(scope string, fileID string, stage string, offset int64, timeout time.Duration) error {
|
||||||
|
wait := p.prepare(scope, fileID, stage, offset)
|
||||||
|
return p.waitPrepared(wait, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) closeAll() {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.pool.Range(func(_, value interface{}) bool {
|
||||||
|
value.(*fileAckWait).cancel()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) closeScope(scope string) {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scope = normalizeFileScope(scope)
|
||||||
|
p.pool.Range(func(_, value interface{}) bool {
|
||||||
|
wait := value.(*fileAckWait)
|
||||||
|
if wait.scope == scope {
|
||||||
|
wait.cancel()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileAckPool) closeScopeFamily(scope string) {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
base := normalizeFileScope(scope)
|
||||||
|
p.pool.Range(func(_, value interface{}) bool {
|
||||||
|
wait := value.(*fileAckWait)
|
||||||
|
if scopeBelongsToServerFileScope(wait.scope, base) {
|
||||||
|
wait.cancel()
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatInt(v int64) string {
|
||||||
|
return strconv.FormatInt(v, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getFileAckPool() *fileAckPool {
|
||||||
|
return c.getLogicalSessionState().fileAckWaits
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getFileAckPool() *fileAckPool {
|
||||||
|
return s.getLogicalSessionState().fileAckWaits
|
||||||
|
}
|
||||||
@@ -0,0 +1,200 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileTransferRetryHooks struct {
|
||||||
|
onRetry func(err error, attempt int)
|
||||||
|
onTimeout func(err error, attempt int)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileStageByKind(kind EnvelopeKind) string {
|
||||||
|
switch kind {
|
||||||
|
case EnvelopeFileMeta:
|
||||||
|
return "meta"
|
||||||
|
case EnvelopeFileChunk:
|
||||||
|
return "chunk"
|
||||||
|
case EnvelopeFileEnd:
|
||||||
|
return "end"
|
||||||
|
case EnvelopeFileAbort:
|
||||||
|
return "abort"
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendFileAck(src Envelope, processErr error) error {
|
||||||
|
errMsg := ""
|
||||||
|
if processErr != nil {
|
||||||
|
errMsg = processErr.Error()
|
||||||
|
}
|
||||||
|
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
|
||||||
|
return c.sendEnvelope(ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileAck(logical *LogicalConn, src Envelope, processErr error) error {
|
||||||
|
if logical == nil {
|
||||||
|
return s.sendFileAckTransport(nil, src, processErr)
|
||||||
|
}
|
||||||
|
return s.sendFileAckTransport(s.resolveOutboundTransport(logical), src, processErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileAckTransport(transport *TransportConn, src Envelope, processErr error) error {
|
||||||
|
errMsg := ""
|
||||||
|
if processErr != nil {
|
||||||
|
errMsg = processErr.Error()
|
||||||
|
}
|
||||||
|
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
|
||||||
|
return s.sendEnvelopeTransport(transport, ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileAckInbound(logical *LogicalConn, transport *TransportConn, conn net.Conn, src Envelope, processErr error) error {
|
||||||
|
if conn == nil {
|
||||||
|
return s.sendFileAckTransport(transport, src, processErr)
|
||||||
|
}
|
||||||
|
errMsg := ""
|
||||||
|
if processErr != nil {
|
||||||
|
errMsg = processErr.Error()
|
||||||
|
}
|
||||||
|
ack := newFileAckEnvelope(src.File.FileID, fileStageByKind(src.Kind), src.File.Offset, errMsg)
|
||||||
|
return s.sendEnvelopeInboundTransport(logical, transport, conn, ack)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendFileAbort(fileID string, stage string, offset int64, cause error) error {
|
||||||
|
errMsg := ""
|
||||||
|
if cause != nil {
|
||||||
|
errMsg = cause.Error()
|
||||||
|
}
|
||||||
|
return c.sendEnvelope(newFileAbortEnvelope(fileID, stage, offset, errMsg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileAbort(logical *LogicalConn, fileID string, stage string, offset int64, cause error) error {
|
||||||
|
if logical == nil {
|
||||||
|
return s.sendFileAbortTransport(nil, fileID, stage, offset, cause)
|
||||||
|
}
|
||||||
|
return s.sendFileAbortTransport(s.resolveOutboundTransport(logical), fileID, stage, offset, cause)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileAbortTransport(transport *TransportConn, fileID string, stage string, offset int64, cause error) error {
|
||||||
|
errMsg := ""
|
||||||
|
if cause != nil {
|
||||||
|
errMsg = cause.Error()
|
||||||
|
}
|
||||||
|
return s.sendEnvelopeTransport(transport, newFileAbortEnvelope(fileID, stage, offset, errMsg))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendFileEnvelopeWithAck(env Envelope, timeout time.Duration) error {
|
||||||
|
pool := c.getFileAckPool()
|
||||||
|
wait := pool.prepare(clientFileScope(), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset)
|
||||||
|
if err := c.sendEnvelope(env); err != nil {
|
||||||
|
wait.cancel()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return pool.waitPrepared(wait, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileEnvelopeWithAck(logical *LogicalConn, env Envelope, timeout time.Duration) error {
|
||||||
|
if logical == nil {
|
||||||
|
return s.sendFileEnvelopeWithAckTransport(nil, env, timeout)
|
||||||
|
}
|
||||||
|
return s.sendFileEnvelopeWithAckTransport(s.resolveOutboundTransport(logical), env, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileEnvelopeWithAckTransport(transport *TransportConn, env Envelope, timeout time.Duration) error {
|
||||||
|
pool := s.getFileAckPool()
|
||||||
|
wait := pool.prepare(serverTransportScopeForTransport(transport), env.File.FileID, fileStageByKind(env.Kind), env.File.Offset)
|
||||||
|
if err := s.sendEnvelopeTransport(transport, env); err != nil {
|
||||||
|
wait.cancel()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return pool.waitPrepared(wait, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendFileEnvelopeReliable(ctx context.Context, env Envelope, cfg fileTransferConfig) error {
|
||||||
|
state := c.getFileTransferState()
|
||||||
|
scope := clientFileScope()
|
||||||
|
stage := fileStageByKind(env.Kind)
|
||||||
|
state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||||
|
return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error {
|
||||||
|
return c.sendFileEnvelopeWithAck(env, cfg.AckTimeout)
|
||||||
|
}, fileTransferRetryHooks{
|
||||||
|
onRetry: func(err error, _ int) {
|
||||||
|
state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID)
|
||||||
|
},
|
||||||
|
onTimeout: func(err error, _ int) {
|
||||||
|
state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID)
|
||||||
|
state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileEnvelopeReliable(ctx context.Context, logical *LogicalConn, env Envelope, cfg fileTransferConfig) error {
|
||||||
|
if logical == nil {
|
||||||
|
return s.sendFileEnvelopeReliableTransport(ctx, nil, env, cfg)
|
||||||
|
}
|
||||||
|
return s.sendFileEnvelopeReliableTransport(ctx, s.resolveOutboundTransport(logical), env, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileEnvelopeReliableTransport(ctx context.Context, transport *TransportConn, env Envelope, cfg fileTransferConfig) error {
|
||||||
|
state := s.getFileTransferState()
|
||||||
|
scope := serverTransportScopeForTransport(transport)
|
||||||
|
stage := fileStageByKind(env.Kind)
|
||||||
|
state.recordRuntimeStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||||
|
return retryFileTransferSend(ctx, cfg, func(cfg fileTransferConfig) error {
|
||||||
|
return s.sendFileEnvelopeWithAckTransport(transport, env, cfg.AckTimeout)
|
||||||
|
}, fileTransferRetryHooks{
|
||||||
|
onRetry: func(err error, _ int) {
|
||||||
|
state.recordRuntimeRetry(fileTransferDirectionSend, scope, env.File.FileID)
|
||||||
|
},
|
||||||
|
onTimeout: func(err error, _ int) {
|
||||||
|
state.recordRuntimeTimeout(fileTransferDirectionSend, scope, env.File.FileID)
|
||||||
|
state.recordRuntimeFailureStage(fileTransferDirectionSend, scope, env.File.FileID, stage)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func retryFileTransferSend(ctx context.Context, cfg fileTransferConfig, send func(fileTransferConfig) error, hooks ...fileTransferRetryHooks) error {
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
var lastErr error
|
||||||
|
hook := mergeFileTransferRetryHooks(hooks...)
|
||||||
|
for attempt := 0; attempt < cfg.SendRetry; attempt++ {
|
||||||
|
if ctx != nil {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lastErr = send(cfg)
|
||||||
|
if lastErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(lastErr, errFileAckTimeout) && hook.onTimeout != nil {
|
||||||
|
hook.onTimeout(lastErr, attempt+1)
|
||||||
|
}
|
||||||
|
if attempt+1 < cfg.SendRetry && hook.onRetry != nil {
|
||||||
|
hook.onRetry(lastErr, attempt+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lastErr == nil {
|
||||||
|
lastErr = errors.New("file send failed")
|
||||||
|
}
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeFileTransferRetryHooks(hooks ...fileTransferRetryHooks) fileTransferRetryHooks {
|
||||||
|
var merged fileTransferRetryHooks
|
||||||
|
for _, hook := range hooks {
|
||||||
|
if hook.onRetry != nil {
|
||||||
|
merged.onRetry = hook.onRetry
|
||||||
|
}
|
||||||
|
if hook.onTimeout != nil {
|
||||||
|
merged.onTimeout = hook.onTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return merged
|
||||||
|
}
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRetryFileTransferSendHonorsRetryCount(t *testing.T) {
|
||||||
|
var attempts int
|
||||||
|
|
||||||
|
err := retryFileTransferSend(context.Background(), fileTransferConfig{
|
||||||
|
SendRetry: 3,
|
||||||
|
}, func(cfg fileTransferConfig) error {
|
||||||
|
attempts++
|
||||||
|
return errors.New("send failed")
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("retryFileTransferSend should return the last error")
|
||||||
|
}
|
||||||
|
if got, want := attempts, 3; got != want {
|
||||||
|
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryFileTransferSendStopsAfterSuccess(t *testing.T) {
|
||||||
|
var attempts int
|
||||||
|
|
||||||
|
err := retryFileTransferSend(context.Background(), fileTransferConfig{
|
||||||
|
SendRetry: 5,
|
||||||
|
}, func(cfg fileTransferConfig) error {
|
||||||
|
attempts++
|
||||||
|
if attempts == 3 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return errors.New("send failed")
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("retryFileTransferSend should stop after success: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := attempts, 3; got != want {
|
||||||
|
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryFileTransferSendHonorsContextCancel(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
var attempts int
|
||||||
|
err := retryFileTransferSend(ctx, fileTransferConfig{
|
||||||
|
SendRetry: 3,
|
||||||
|
}, func(cfg fileTransferConfig) error {
|
||||||
|
attempts++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("expected context canceled, got %v", err)
|
||||||
|
}
|
||||||
|
if got, want := attempts, 0; got != want {
|
||||||
|
t.Fatalf("attempt count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRetryFileTransferSendReportsRetryAndTimeoutHooks(t *testing.T) {
|
||||||
|
var attempts int
|
||||||
|
var retries int
|
||||||
|
var timeouts int
|
||||||
|
|
||||||
|
err := retryFileTransferSend(context.Background(), fileTransferConfig{
|
||||||
|
SendRetry: 3,
|
||||||
|
}, func(cfg fileTransferConfig) error {
|
||||||
|
attempts++
|
||||||
|
if attempts < 3 {
|
||||||
|
return errFileAckTimeout
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}, fileTransferRetryHooks{
|
||||||
|
onRetry: func(err error, attempt int) {
|
||||||
|
retries++
|
||||||
|
if !errors.Is(err, errFileAckTimeout) {
|
||||||
|
t.Fatalf("retry err = %v, want %v", err, errFileAckTimeout)
|
||||||
|
}
|
||||||
|
if attempt != retries {
|
||||||
|
t.Fatalf("retry attempt = %d, want %d", attempt, retries)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onTimeout: func(err error, attempt int) {
|
||||||
|
timeouts++
|
||||||
|
if !errors.Is(err, errFileAckTimeout) {
|
||||||
|
t.Fatalf("timeout err = %v, want %v", err, errFileAckTimeout)
|
||||||
|
}
|
||||||
|
if attempt != timeouts {
|
||||||
|
t.Fatalf("timeout attempt = %d, want %d", attempt, timeouts)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("retryFileTransferSend should succeed after timeout retries: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := retries, 2; got != want {
|
||||||
|
t.Fatalf("retry hook count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := timeouts, 2; got != want {
|
||||||
|
t.Fatalf("timeout hook count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestFileAckPoolPreparedWaitConsumesEarlyAck(t *testing.T) {
|
||||||
|
pool := newFileAckPool()
|
||||||
|
wait := pool.prepare("client:a", "file-1", "chunk", 64)
|
||||||
|
|
||||||
|
ok := pool.deliver("client:a", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-1",
|
||||||
|
Stage: "chunk",
|
||||||
|
Offset: 64,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("deliver should match prepared waiter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.waitPrepared(wait, defaultFileAckTimeout); err != nil {
|
||||||
|
t.Fatalf("waitPrepared failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileAckPoolPreparedWaitReturnsAckError(t *testing.T) {
|
||||||
|
pool := newFileAckPool()
|
||||||
|
wait := pool.prepare("client:a", "file-2", "meta", 0)
|
||||||
|
|
||||||
|
ok := pool.deliver("client:a", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-2",
|
||||||
|
Stage: "meta",
|
||||||
|
Offset: 0,
|
||||||
|
Error: "checksum mismatch",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("deliver should match prepared waiter")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := pool.waitPrepared(wait, defaultFileAckTimeout)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("waitPrepared should return ack error")
|
||||||
|
}
|
||||||
|
if got, want := err.Error(), "checksum mismatch"; got != want {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileAckPoolCancelRemovesPreparedWaiter(t *testing.T) {
|
||||||
|
pool := newFileAckPool()
|
||||||
|
wait := pool.prepare("client:a", "file-3", "end", 0)
|
||||||
|
wait.cancel()
|
||||||
|
|
||||||
|
ok := pool.deliver("client:a", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-3",
|
||||||
|
Stage: "end",
|
||||||
|
Offset: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if ok {
|
||||||
|
t.Fatal("deliver should not match canceled waiter")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileAckPoolScopeIsolation(t *testing.T) {
|
||||||
|
pool := newFileAckPool()
|
||||||
|
waitA := pool.prepare("server:client-a", "file-4", "chunk", 128)
|
||||||
|
waitB := pool.prepare("server:client-b", "file-4", "chunk", 128)
|
||||||
|
|
||||||
|
ok := pool.deliver("server:client-a", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-4",
|
||||||
|
Stage: "chunk",
|
||||||
|
Offset: 128,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("deliver should match scopeA waiter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.waitPrepared(waitA, defaultFileAckTimeout); err != nil {
|
||||||
|
t.Fatalf("waitPrepared scopeA failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok = pool.deliver("server:client-a", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-4",
|
||||||
|
Stage: "chunk",
|
||||||
|
Offset: 128,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if ok {
|
||||||
|
t.Fatal("scopeA ack should not consume scopeB waiter")
|
||||||
|
}
|
||||||
|
|
||||||
|
ok = pool.deliver("server:client-b", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-4",
|
||||||
|
Stage: "chunk",
|
||||||
|
Offset: 128,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("deliver should match scopeB waiter")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
|
||||||
|
t.Fatalf("waitPrepared scopeB failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileAckPoolCloseAllCancelsPreparedWaiters(t *testing.T) {
|
||||||
|
pool := newFileAckPool()
|
||||||
|
wait := pool.prepare("client:a", "file-5", "chunk", 256)
|
||||||
|
|
||||||
|
pool.closeAll()
|
||||||
|
|
||||||
|
err := pool.waitPrepared(wait, defaultFileAckTimeout)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("waitPrepared should return cancel error after closeAll")
|
||||||
|
}
|
||||||
|
if got, want := err.Error(), "file ack canceled"; got != want {
|
||||||
|
t.Fatalf("unexpected error after closeAll: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileAckPoolCloseScopeCancelsMatchingWaitersOnly(t *testing.T) {
|
||||||
|
pool := newFileAckPool()
|
||||||
|
waitA := pool.prepare("server:client-a", "file-6", "chunk", 256)
|
||||||
|
waitB := pool.prepare("server:client-b", "file-6", "chunk", 256)
|
||||||
|
|
||||||
|
pool.closeScope("server:client-a")
|
||||||
|
|
||||||
|
err := pool.waitPrepared(waitA, defaultFileAckTimeout)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("scopeA waiter should be canceled")
|
||||||
|
}
|
||||||
|
if got, want := err.Error(), "file ack canceled"; got != want {
|
||||||
|
t.Fatalf("unexpected scopeA error: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := pool.deliver("server:client-b", FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-6",
|
||||||
|
Stage: "chunk",
|
||||||
|
Offset: 256,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("scopeB waiter should remain deliverable")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
|
||||||
|
t.Fatalf("waitPrepared scopeB failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerRemoveClientClosesScopedFileAckWaiters(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
clientA := &ClientConn{ClientID: "client-a"}
|
||||||
|
clientB := &ClientConn{ClientID: "client-b"}
|
||||||
|
pool := server.getFileAckPool()
|
||||||
|
|
||||||
|
waitA := pool.prepare(serverFileScope(clientA), "file-7", "end", 0)
|
||||||
|
waitB := pool.prepare(serverFileScope(clientB), "file-7", "end", 0)
|
||||||
|
|
||||||
|
server.removeClient(clientA)
|
||||||
|
|
||||||
|
err := pool.waitPrepared(waitA, defaultFileAckTimeout)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("clientA waiter should be canceled when client is removed")
|
||||||
|
}
|
||||||
|
if got, want := err.Error(), "file ack canceled"; got != want {
|
||||||
|
t.Fatalf("unexpected clientA error: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
ok := pool.deliver(serverFileScope(clientB), FileEvent{
|
||||||
|
Packet: FilePacket{
|
||||||
|
FileID: "file-7",
|
||||||
|
Stage: "end",
|
||||||
|
Offset: 0,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("clientB waiter should remain deliverable")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pool.waitPrepared(waitB, defaultFileAckTimeout); err != nil {
|
||||||
|
t.Fatalf("waitPrepared clientB failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,171 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ClientCommon) dispatchFileEnvelope(env Envelope, now time.Time) {
|
||||||
|
event := FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
ServerConn: c,
|
||||||
|
Kind: env.Kind,
|
||||||
|
Packet: env.File,
|
||||||
|
Time: now,
|
||||||
|
}
|
||||||
|
pool := c.getFileReceivePool()
|
||||||
|
switch env.Kind {
|
||||||
|
case EnvelopeAck:
|
||||||
|
event.Packet.Stage = env.File.Stage
|
||||||
|
event.Packet.Error = env.File.Error
|
||||||
|
event.Received = env.File.Offset
|
||||||
|
if c.getFileAckPool().deliver(clientFileScope(), event) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case EnvelopeFileMeta:
|
||||||
|
session, err := pool.onMeta(clientFileScope(), env.File, now)
|
||||||
|
if session != nil {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
event.Received = session.received
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
case EnvelopeFileChunk:
|
||||||
|
session, err := pool.onChunk(clientFileScope(), env.File, now)
|
||||||
|
if session != nil {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
event.Received = session.received
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
case EnvelopeFileEnd:
|
||||||
|
finalPath, session, err := pool.onEnd(clientFileScope(), env.File, now)
|
||||||
|
if session != nil {
|
||||||
|
event.Path = finalPath
|
||||||
|
event.Received = session.received
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
case EnvelopeFileAbort:
|
||||||
|
session, err := pool.onAbort(clientFileScope(), env.File, now)
|
||||||
|
event.Received = env.File.Offset
|
||||||
|
if session != nil {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort {
|
||||||
|
if ackErr := c.sendFileAck(env, event.Err); ackErr != nil && event.Err == nil {
|
||||||
|
event.Err = ackErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fillFileEventProgress(&event)
|
||||||
|
c.publishReceivedFileEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) dispatchFileEnvelope(logical *LogicalConn, transport *TransportConn, conn net.Conn, env Envelope, now time.Time) {
|
||||||
|
if transport == nil && logical != nil {
|
||||||
|
transport = logical.CurrentTransportConn()
|
||||||
|
}
|
||||||
|
event := FileEvent{
|
||||||
|
LogicalConn: logical,
|
||||||
|
NetType: NET_SERVER,
|
||||||
|
TransportConn: transport,
|
||||||
|
Kind: env.Kind,
|
||||||
|
Packet: env.File,
|
||||||
|
Time: now,
|
||||||
|
}
|
||||||
|
pool := s.getFileReceivePool()
|
||||||
|
switch env.Kind {
|
||||||
|
case EnvelopeAck:
|
||||||
|
event.Packet.Stage = env.File.Stage
|
||||||
|
event.Packet.Error = env.File.Error
|
||||||
|
event.Received = env.File.Offset
|
||||||
|
scopes := serverTransportDeliveryScopes(logical)
|
||||||
|
if transport := fileEventTransportConnSnapshot(event); transport != nil {
|
||||||
|
scopes = serverTransportDeliveryScopesForTransport(transport)
|
||||||
|
}
|
||||||
|
if s.getFileAckPool().deliverAny(scopes, event) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case EnvelopeFileMeta:
|
||||||
|
session, err := pool.onMeta(serverFileScope(logical), env.File, now)
|
||||||
|
if session != nil {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
event.Received = session.received
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
case EnvelopeFileChunk:
|
||||||
|
session, err := pool.onChunk(serverFileScope(logical), env.File, now)
|
||||||
|
if session != nil {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
event.Received = session.received
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
case EnvelopeFileEnd:
|
||||||
|
finalPath, session, err := pool.onEnd(serverFileScope(logical), env.File, now)
|
||||||
|
if session != nil {
|
||||||
|
event.Path = finalPath
|
||||||
|
event.Received = session.received
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
case EnvelopeFileAbort:
|
||||||
|
session, err := pool.onAbort(serverFileScope(logical), env.File, now)
|
||||||
|
event.Received = env.File.Offset
|
||||||
|
if session != nil {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
event.Err = err
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if env.Kind == EnvelopeFileMeta || env.Kind == EnvelopeFileChunk || env.Kind == EnvelopeFileEnd || env.Kind == EnvelopeFileAbort {
|
||||||
|
if ackErr := s.sendFileAckInbound(logical, transport, conn, env, event.Err); ackErr != nil && event.Err == nil {
|
||||||
|
event.Err = ackErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fillFileEventProgress(&event)
|
||||||
|
s.publishReceivedFileEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) emitFileEvent(event FileEvent) {
|
||||||
|
c.mu.Lock()
|
||||||
|
handler := c.onFileEvent
|
||||||
|
c.mu.Unlock()
|
||||||
|
if handler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) emitFileEvent(event FileEvent) {
|
||||||
|
s.mu.Lock()
|
||||||
|
handler := s.onFileEvent
|
||||||
|
s.mu.Unlock()
|
||||||
|
if handler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
handler(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) logFileEvent(role string, event FileEvent) {
|
||||||
|
if !(c.debugMode || event.Err != nil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n",
|
||||||
|
role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) logFileEvent(role string, event FileEvent) {
|
||||||
|
if !(s.debugMode || event.Err != nil) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("%s file event kind=%d file_id=%s received=%d path=%s err=%v\n",
|
||||||
|
role, event.Kind, event.Packet.FileID, event.Received, event.Path, event.Err)
|
||||||
|
}
|
||||||
+243
@@ -0,0 +1,243 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type FileEvent struct {
|
||||||
|
NetType NetType
|
||||||
|
LogicalConn *LogicalConn
|
||||||
|
// Deprecated: ClientConn aliases LogicalConn for compatibility.
|
||||||
|
ClientConn *ClientConn
|
||||||
|
TransportConn *TransportConn
|
||||||
|
ServerConn Client
|
||||||
|
Kind EnvelopeKind
|
||||||
|
Packet FilePacket
|
||||||
|
Path string
|
||||||
|
Received int64
|
||||||
|
Total int64
|
||||||
|
Percent float64
|
||||||
|
Done bool
|
||||||
|
StartedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
RateBPS float64
|
||||||
|
StepDuration time.Duration
|
||||||
|
InstantRateBPS float64
|
||||||
|
Err error
|
||||||
|
Time time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeFileEventTime(now time.Time) time.Time {
|
||||||
|
if now.IsZero() {
|
||||||
|
return time.Now()
|
||||||
|
}
|
||||||
|
return now
|
||||||
|
}
|
||||||
|
|
||||||
|
func hydrateServerFileEventPeerFields(event FileEvent) FileEvent {
|
||||||
|
if event.LogicalConn == nil {
|
||||||
|
event.LogicalConn = logicalConnFromClient(event.ClientConn)
|
||||||
|
}
|
||||||
|
if event.ClientConn == nil {
|
||||||
|
event.ClientConn = event.LogicalConn.compatClientConn()
|
||||||
|
}
|
||||||
|
if event.TransportConn == nil && event.LogicalConn != nil {
|
||||||
|
event.TransportConn = event.LogicalConn.CurrentTransportConn()
|
||||||
|
}
|
||||||
|
return event
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileEventLogicalConnSnapshot(event FileEvent) *LogicalConn {
|
||||||
|
if event.LogicalConn != nil {
|
||||||
|
return event.LogicalConn
|
||||||
|
}
|
||||||
|
return logicalConnFromClient(event.ClientConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileEventTransportConnSnapshot(event FileEvent) *TransportConn {
|
||||||
|
if event.TransportConn != nil {
|
||||||
|
return event.TransportConn
|
||||||
|
}
|
||||||
|
logical := fileEventLogicalConnSnapshot(event)
|
||||||
|
if logical == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return logical.CurrentTransportConn()
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileEventTimeline struct {
|
||||||
|
startedAt time.Time
|
||||||
|
updatedAt time.Time
|
||||||
|
previousUpdatedAt time.Time
|
||||||
|
previousProgress int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillFileEventProgress(event *FileEvent) {
|
||||||
|
if event == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event.Total = event.Packet.Size
|
||||||
|
if event.Received < 0 {
|
||||||
|
event.Received = 0
|
||||||
|
}
|
||||||
|
if event.Total > 0 && event.Received > event.Total {
|
||||||
|
event.Received = event.Total
|
||||||
|
}
|
||||||
|
switch event.Kind {
|
||||||
|
case EnvelopeFileEnd:
|
||||||
|
event.Done = event.Err == nil
|
||||||
|
if event.Done && event.Total > 0 {
|
||||||
|
event.Received = event.Total
|
||||||
|
}
|
||||||
|
case EnvelopeFileAbort:
|
||||||
|
event.Done = false
|
||||||
|
}
|
||||||
|
if event.Total <= 0 {
|
||||||
|
if event.Done {
|
||||||
|
event.Percent = 100
|
||||||
|
}
|
||||||
|
if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) {
|
||||||
|
event.Duration = event.UpdatedAt.Sub(event.StartedAt)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event.Percent = float64(event.Received) * 100 / float64(event.Total)
|
||||||
|
if event.Percent < 0 {
|
||||||
|
event.Percent = 0
|
||||||
|
}
|
||||||
|
if event.Percent > 100 {
|
||||||
|
event.Percent = 100
|
||||||
|
}
|
||||||
|
if !event.StartedAt.IsZero() && !event.UpdatedAt.IsZero() && !event.UpdatedAt.Before(event.StartedAt) {
|
||||||
|
event.Duration = event.UpdatedAt.Sub(event.StartedAt)
|
||||||
|
}
|
||||||
|
if event.Duration > 0 && event.Received > 0 {
|
||||||
|
event.RateBPS = float64(event.Received) / event.Duration.Seconds()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillFileEventTimeline(event *FileEvent, timeline fileEventTimeline) {
|
||||||
|
if event == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
event.StartedAt = timeline.startedAt
|
||||||
|
event.UpdatedAt = timeline.updatedAt
|
||||||
|
if !timeline.previousUpdatedAt.IsZero() && !timeline.updatedAt.Before(timeline.previousUpdatedAt) {
|
||||||
|
event.StepDuration = timeline.updatedAt.Sub(timeline.previousUpdatedAt)
|
||||||
|
}
|
||||||
|
if delta := event.Received - timeline.previousProgress; delta > 0 && event.StepDuration > 0 {
|
||||||
|
event.InstantRateBPS = float64(delta) / event.StepDuration.Seconds()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillFileEventTiming(event *FileEvent, session *fileReceiveSession) {
|
||||||
|
if session == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fillFileEventTimeline(event, fileEventTimeline{
|
||||||
|
startedAt: session.startedAt,
|
||||||
|
updatedAt: session.updatedAt,
|
||||||
|
previousUpdatedAt: session.previousUpdatedAt,
|
||||||
|
previousProgress: session.previousReceived,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func fillFileSendEventTiming(event *FileEvent, session *fileSendSession) {
|
||||||
|
if session == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fillFileEventTimeline(event, fileEventTimeline{
|
||||||
|
startedAt: session.startedAt,
|
||||||
|
updatedAt: session.updatedAt,
|
||||||
|
previousUpdatedAt: session.previousUpdatedAt,
|
||||||
|
previousProgress: session.previousSent,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeFileEventCallback(fn func(FileEvent)) func(FileEvent) {
|
||||||
|
if fn == nil {
|
||||||
|
return func(FileEvent) {}
|
||||||
|
}
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setFileEventObserver(fn func(FileEvent)) {
|
||||||
|
c.mu.Lock()
|
||||||
|
c.fileEventObserver = normalizeFileEventCallback(fn)
|
||||||
|
c.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) setFileEventObserver(fn func(FileEvent)) {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fileEventObserver = normalizeFileEventCallback(fn)
|
||||||
|
s.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) observeFileEvent(event FileEvent) {
|
||||||
|
c.mu.Lock()
|
||||||
|
observer := c.fileEventObserver
|
||||||
|
c.mu.Unlock()
|
||||||
|
normalizeFileEventCallback(observer)(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) observeFileEvent(event FileEvent) {
|
||||||
|
s.mu.RLock()
|
||||||
|
observer := s.fileEventObserver
|
||||||
|
s.mu.RUnlock()
|
||||||
|
normalizeFileEventCallback(observer)(hydrateServerFileEventPeerFields(event))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) publishReceivedFileEvent(event FileEvent) {
|
||||||
|
c.getFileTransferState().observe(fileTransferDirectionReceive, event)
|
||||||
|
c.observeFileEvent(event)
|
||||||
|
c.logFileEvent("client", event)
|
||||||
|
c.emitFileEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) publishReceivedFileEventMonitorOnly(event FileEvent) {
|
||||||
|
c.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event)
|
||||||
|
c.observeFileEvent(event)
|
||||||
|
c.logFileEvent("client", event)
|
||||||
|
c.emitFileEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) publishReceivedFileEvent(event FileEvent) {
|
||||||
|
event = hydrateServerFileEventPeerFields(event)
|
||||||
|
s.getFileTransferState().observe(fileTransferDirectionReceive, event)
|
||||||
|
s.observeFileEvent(event)
|
||||||
|
s.logFileEvent("server", event)
|
||||||
|
s.emitFileEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) publishReceivedFileEventMonitorOnly(event FileEvent) {
|
||||||
|
event = hydrateServerFileEventPeerFields(event)
|
||||||
|
s.getFileTransferState().observeMonitorOnly(fileTransferDirectionReceive, event)
|
||||||
|
s.observeFileEvent(event)
|
||||||
|
s.logFileEvent("server", event)
|
||||||
|
s.emitFileEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) publishSendFileEvent(event FileEvent) {
|
||||||
|
c.getFileTransferState().observe(fileTransferDirectionSend, event)
|
||||||
|
c.observeFileEvent(event)
|
||||||
|
c.logFileEvent("client-send", event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) publishSendFileEventMonitorOnly(event FileEvent) {
|
||||||
|
c.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event)
|
||||||
|
c.observeFileEvent(event)
|
||||||
|
c.logFileEvent("client-send", event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) publishSendFileEvent(event FileEvent) {
|
||||||
|
event = hydrateServerFileEventPeerFields(event)
|
||||||
|
s.getFileTransferState().observe(fileTransferDirectionSend, event)
|
||||||
|
s.observeFileEvent(event)
|
||||||
|
s.logFileEvent("server-send", event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) publishSendFileEventMonitorOnly(event FileEvent) {
|
||||||
|
event = hydrateServerFileEventPeerFields(event)
|
||||||
|
s.getFileTransferState().observeMonitorOnly(fileTransferDirectionSend, event)
|
||||||
|
s.observeFileEvent(event)
|
||||||
|
s.logFileEvent("server-send", event)
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFillFileEventTimeline(t *testing.T) {
|
||||||
|
event := FileEvent{
|
||||||
|
Received: 150,
|
||||||
|
}
|
||||||
|
timeline := fileEventTimeline{
|
||||||
|
startedAt: time.Unix(100, 0),
|
||||||
|
updatedAt: time.Unix(110, 0),
|
||||||
|
previousUpdatedAt: time.Unix(106, 0),
|
||||||
|
previousProgress: 90,
|
||||||
|
}
|
||||||
|
|
||||||
|
fillFileEventTimeline(&event, timeline)
|
||||||
|
|
||||||
|
if got, want := event.StartedAt, timeline.startedAt; !got.Equal(want) {
|
||||||
|
t.Fatalf("startedAt mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.UpdatedAt, timeline.updatedAt; !got.Equal(want) {
|
||||||
|
t.Fatalf("updatedAt mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.StepDuration, 4*time.Second; got != want {
|
||||||
|
t.Fatalf("step duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.InstantRateBPS, 15.0; got != want {
|
||||||
|
t.Fatalf("instant rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,95 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestClientPublishSendFileEventObserverOnly(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
|
||||||
|
var observed []FileEvent
|
||||||
|
var handled []FileEvent
|
||||||
|
client.setFileEventObserver(func(event FileEvent) {
|
||||||
|
observed = append(observed, event)
|
||||||
|
})
|
||||||
|
client.SetFileHandler(func(event FileEvent) {
|
||||||
|
handled = append(handled, event)
|
||||||
|
})
|
||||||
|
|
||||||
|
event := FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "send-1", Size: 32},
|
||||||
|
}
|
||||||
|
client.publishSendFileEvent(event)
|
||||||
|
|
||||||
|
if got, want := len(observed), 1; got != want {
|
||||||
|
t.Fatalf("observed count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(handled), 0; got != want {
|
||||||
|
t.Fatalf("handled count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := observed[0].Packet.FileID, "send-1"; got != want {
|
||||||
|
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientPublishReceivedFileEventObserverAndHandler(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
|
||||||
|
var observed []FileEvent
|
||||||
|
var handled []FileEvent
|
||||||
|
client.setFileEventObserver(func(event FileEvent) {
|
||||||
|
observed = append(observed, event)
|
||||||
|
})
|
||||||
|
client.SetFileHandler(func(event FileEvent) {
|
||||||
|
handled = append(handled, event)
|
||||||
|
})
|
||||||
|
|
||||||
|
event := FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "recv-1", Size: 64},
|
||||||
|
Received: 64,
|
||||||
|
Done: true,
|
||||||
|
}
|
||||||
|
client.publishReceivedFileEvent(event)
|
||||||
|
|
||||||
|
if got, want := len(observed), 1; got != want {
|
||||||
|
t.Fatalf("observed count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(handled), 1; got != want {
|
||||||
|
t.Fatalf("handled count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := observed[0].Packet.FileID, "recv-1"; got != want {
|
||||||
|
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := handled[0].Packet.FileID, "recv-1"; got != want {
|
||||||
|
t.Fatalf("handled fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerPublishSendFileEventObserverOnly(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
|
||||||
|
var observed []FileEvent
|
||||||
|
var handled []FileEvent
|
||||||
|
server.setFileEventObserver(func(event FileEvent) {
|
||||||
|
observed = append(observed, event)
|
||||||
|
})
|
||||||
|
server.SetFileHandler(func(event FileEvent) {
|
||||||
|
handled = append(handled, event)
|
||||||
|
})
|
||||||
|
|
||||||
|
event := FileEvent{
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "server-send-1", Size: 128},
|
||||||
|
}
|
||||||
|
server.publishSendFileEvent(event)
|
||||||
|
|
||||||
|
if got, want := len(observed), 1; got != want {
|
||||||
|
t.Fatalf("observed count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(handled), 0; got != want {
|
||||||
|
t.Fatalf("handled count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := observed[0].Packet.FileID, "server-send-1"; got != want {
|
||||||
|
t.Fatalf("observed fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileReceiveCheckpoint struct {
|
||||||
|
FileID string `json:"file_id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
Mode uint32 `json:"mode"`
|
||||||
|
ModTime int64 `json:"mod_time"`
|
||||||
|
Checksum string `json:"checksum"`
|
||||||
|
Received int64 `json:"received"`
|
||||||
|
TmpPath string `json:"tmp_path"`
|
||||||
|
FinalPath string `json:"final_path"`
|
||||||
|
StartedAt int64 `json:"started_at"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
PreviousUpdatedAt int64 `json:"previous_updated_at"`
|
||||||
|
PreviousReceived int64 `json:"previous_received"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) restoreCheckpointLocked(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, bool, error) {
|
||||||
|
checkpoint, ok, err := p.loadCheckpointLocked(scope, packet.FileID)
|
||||||
|
if err != nil || !ok {
|
||||||
|
return nil, ok, err
|
||||||
|
}
|
||||||
|
name := filepath.Base(packet.Name)
|
||||||
|
if name == "." || name == "/" || name == "" {
|
||||||
|
name = "unnamed.bin"
|
||||||
|
}
|
||||||
|
if checkpoint.FileID != packet.FileID || checkpoint.Name != name || checkpoint.Size != packet.Size || !strings.EqualFold(checkpoint.Checksum, packet.Checksum) {
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
if checkpoint.TmpPath != "" {
|
||||||
|
_ = os.Remove(checkpoint.TmpPath)
|
||||||
|
}
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
if checkpoint.TmpPath == "" {
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
info, statErr := os.Stat(checkpoint.TmpPath)
|
||||||
|
if statErr != nil {
|
||||||
|
if checkpoint.FinalPath != "" && pathExists(checkpoint.FinalPath) {
|
||||||
|
session := checkpoint.toSession(now)
|
||||||
|
session.tmpPath = checkpoint.FinalPath
|
||||||
|
session.finalPath = checkpoint.FinalPath
|
||||||
|
session.received = session.size
|
||||||
|
p.completed[fileReceiveKey(scope, packet.FileID)] = session.copy()
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
return session.copy(), true, nil
|
||||||
|
}
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
received := info.Size()
|
||||||
|
if received < 0 {
|
||||||
|
received = 0
|
||||||
|
}
|
||||||
|
if packet.Size > 0 && received > packet.Size {
|
||||||
|
received = packet.Size
|
||||||
|
}
|
||||||
|
session := checkpoint.toSession(now)
|
||||||
|
session.name = name
|
||||||
|
session.mode = os.FileMode(packet.Mode)
|
||||||
|
session.modTime = filePacketModTime(packet)
|
||||||
|
session.checksum = packet.Checksum
|
||||||
|
session.received = received
|
||||||
|
if session.finalPath == "" || (session.finalPath != session.tmpPath && pathExists(session.finalPath)) {
|
||||||
|
session.finalPath = p.uniqueFinalPathLocked(p.receiveDirLocked(), name, packet.FileID)
|
||||||
|
}
|
||||||
|
p.sessions[fileReceiveKey(scope, packet.FileID)] = session
|
||||||
|
if session.received != checkpoint.Received || session.finalPath != checkpoint.FinalPath {
|
||||||
|
if err := p.saveCheckpointLocked(scope, session); err != nil {
|
||||||
|
return nil, true, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return session.copy(), true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) loadCheckpointLocked(scope string, fileID string) (fileReceiveCheckpoint, bool, error) {
|
||||||
|
path := p.checkpointPathLocked(scope, fileID)
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return fileReceiveCheckpoint{}, false, nil
|
||||||
|
}
|
||||||
|
return fileReceiveCheckpoint{}, false, err
|
||||||
|
}
|
||||||
|
var checkpoint fileReceiveCheckpoint
|
||||||
|
if err := json.Unmarshal(data, &checkpoint); err != nil {
|
||||||
|
_ = os.Remove(path)
|
||||||
|
return fileReceiveCheckpoint{}, false, nil
|
||||||
|
}
|
||||||
|
return checkpoint, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) saveCheckpointLocked(scope string, session *fileReceiveSession) error {
|
||||||
|
if p == nil || session == nil || session.fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
path := p.checkpointPathLocked(scope, session.fileID)
|
||||||
|
checkpoint := fileReceiveCheckpoint{
|
||||||
|
FileID: session.fileID,
|
||||||
|
Name: session.name,
|
||||||
|
Size: session.size,
|
||||||
|
Mode: uint32(session.mode.Perm()),
|
||||||
|
ModTime: session.modTime.UnixNano(),
|
||||||
|
Checksum: session.checksum,
|
||||||
|
Received: session.received,
|
||||||
|
TmpPath: session.tmpPath,
|
||||||
|
FinalPath: session.finalPath,
|
||||||
|
StartedAt: session.startedAt.UnixNano(),
|
||||||
|
UpdatedAt: session.updatedAt.UnixNano(),
|
||||||
|
PreviousUpdatedAt: session.previousUpdatedAt.UnixNano(),
|
||||||
|
PreviousReceived: session.previousReceived,
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(checkpoint)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
tmpPath := path + ".tmp"
|
||||||
|
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.Rename(tmpPath, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) removeCheckpointLocked(scope string, fileID string) {
|
||||||
|
if p == nil || fileID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_ = os.Remove(p.checkpointPathLocked(scope, fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) checkpointPathLocked(scope string, fileID string) string {
|
||||||
|
baseDir := p.receiveDirLocked()
|
||||||
|
sum := sha256.Sum256([]byte(fileReceiveKey(scope, fileID)))
|
||||||
|
return filepath.Join(baseDir, ".notify_recv_"+hex.EncodeToString(sum[:8])+".json")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (checkpoint fileReceiveCheckpoint) toSession(now time.Time) *fileReceiveSession {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
session := &fileReceiveSession{
|
||||||
|
fileID: checkpoint.FileID,
|
||||||
|
name: checkpoint.Name,
|
||||||
|
size: checkpoint.Size,
|
||||||
|
mode: os.FileMode(checkpoint.Mode),
|
||||||
|
modTime: time.Unix(0, checkpoint.ModTime),
|
||||||
|
checksum: checkpoint.Checksum,
|
||||||
|
received: checkpoint.Received,
|
||||||
|
tmpPath: checkpoint.TmpPath,
|
||||||
|
finalPath: checkpoint.FinalPath,
|
||||||
|
previousReceived: checkpoint.PreviousReceived,
|
||||||
|
}
|
||||||
|
session.startedAt = unixNanoTime(checkpoint.StartedAt)
|
||||||
|
session.updatedAt = unixNanoTime(checkpoint.UpdatedAt)
|
||||||
|
session.previousUpdatedAt = unixNanoTime(checkpoint.PreviousUpdatedAt)
|
||||||
|
if session.startedAt.IsZero() {
|
||||||
|
session.startedAt = now
|
||||||
|
}
|
||||||
|
if session.updatedAt.IsZero() {
|
||||||
|
session.updatedAt = now
|
||||||
|
}
|
||||||
|
return session
|
||||||
|
}
|
||||||
@@ -0,0 +1,147 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func computeFileChecksum(path string) (string, error) {
|
||||||
|
fd, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer fd.Close()
|
||||||
|
h := sha256.New()
|
||||||
|
if _, err := io.Copy(h, fd); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(h.Sum(nil)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func filePacketModTime(packet FilePacket) time.Time {
|
||||||
|
if packet.ModTime <= 0 {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return time.Unix(0, packet.ModTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func applyReceivedFileMeta(path string, mode os.FileMode, modTime time.Time) {
|
||||||
|
if mode != 0 {
|
||||||
|
_ = os.Chmod(path, mode.Perm())
|
||||||
|
}
|
||||||
|
if !modTime.IsZero() {
|
||||||
|
_ = os.Chtimes(path, modTime, modTime)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeFileName(name string) string {
|
||||||
|
trimmed := strings.TrimSpace(name)
|
||||||
|
if trimmed == "" {
|
||||||
|
return "unnamed"
|
||||||
|
}
|
||||||
|
trimmed = strings.ReplaceAll(trimmed, "/", "_")
|
||||||
|
trimmed = strings.ReplaceAll(trimmed, "\\", "_")
|
||||||
|
trimmed = strings.ReplaceAll(trimmed, ":", "_")
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func shortFileIDSuffix(fileID string) string {
|
||||||
|
cleaned := sanitizeFileName(fileID)
|
||||||
|
if len(cleaned) > 12 {
|
||||||
|
return cleaned[:12]
|
||||||
|
}
|
||||||
|
if cleaned == "" {
|
||||||
|
return "copy"
|
||||||
|
}
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
func pathExists(path string) bool {
|
||||||
|
_, err := os.Stat(path)
|
||||||
|
return err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) receiveDirLocked() string {
|
||||||
|
if p.dir != "" {
|
||||||
|
return p.dir
|
||||||
|
}
|
||||||
|
return os.TempDir()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) uniqueFinalPathLocked(baseDir string, name string, fileID string) string {
|
||||||
|
cleanName := sanitizeFileName(filepath.Base(name))
|
||||||
|
if cleanName == "" {
|
||||||
|
cleanName = "unnamed.bin"
|
||||||
|
}
|
||||||
|
ext := filepath.Ext(cleanName)
|
||||||
|
base := strings.TrimSuffix(cleanName, ext)
|
||||||
|
candidate := filepath.Join(baseDir, cleanName)
|
||||||
|
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
suffix := shortFileIDSuffix(fileID)
|
||||||
|
candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s%s", base, suffix, ext))
|
||||||
|
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
for i := 1; ; i++ {
|
||||||
|
candidate = filepath.Join(baseDir, fmt.Sprintf("%s.%s.%d%s", base, suffix, i, ext))
|
||||||
|
if !p.pathReservedLocked(candidate) && !pathExists(candidate) {
|
||||||
|
return candidate
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) pathReservedLocked(path string) bool {
|
||||||
|
for _, session := range p.sessions {
|
||||||
|
if session.finalPath == path || session.tmpPath == path {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) trimCompletedLocked() {
|
||||||
|
if p.completedLimit <= 0 || len(p.completed) <= p.completedLimit {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for len(p.completed) > p.completedLimit {
|
||||||
|
oldestKey := ""
|
||||||
|
oldestTime := time.Time{}
|
||||||
|
for key, session := range p.completed {
|
||||||
|
candidateTime := completedFileReceiveTime(session)
|
||||||
|
if oldestKey == "" || candidateTime.Before(oldestTime) || (candidateTime.Equal(oldestTime) && key < oldestKey) {
|
||||||
|
oldestKey = key
|
||||||
|
oldestTime = candidateTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if oldestKey == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(p.completed, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func completedFileReceiveTime(session *fileReceiveSession) time.Time {
|
||||||
|
if session == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
if !session.updatedAt.IsZero() {
|
||||||
|
return session.updatedAt
|
||||||
|
}
|
||||||
|
return session.startedAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileReceiveSession) copy() *fileReceiveSession {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dup := *s
|
||||||
|
return &dup
|
||||||
|
}
|
||||||
@@ -0,0 +1,278 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileReceiveSession struct {
|
||||||
|
fileID string
|
||||||
|
name string
|
||||||
|
size int64
|
||||||
|
mode os.FileMode
|
||||||
|
modTime time.Time
|
||||||
|
checksum string
|
||||||
|
received int64
|
||||||
|
tmpPath string
|
||||||
|
finalPath string
|
||||||
|
startedAt time.Time
|
||||||
|
updatedAt time.Time
|
||||||
|
previousUpdatedAt time.Time
|
||||||
|
previousReceived int64
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultFileReceiveCompletedLimit = 128
|
||||||
|
|
||||||
|
type fileReceivePool struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
dir string
|
||||||
|
sessions map[string]*fileReceiveSession
|
||||||
|
completed map[string]*fileReceiveSession
|
||||||
|
completedLimit int
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileReceiveKey(scope string, fileID string) string {
|
||||||
|
return normalizeFileScope(scope) + "|" + fileID
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileReceivePool() *fileReceivePool {
|
||||||
|
return newFileReceivePoolWithConfig(defaultFileTransferConfig())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileReceivePoolWithConfig(cfg fileTransferConfig) *fileReceivePool {
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
return newFileReceivePoolWithCompletedLimit(cfg.ReceiveCompletedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileReceivePoolWithCompletedLimit(limit int) *fileReceivePool {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = defaultFileReceiveCompletedLimit
|
||||||
|
}
|
||||||
|
return &fileReceivePool{
|
||||||
|
sessions: make(map[string]*fileReceiveSession),
|
||||||
|
completed: make(map[string]*fileReceiveSession),
|
||||||
|
completedLimit: limit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) applyConfig(cfg fileTransferConfig) {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
p.mu.Lock()
|
||||||
|
p.completedLimit = cfg.ReceiveCompletedLimit
|
||||||
|
p.trimCompletedLocked()
|
||||||
|
p.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) setDir(dir string) error {
|
||||||
|
cleaned := strings.TrimSpace(dir)
|
||||||
|
if cleaned == "" {
|
||||||
|
p.mu.Lock()
|
||||||
|
p.dir = ""
|
||||||
|
p.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
cleaned = filepath.Clean(cleaned)
|
||||||
|
if err := os.MkdirAll(cleaned, 0o755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
info, err := os.Stat(cleaned)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
return errors.New("file receive path is not a directory")
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
p.dir = cleaned
|
||||||
|
p.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) onMeta(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
|
||||||
|
if packet.FileID == "" {
|
||||||
|
return nil, errors.New("empty file id")
|
||||||
|
}
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||||
|
name := filepath.Base(packet.Name)
|
||||||
|
if name == "." || name == "/" || name == "" {
|
||||||
|
name = "unnamed.bin"
|
||||||
|
}
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
if old, ok := p.completed[sessionKey]; ok {
|
||||||
|
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
|
||||||
|
return old.copy(), nil
|
||||||
|
}
|
||||||
|
delete(p.completed, sessionKey)
|
||||||
|
}
|
||||||
|
if old, ok := p.sessions[sessionKey]; ok {
|
||||||
|
if old.name == name && old.size == packet.Size && old.checksum == packet.Checksum {
|
||||||
|
return old.copy(), nil
|
||||||
|
}
|
||||||
|
_ = os.Remove(old.tmpPath)
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
delete(p.sessions, sessionKey)
|
||||||
|
}
|
||||||
|
if restored, ok, err := p.restoreCheckpointLocked(scope, packet, now); ok || err != nil {
|
||||||
|
return restored, err
|
||||||
|
}
|
||||||
|
baseDir := p.receiveDirLocked()
|
||||||
|
finalPath := p.uniqueFinalPathLocked(baseDir, name, packet.FileID)
|
||||||
|
prefix := "notify_recv_" + sanitizeFileName(name) + "_"
|
||||||
|
tmp, err := os.CreateTemp(baseDir, prefix+"*.part")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_ = tmp.Close()
|
||||||
|
session := &fileReceiveSession{
|
||||||
|
fileID: packet.FileID,
|
||||||
|
name: name,
|
||||||
|
size: packet.Size,
|
||||||
|
mode: os.FileMode(packet.Mode),
|
||||||
|
modTime: filePacketModTime(packet),
|
||||||
|
checksum: packet.Checksum,
|
||||||
|
received: 0,
|
||||||
|
tmpPath: tmp.Name(),
|
||||||
|
finalPath: finalPath,
|
||||||
|
startedAt: now,
|
||||||
|
updatedAt: now,
|
||||||
|
}
|
||||||
|
p.sessions[sessionKey] = session
|
||||||
|
if err := p.saveCheckpointLocked(scope, session); err != nil {
|
||||||
|
_ = os.Remove(session.tmpPath)
|
||||||
|
delete(p.sessions, sessionKey)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session.copy(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) onChunk(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
session, ok := p.sessions[sessionKey]
|
||||||
|
if !ok {
|
||||||
|
if completed, ok := p.completed[sessionKey]; ok {
|
||||||
|
return completed.copy(), nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("unknown file id")
|
||||||
|
}
|
||||||
|
if packet.Offset < session.received {
|
||||||
|
return session.copy(), nil
|
||||||
|
}
|
||||||
|
if packet.Offset > session.received {
|
||||||
|
return nil, errors.New("chunk offset mismatch")
|
||||||
|
}
|
||||||
|
if len(packet.Chunk) == 0 {
|
||||||
|
return session.copy(), nil
|
||||||
|
}
|
||||||
|
prevUpdatedAt := session.updatedAt
|
||||||
|
prevReceived := session.received
|
||||||
|
fd, err := os.OpenFile(session.tmpPath, os.O_WRONLY|os.O_APPEND, 0o600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer fd.Close()
|
||||||
|
n, err := fd.Write(packet.Chunk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
session.received += int64(n)
|
||||||
|
session.previousUpdatedAt = prevUpdatedAt
|
||||||
|
session.previousReceived = prevReceived
|
||||||
|
session.updatedAt = now
|
||||||
|
if err := p.saveCheckpointLocked(scope, session); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return session.copy(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) onEnd(scope string, packet FilePacket, now time.Time) (string, *fileReceiveSession, error) {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
session, ok := p.sessions[sessionKey]
|
||||||
|
if !ok {
|
||||||
|
if completed, ok := p.completed[sessionKey]; ok {
|
||||||
|
return completed.finalPath, completed.copy(), nil
|
||||||
|
}
|
||||||
|
return "", nil, errors.New("unknown file id")
|
||||||
|
}
|
||||||
|
if session.size > 0 && session.received != session.size {
|
||||||
|
return "", session.copy(), errors.New("file size not match")
|
||||||
|
}
|
||||||
|
if session.checksum != "" {
|
||||||
|
sum, err := computeFileChecksum(session.tmpPath)
|
||||||
|
if err != nil {
|
||||||
|
return "", session.copy(), err
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(sum, session.checksum) {
|
||||||
|
_ = os.Remove(session.tmpPath)
|
||||||
|
delete(p.sessions, sessionKey)
|
||||||
|
return "", session.copy(), errors.New("file checksum not match")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
finalPath := session.finalPath
|
||||||
|
baseDir := filepath.Dir(session.tmpPath)
|
||||||
|
if baseDir == "" || baseDir == "." {
|
||||||
|
baseDir = p.receiveDirLocked()
|
||||||
|
}
|
||||||
|
if finalPath == "" || pathExists(finalPath) {
|
||||||
|
finalPath = p.uniqueFinalPathLocked(baseDir, session.name, packet.FileID)
|
||||||
|
}
|
||||||
|
if err := os.Rename(session.tmpPath, finalPath); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
session.previousUpdatedAt = session.updatedAt
|
||||||
|
session.previousReceived = session.received
|
||||||
|
session.updatedAt = now
|
||||||
|
applyReceivedFileMeta(finalPath, session.mode, session.modTime)
|
||||||
|
delete(p.sessions, sessionKey)
|
||||||
|
session.tmpPath = finalPath
|
||||||
|
session.finalPath = finalPath
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
p.completed[sessionKey] = session.copy()
|
||||||
|
p.trimCompletedLocked()
|
||||||
|
return finalPath, session.copy(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *fileReceivePool) onAbort(scope string, packet FilePacket, now time.Time) (*fileReceiveSession, error) {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
sessionKey := fileReceiveKey(scope, packet.FileID)
|
||||||
|
p.mu.Lock()
|
||||||
|
defer p.mu.Unlock()
|
||||||
|
session, ok := p.sessions[sessionKey]
|
||||||
|
if !ok {
|
||||||
|
if completed, ok := p.completed[sessionKey]; ok {
|
||||||
|
return completed.copy(), nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
session.previousUpdatedAt = session.updatedAt
|
||||||
|
session.previousReceived = session.received
|
||||||
|
session.updatedAt = now
|
||||||
|
dup := session.copy()
|
||||||
|
_ = os.Remove(session.tmpPath)
|
||||||
|
p.removeCheckpointLocked(scope, packet.FileID)
|
||||||
|
delete(p.sessions, sessionKey)
|
||||||
|
delete(p.completed, sessionKey)
|
||||||
|
return dup, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getFileReceivePool() *fileReceivePool {
|
||||||
|
return c.getLogicalSessionState().fileReceives
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getFileReceivePool() *fileReceivePool {
|
||||||
|
return s.getLogicalSessionState().fileReceives
|
||||||
|
}
|
||||||
@@ -0,0 +1,520 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileReceivePoolUsesConfiguredDirAndStableName(t *testing.T) {
|
||||||
|
pool := newFileReceivePool()
|
||||||
|
scope := "client:test"
|
||||||
|
dir := t.TempDir()
|
||||||
|
now := time.Now()
|
||||||
|
if err := pool.setDir(dir); err != nil {
|
||||||
|
t.Fatalf("setDir failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte("hello notify")
|
||||||
|
meta := FilePacket{
|
||||||
|
FileID: "file-1",
|
||||||
|
Name: "greeting.txt",
|
||||||
|
Size: int64(len(payload)),
|
||||||
|
Checksum: testFileChecksum(payload),
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := pool.onMeta(scope, meta, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onMeta failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := filepath.Dir(session.tmpPath), dir; got != want {
|
||||||
|
t.Fatalf("tmp dir mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := session.finalPath, filepath.Join(dir, "greeting.txt"); got != want {
|
||||||
|
t.Fatalf("final path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err = pool.onChunk(scope, FilePacket{
|
||||||
|
FileID: meta.FileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payload,
|
||||||
|
}, now.Add(time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onChunk failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := session.received, int64(len(payload)); got != want {
|
||||||
|
t.Fatalf("received mismatch after chunk: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPath, session, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onEnd failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := finalPath, filepath.Join(dir, "greeting.txt"); got != want {
|
||||||
|
t.Fatalf("completed path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := session.finalPath, finalPath; got != want {
|
||||||
|
t.Fatalf("session final path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
gotData, err := os.ReadFile(finalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotData, payload) {
|
||||||
|
t.Fatalf("completed file content mismatch: got %q want %q", gotData, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
dupMeta, err := pool.onMeta(scope, meta, now.Add(3*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("duplicate onMeta failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := dupMeta.finalPath, finalPath; got != want {
|
||||||
|
t.Fatalf("duplicate meta final path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
dupChunk, err := pool.onChunk(scope, FilePacket{
|
||||||
|
FileID: meta.FileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payload,
|
||||||
|
}, now.Add(4*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("duplicate onChunk failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := dupChunk.received, int64(len(payload)); got != want {
|
||||||
|
t.Fatalf("duplicate chunk received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
dupPath, dupEnd, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("duplicate onEnd failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := dupPath, finalPath; got != want {
|
||||||
|
t.Fatalf("duplicate end path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := dupEnd.finalPath, finalPath; got != want {
|
||||||
|
t.Fatalf("duplicate end session final path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileReceivePoolAvoidsOverwriteWhenFinalPathBecomesBusy(t *testing.T) {
|
||||||
|
pool := newFileReceivePool()
|
||||||
|
scope := "client:test"
|
||||||
|
dir := t.TempDir()
|
||||||
|
now := time.Now()
|
||||||
|
if err := pool.setDir(dir); err != nil {
|
||||||
|
t.Fatalf("setDir failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte("new report payload")
|
||||||
|
meta := FilePacket{
|
||||||
|
FileID: "file-2",
|
||||||
|
Name: "report.txt",
|
||||||
|
Size: int64(len(payload)),
|
||||||
|
Checksum: testFileChecksum(payload),
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := pool.onMeta(scope, meta, now)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onMeta failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
occupiedPath := session.finalPath
|
||||||
|
occupiedContent := []byte("existing report")
|
||||||
|
if err := os.WriteFile(occupiedPath, occupiedContent, 0o644); err != nil {
|
||||||
|
t.Fatalf("WriteFile occupied path failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := pool.onChunk(scope, FilePacket{
|
||||||
|
FileID: meta.FileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payload,
|
||||||
|
}, now.Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("onChunk failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onEnd failed: %v", err)
|
||||||
|
}
|
||||||
|
if finalPath == occupiedPath {
|
||||||
|
t.Fatalf("expected final path to avoid occupied path %q", occupiedPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotOccupied, err := os.ReadFile(occupiedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile occupied path failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotOccupied, occupiedContent) {
|
||||||
|
t.Fatalf("occupied file content changed: got %q want %q", gotOccupied, occupiedContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotFinal, err := os.ReadFile(finalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile final path failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotFinal, payload) {
|
||||||
|
t.Fatalf("final file content mismatch: got %q want %q", gotFinal, payload)
|
||||||
|
}
|
||||||
|
if got, want := filepath.Dir(finalPath), dir; got != want {
|
||||||
|
t.Fatalf("final dir mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileReceivePoolAbortAfterCompletionKeepsDeliveredFile(t *testing.T) {
|
||||||
|
pool := newFileReceivePool()
|
||||||
|
scope := "client:test"
|
||||||
|
dir := t.TempDir()
|
||||||
|
now := time.Now()
|
||||||
|
if err := pool.setDir(dir); err != nil {
|
||||||
|
t.Fatalf("setDir failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte("keep me")
|
||||||
|
meta := FilePacket{
|
||||||
|
FileID: "file-3",
|
||||||
|
Name: "keep.txt",
|
||||||
|
Size: int64(len(payload)),
|
||||||
|
Checksum: testFileChecksum(payload),
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := pool.onMeta(scope, meta, now); err != nil {
|
||||||
|
t.Fatalf("onMeta failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := pool.onChunk(scope, FilePacket{
|
||||||
|
FileID: meta.FileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payload,
|
||||||
|
}, now.Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("onChunk failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onEnd failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := pool.onAbort(scope, FilePacket{FileID: meta.FileID}, now.Add(3*time.Second)); err != nil {
|
||||||
|
t.Fatalf("onAbort failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotData, err := os.ReadFile(finalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile final path after abort failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotData, payload) {
|
||||||
|
t.Fatalf("final file content mismatch after abort: got %q want %q", gotData, payload)
|
||||||
|
}
|
||||||
|
|
||||||
|
dupPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(4*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("duplicate onEnd after abort failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := dupPath, finalPath; got != want {
|
||||||
|
t.Fatalf("duplicate end path mismatch after abort: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileReceivePoolAppliesMetaModeAndModTime(t *testing.T) {
|
||||||
|
pool := newFileReceivePool()
|
||||||
|
scope := "client:test"
|
||||||
|
dir := t.TempDir()
|
||||||
|
now := time.Now()
|
||||||
|
if err := pool.setDir(dir); err != nil {
|
||||||
|
t.Fatalf("setDir failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
payload := []byte("meta test")
|
||||||
|
wantMode := os.FileMode(0o640)
|
||||||
|
wantTime := time.Now().Add(-2 * time.Hour).Truncate(time.Second)
|
||||||
|
meta := FilePacket{
|
||||||
|
FileID: "file-meta",
|
||||||
|
Name: "meta.txt",
|
||||||
|
Size: int64(len(payload)),
|
||||||
|
Checksum: testFileChecksum(payload),
|
||||||
|
Mode: uint32(wantMode),
|
||||||
|
ModTime: wantTime.UnixNano(),
|
||||||
|
}
|
||||||
|
if _, err := pool.onMeta(scope, meta, now); err != nil {
|
||||||
|
t.Fatalf("onMeta failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := pool.onChunk(scope, FilePacket{
|
||||||
|
FileID: meta.FileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payload,
|
||||||
|
}, now.Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("onChunk failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPath, _, err := pool.onEnd(scope, FilePacket{FileID: meta.FileID}, now.Add(2*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onEnd failed: %v", err)
|
||||||
|
}
|
||||||
|
info, err := os.Stat(finalPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Stat failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := info.Mode().Perm(), wantMode; got != want {
|
||||||
|
t.Fatalf("mode mismatch: got %o want %o", got, want)
|
||||||
|
}
|
||||||
|
gotMTime := info.ModTime().Truncate(time.Second)
|
||||||
|
if got, want := gotMTime, wantTime; !got.Equal(want) {
|
||||||
|
t.Fatalf("mtime mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileReceivePoolScopeIsolation(t *testing.T) {
|
||||||
|
pool := newFileReceivePool()
|
||||||
|
dir := t.TempDir()
|
||||||
|
now := time.Now()
|
||||||
|
if err := pool.setDir(dir); err != nil {
|
||||||
|
t.Fatalf("setDir failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
const sharedFileID = "shared-file-id"
|
||||||
|
payloadA := []byte("from client A")
|
||||||
|
payloadB := []byte("from client B")
|
||||||
|
metaA := FilePacket{
|
||||||
|
FileID: sharedFileID,
|
||||||
|
Name: "shared.txt",
|
||||||
|
Size: int64(len(payloadA)),
|
||||||
|
Checksum: testFileChecksum(payloadA),
|
||||||
|
}
|
||||||
|
metaB := FilePacket{
|
||||||
|
FileID: sharedFileID,
|
||||||
|
Name: "shared.txt",
|
||||||
|
Size: int64(len(payloadB)),
|
||||||
|
Checksum: testFileChecksum(payloadB),
|
||||||
|
}
|
||||||
|
|
||||||
|
scopeA := "server:client-a"
|
||||||
|
scopeB := "server:client-b"
|
||||||
|
if _, err := pool.onMeta(scopeA, metaA, now); err != nil {
|
||||||
|
t.Fatalf("onMeta scopeA failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := pool.onMeta(scopeB, metaB, now); err != nil {
|
||||||
|
t.Fatalf("onMeta scopeB failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := pool.onChunk(scopeA, FilePacket{
|
||||||
|
FileID: sharedFileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payloadA,
|
||||||
|
}, now.Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("onChunk scopeA failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := pool.onChunk(scopeB, FilePacket{
|
||||||
|
FileID: sharedFileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payloadB,
|
||||||
|
}, now.Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("onChunk scopeB failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
finalPathA, _, err := pool.onEnd(scopeA, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onEnd scopeA failed: %v", err)
|
||||||
|
}
|
||||||
|
finalPathB, _, err := pool.onEnd(scopeB, FilePacket{FileID: sharedFileID}, now.Add(2*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onEnd scopeB failed: %v", err)
|
||||||
|
}
|
||||||
|
if finalPathA == finalPathB {
|
||||||
|
t.Fatalf("scope-isolated files should not share path: %q", finalPathA)
|
||||||
|
}
|
||||||
|
|
||||||
|
gotA, err := os.ReadFile(finalPathA)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile scopeA failed: %v", err)
|
||||||
|
}
|
||||||
|
gotB, err := os.ReadFile(finalPathB)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile scopeB failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotA, payloadA) {
|
||||||
|
t.Fatalf("scopeA content mismatch: got %q want %q", gotA, payloadA)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(gotB, payloadB) {
|
||||||
|
t.Fatalf("scopeB content mismatch: got %q want %q", gotB, payloadB)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileReceivePoolCompletedRetentionEvictsOldest(t *testing.T) {
|
||||||
|
pool := newFileReceivePoolWithCompletedLimit(2)
|
||||||
|
dir := t.TempDir()
|
||||||
|
now := time.Now()
|
||||||
|
scope := "client:test"
|
||||||
|
if err := pool.setDir(dir); err != nil {
|
||||||
|
t.Fatalf("setDir failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
complete := func(fileID string, offset time.Duration) {
|
||||||
|
payload := []byte("payload-" + fileID)
|
||||||
|
meta := FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
Name: fileID + ".txt",
|
||||||
|
Size: int64(len(payload)),
|
||||||
|
Checksum: testFileChecksum(payload),
|
||||||
|
}
|
||||||
|
eventTime := now.Add(offset)
|
||||||
|
if _, err := pool.onMeta(scope, meta, eventTime); err != nil {
|
||||||
|
t.Fatalf("onMeta %s failed: %v", fileID, err)
|
||||||
|
}
|
||||||
|
if _, err := pool.onChunk(scope, FilePacket{
|
||||||
|
FileID: fileID,
|
||||||
|
Offset: 0,
|
||||||
|
Chunk: payload,
|
||||||
|
}, eventTime.Add(time.Second)); err != nil {
|
||||||
|
t.Fatalf("onChunk %s failed: %v", fileID, err)
|
||||||
|
}
|
||||||
|
if _, _, err := pool.onEnd(scope, FilePacket{FileID: fileID}, eventTime.Add(2*time.Second)); err != nil {
|
||||||
|
t.Fatalf("onEnd %s failed: %v", fileID, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
complete("done-1", 0)
|
||||||
|
complete("done-2", 10*time.Second)
|
||||||
|
|
||||||
|
activePayload := []byte("still-active")
|
||||||
|
if _, err := pool.onMeta(scope, FilePacket{
|
||||||
|
FileID: "active-1",
|
||||||
|
Name: "active-1.txt",
|
||||||
|
Size: int64(len(activePayload)),
|
||||||
|
Checksum: testFileChecksum(activePayload),
|
||||||
|
}, now.Add(20*time.Second)); err != nil {
|
||||||
|
t.Fatalf("onMeta active-1 failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
complete("done-3", 30*time.Second)
|
||||||
|
|
||||||
|
if got, want := len(pool.completed), 2; got != want {
|
||||||
|
t.Fatalf("completed size mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(pool.sessions), 1; got != want {
|
||||||
|
t.Fatalf("active session size mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if _, ok := pool.sessions[fileReceiveKey(scope, "active-1")]; !ok {
|
||||||
|
t.Fatal("active session should be retained")
|
||||||
|
}
|
||||||
|
if _, ok := pool.completed[fileReceiveKey(scope, "done-1")]; ok {
|
||||||
|
t.Fatal("oldest completed session should be evicted")
|
||||||
|
}
|
||||||
|
if _, ok := pool.completed[fileReceiveKey(scope, "done-2")]; !ok {
|
||||||
|
t.Fatal("newer completed session should be retained")
|
||||||
|
}
|
||||||
|
if _, ok := pool.completed[fileReceiveKey(scope, "done-3")]; !ok {
|
||||||
|
t.Fatal("latest completed session should be retained")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-1"}, now.Add(40*time.Second)); err == nil {
|
||||||
|
t.Fatal("evicted completed session should no longer resolve duplicate end")
|
||||||
|
}
|
||||||
|
if _, _, err := pool.onEnd(scope, FilePacket{FileID: "done-3"}, now.Add(41*time.Second)); err != nil {
|
||||||
|
t.Fatalf("latest completed session should still resolve duplicate end: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFillFileEventProgress(t *testing.T) {
|
||||||
|
event := FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{Size: 200},
|
||||||
|
Received: 50,
|
||||||
|
StartedAt: time.Unix(100, 0),
|
||||||
|
UpdatedAt: time.Unix(102, 0),
|
||||||
|
}
|
||||||
|
fillFileEventProgress(&event)
|
||||||
|
if got, want := event.Total, int64(200); got != want {
|
||||||
|
t.Fatalf("total mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.Percent, 25.0; got != want {
|
||||||
|
t.Fatalf("percent mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if event.Done {
|
||||||
|
t.Fatal("chunk event should not be done")
|
||||||
|
}
|
||||||
|
if got, want := event.Duration, 2*time.Second; got != want {
|
||||||
|
t.Fatalf("duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.RateBPS, 25.0; got != want {
|
||||||
|
t.Fatalf("rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
endEvent := FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{Size: 200},
|
||||||
|
Received: 180,
|
||||||
|
StartedAt: time.Unix(200, 0),
|
||||||
|
UpdatedAt: time.Unix(204, 0),
|
||||||
|
}
|
||||||
|
fillFileEventProgress(&endEvent)
|
||||||
|
if !endEvent.Done {
|
||||||
|
t.Fatal("end event should be done")
|
||||||
|
}
|
||||||
|
if got, want := endEvent.Received, int64(200); got != want {
|
||||||
|
t.Fatalf("end received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.Percent, 100.0; got != want {
|
||||||
|
t.Fatalf("end percent mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.Duration, 4*time.Second; got != want {
|
||||||
|
t.Fatalf("end duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.RateBPS, 50.0; got != want {
|
||||||
|
t.Fatalf("end rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
abortEvent := FileEvent{
|
||||||
|
Kind: EnvelopeFileAbort,
|
||||||
|
Packet: FilePacket{Size: 200},
|
||||||
|
Received: 60,
|
||||||
|
StartedAt: time.Unix(300, 0),
|
||||||
|
UpdatedAt: time.Unix(303, 0),
|
||||||
|
}
|
||||||
|
fillFileEventProgress(&abortEvent)
|
||||||
|
if abortEvent.Done {
|
||||||
|
t.Fatal("abort event should not be done")
|
||||||
|
}
|
||||||
|
if got, want := abortEvent.Percent, 30.0; got != want {
|
||||||
|
t.Fatalf("abort percent mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := abortEvent.Duration, 3*time.Second; got != want {
|
||||||
|
t.Fatalf("abort duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := abortEvent.RateBPS, 20.0; got != want {
|
||||||
|
t.Fatalf("abort rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFillFileEventTiming(t *testing.T) {
|
||||||
|
event := FileEvent{
|
||||||
|
Received: 120,
|
||||||
|
}
|
||||||
|
session := &fileReceiveSession{
|
||||||
|
startedAt: time.Unix(100, 0),
|
||||||
|
updatedAt: time.Unix(110, 0),
|
||||||
|
previousUpdatedAt: time.Unix(108, 0),
|
||||||
|
previousReceived: 80,
|
||||||
|
}
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
|
||||||
|
if got, want := event.StartedAt, session.startedAt; !got.Equal(want) {
|
||||||
|
t.Fatalf("startedAt mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.UpdatedAt, session.updatedAt; !got.Equal(want) {
|
||||||
|
t.Fatalf("updatedAt mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.StepDuration, 2*time.Second; got != want {
|
||||||
|
t.Fatalf("step duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := event.InstantRateBPS, 20.0; got != want {
|
||||||
|
t.Fatalf("instant rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func testFileChecksum(data []byte) string {
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultFileScope = "default"
|
||||||
|
clientFileDomain = "client"
|
||||||
|
serverFileDomain = "server"
|
||||||
|
serverTransportScopeSuffix = "#tg:"
|
||||||
|
)
|
||||||
|
|
||||||
|
func normalizeFileScope(scope string) string {
|
||||||
|
cleaned := strings.TrimSpace(scope)
|
||||||
|
if cleaned == "" {
|
||||||
|
return defaultFileScope
|
||||||
|
}
|
||||||
|
return cleaned
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientFileScope() string {
|
||||||
|
return clientFileDomain
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverFileScope(peer any) string {
|
||||||
|
logical := logicalConnFromPeer(peer)
|
||||||
|
if logical == nil {
|
||||||
|
return serverFileDomain + ":unknown"
|
||||||
|
}
|
||||||
|
id := strings.TrimSpace(logical.ID())
|
||||||
|
if id == "" {
|
||||||
|
return serverFileDomain + ":unknown"
|
||||||
|
}
|
||||||
|
return serverFileDomain + ":" + id
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverTransportScope(peer any) string {
|
||||||
|
logical := logicalConnFromPeer(peer)
|
||||||
|
if logical == nil {
|
||||||
|
return serverFileDomain + ":unknown"
|
||||||
|
}
|
||||||
|
return serverTransportScopeByGeneration(logical, logical.transportGenerationSnapshot())
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverTransportScopeForTransport(transport *TransportConn) string {
|
||||||
|
if transport == nil {
|
||||||
|
return serverFileDomain + ":unknown"
|
||||||
|
}
|
||||||
|
return transport.transportScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverTransportScopeByGeneration(peer any, generation uint64) string {
|
||||||
|
base := serverFileScope(peer)
|
||||||
|
if generation == 0 {
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
return base + serverTransportScopeSuffix + strconv.FormatUint(generation, 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverTransportDeliveryScopes(peer any) []string {
|
||||||
|
logical := logicalConnFromPeer(peer)
|
||||||
|
if logical == nil {
|
||||||
|
return []string{serverFileDomain + ":unknown"}
|
||||||
|
}
|
||||||
|
base := serverFileScope(logical)
|
||||||
|
transport := serverTransportScope(logical)
|
||||||
|
if transport == base {
|
||||||
|
return []string{base}
|
||||||
|
}
|
||||||
|
return []string{transport, base}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverTransportDeliveryScopesForTransport(transport *TransportConn) []string {
|
||||||
|
if transport == nil {
|
||||||
|
return []string{serverFileDomain + ":unknown"}
|
||||||
|
}
|
||||||
|
return transport.deliveryScopes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func scopeBelongsToServerFileScope(scope string, base string) bool {
|
||||||
|
scope = normalizeFileScope(scope)
|
||||||
|
base = normalizeFileScope(base)
|
||||||
|
return scope == base || strings.HasPrefix(scope, base+serverTransportScopeSuffix)
|
||||||
|
}
|
||||||
+451
@@ -0,0 +1,451 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultFileChunkSize = 64 * 1024
|
||||||
|
|
||||||
|
type fileSendHooks struct {
|
||||||
|
config fileTransferConfig
|
||||||
|
startSession func(*fileSendSession)
|
||||||
|
sendReliable func(context.Context, Envelope) error
|
||||||
|
sendAbort func(fileID string, stage string, offset int64, cause error) error
|
||||||
|
publishEvent func(FileEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileSendError struct {
|
||||||
|
stage string
|
||||||
|
offset int64
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fileSendError) Error() string {
|
||||||
|
if e == nil || e.err == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return e.err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *fileSendError) Unwrap() error {
|
||||||
|
if e == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return e.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) SendFile(ctx context.Context, filePath string) error {
|
||||||
|
target := transferSendTarget{
|
||||||
|
runtime: c.getTransferRuntime(),
|
||||||
|
runtimeScope: clientFileScope(),
|
||||||
|
publicScope: clientFileScope(),
|
||||||
|
transportGeneration: 0,
|
||||||
|
sequenceEn: c.sequenceEn,
|
||||||
|
sequenceDe: c.sequenceDe,
|
||||||
|
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
|
||||||
|
return c.OpenStream(ctx, opt)
|
||||||
|
},
|
||||||
|
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||||||
|
return SendTransferBeginClient(ctx, c, req)
|
||||||
|
},
|
||||||
|
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
|
||||||
|
return SendTransferResumeClient(ctx, c, req)
|
||||||
|
},
|
||||||
|
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
|
||||||
|
return SendTransferCommitClient(ctx, c, req)
|
||||||
|
},
|
||||||
|
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
|
||||||
|
return SendTransferAbortClient(ctx, c, req)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return c.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) {
|
||||||
|
event.NetType = NET_CLIENT
|
||||||
|
event.ServerConn = c
|
||||||
|
c.publishSendFileEventMonitorOnly(event)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) SendFile(ctx context.Context, client *ClientConn, filePath string) error {
|
||||||
|
return s.SendFileLogical(ctx, logicalConnFromClient(client), filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) SendFileLogical(ctx context.Context, client *LogicalConn, filePath string) error {
|
||||||
|
if client == nil {
|
||||||
|
return s.SendFileTransport(ctx, nil, filePath)
|
||||||
|
}
|
||||||
|
return s.SendFileTransport(ctx, s.resolveOutboundTransport(client), filePath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) SendFileTransport(ctx context.Context, transport *TransportConn, filePath string) error {
|
||||||
|
if transport == nil {
|
||||||
|
return transportDetachedErrorForTransport(transport)
|
||||||
|
}
|
||||||
|
logical := transport.logicalConnSnapshot()
|
||||||
|
if logical == nil || !transport.Attached() || !transport.IsCurrent() {
|
||||||
|
return transportDetachedErrorForTransport(transport)
|
||||||
|
}
|
||||||
|
target := transferSendTarget{
|
||||||
|
runtime: s.getTransferRuntime(),
|
||||||
|
runtimeScope: serverTransportScopeForTransport(transport),
|
||||||
|
publicScope: serverFileScope(logical),
|
||||||
|
transportGeneration: transport.TransportGeneration(),
|
||||||
|
logical: logical,
|
||||||
|
transport: transport,
|
||||||
|
sequenceEn: s.sequenceEn,
|
||||||
|
sequenceDe: s.sequenceDe,
|
||||||
|
openStream: func(ctx context.Context, opt StreamOpenOptions) (Stream, error) {
|
||||||
|
return s.OpenStreamTransport(ctx, transport, opt)
|
||||||
|
},
|
||||||
|
sendBegin: func(ctx context.Context, req TransferBeginRequest) (TransferBeginResponse, error) {
|
||||||
|
return SendTransferBeginTransport(ctx, s, transport, req)
|
||||||
|
},
|
||||||
|
sendResume: func(ctx context.Context, req TransferResumeRequest) (TransferResumeResponse, error) {
|
||||||
|
return SendTransferResumeTransport(ctx, s, transport, req)
|
||||||
|
},
|
||||||
|
sendCommit: func(ctx context.Context, req TransferCommitRequest) (TransferCommitResponse, error) {
|
||||||
|
return SendTransferCommitTransport(ctx, s, transport, req)
|
||||||
|
},
|
||||||
|
sendAbort: func(ctx context.Context, req TransferAbortRequest) (TransferAbortResponse, error) {
|
||||||
|
return SendTransferAbortTransport(ctx, s, transport, req)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return s.sendFileViaTransfer(ctx, filePath, target, func(event FileEvent) {
|
||||||
|
event.NetType = NET_SERVER
|
||||||
|
event.LogicalConn = logical
|
||||||
|
event.TransportConn = transport
|
||||||
|
s.publishSendFileEventMonitorOnly(event)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
|
||||||
|
return sendFileViaTransfer(ctx, filePath, target, publishEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
|
||||||
|
return sendFileViaTransfer(ctx, filePath, target, publishEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendFileViaTransfer(ctx context.Context, filePath string, target transferSendTarget, publishEvent func(FileEvent)) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
session, err := newFileSendSession(filePath, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
session.fileID = buildStableFileTransferID(session)
|
||||||
|
source, err := newTransferFileSource(filePath, session.size)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer source.Close()
|
||||||
|
if publishEvent != nil {
|
||||||
|
hooks := transferSendHooks{
|
||||||
|
onNegotiated: func(nextOffset int64, _ bool) {
|
||||||
|
session.syncProgress(nextOffset, time.Now())
|
||||||
|
publishEvent(session.onMetaSent(time.Now()))
|
||||||
|
},
|
||||||
|
onSegmentSent: func(offset int64, sentBytes int64) {
|
||||||
|
event, chunkErr := session.onChunkSent(offset, sentBytes, time.Now())
|
||||||
|
if chunkErr == nil {
|
||||||
|
publishEvent(event)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
onCommitted: func() {
|
||||||
|
publishEvent(session.onEndSent(time.Now()))
|
||||||
|
},
|
||||||
|
onAbort: func(stage string, offset int64, cause error) {
|
||||||
|
publishEvent(session.onAbort(stage, offset, cause, time.Now()))
|
||||||
|
},
|
||||||
|
}
|
||||||
|
handle, err := startTransferSendWithHooks(ctx, TransferSendOptions{
|
||||||
|
Descriptor: buildFileTransferDescriptor(session),
|
||||||
|
Source: source,
|
||||||
|
ChunkSize: defaultFileChunkSize,
|
||||||
|
VerifyChecksum: false,
|
||||||
|
}, target, hooks)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return handle.Wait(ctx)
|
||||||
|
}
|
||||||
|
handle, err := startTransferSend(ctx, TransferSendOptions{
|
||||||
|
Descriptor: buildFileTransferDescriptor(session),
|
||||||
|
Source: source,
|
||||||
|
ChunkSize: defaultFileChunkSize,
|
||||||
|
VerifyChecksum: false,
|
||||||
|
}, target)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return handle.Wait(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendFileWithHooks(ctx context.Context, filePath string, hooks fileSendHooks) error {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
hooks.config = normalizeFileTransferConfig(hooks.config)
|
||||||
|
session, err := newFileSendSession(filePath, time.Now())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hooks.startSession != nil {
|
||||||
|
hooks.startSession(session)
|
||||||
|
}
|
||||||
|
if err := sendFileMetaWithHooks(ctx, session, hooks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := sendFileChunksWithHooks(ctx, session, hooks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := sendFileEndWithHooks(ctx, session, hooks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileSendSession(filePath string, now time.Time) (*fileSendSession, error) {
|
||||||
|
fi, err := os.Stat(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if fi.IsDir() {
|
||||||
|
return nil, fmt.Errorf("file path is a directory: %s", filePath)
|
||||||
|
}
|
||||||
|
checksum, err := computeFileChecksum(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
name := filepath.Base(filePath)
|
||||||
|
if name == "" || name == "." || name == string(filepath.Separator) {
|
||||||
|
name = "unnamed.bin"
|
||||||
|
}
|
||||||
|
return &fileSendSession{
|
||||||
|
fileID: buildFileID(filePath),
|
||||||
|
path: filePath,
|
||||||
|
name: name,
|
||||||
|
size: fi.Size(),
|
||||||
|
mode: fi.Mode().Perm(),
|
||||||
|
modTime: fi.ModTime(),
|
||||||
|
checksum: checksum,
|
||||||
|
startedAt: now,
|
||||||
|
updatedAt: now,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileSendSession struct {
|
||||||
|
fileID string
|
||||||
|
path string
|
||||||
|
name string
|
||||||
|
size int64
|
||||||
|
mode os.FileMode
|
||||||
|
modTime time.Time
|
||||||
|
checksum string
|
||||||
|
sent int64
|
||||||
|
startedAt time.Time
|
||||||
|
updatedAt time.Time
|
||||||
|
previousUpdatedAt time.Time
|
||||||
|
previousSent int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) metaEnvelope() Envelope {
|
||||||
|
return newFileMetaEnvelope(s.fileID, s.name, s.size, s.checksum, uint32(s.mode.Perm()), s.modTime.UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) chunkEnvelope(offset int64, chunk []byte) Envelope {
|
||||||
|
return newFileChunkEnvelope(s.fileID, offset, chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) endEnvelope() Envelope {
|
||||||
|
return newFileEndEnvelope(s.fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) filePacket() FilePacket {
|
||||||
|
return FilePacket{
|
||||||
|
FileID: s.fileID,
|
||||||
|
Name: s.name,
|
||||||
|
Size: s.size,
|
||||||
|
Mode: uint32(s.mode.Perm()),
|
||||||
|
ModTime: s.modTime.UnixNano(),
|
||||||
|
Checksum: s.checksum,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) advance(delta int64, now time.Time) {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
if s.startedAt.IsZero() {
|
||||||
|
s.startedAt = now
|
||||||
|
}
|
||||||
|
s.previousUpdatedAt = s.updatedAt
|
||||||
|
s.previousSent = s.sent
|
||||||
|
s.updatedAt = now
|
||||||
|
s.sent += delta
|
||||||
|
if s.sent < 0 {
|
||||||
|
s.sent = 0
|
||||||
|
}
|
||||||
|
if s.size > 0 && s.sent > s.size {
|
||||||
|
s.sent = s.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) syncProgress(progress int64, now time.Time) {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
if progress < 0 {
|
||||||
|
progress = 0
|
||||||
|
}
|
||||||
|
if s.size > 0 && progress > s.size {
|
||||||
|
progress = s.size
|
||||||
|
}
|
||||||
|
if s.startedAt.IsZero() {
|
||||||
|
s.startedAt = now
|
||||||
|
}
|
||||||
|
s.previousUpdatedAt = s.updatedAt
|
||||||
|
s.previousSent = s.sent
|
||||||
|
s.updatedAt = now
|
||||||
|
s.sent = progress
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) buildEvent(kind EnvelopeKind, packet FilePacket, err error, now time.Time) FileEvent {
|
||||||
|
now = normalizeFileEventTime(now)
|
||||||
|
if err != nil && packet.Error == "" {
|
||||||
|
packet.Error = err.Error()
|
||||||
|
}
|
||||||
|
event := FileEvent{
|
||||||
|
Kind: kind,
|
||||||
|
Packet: packet,
|
||||||
|
Path: s.path,
|
||||||
|
Received: s.sent,
|
||||||
|
Err: err,
|
||||||
|
Time: now,
|
||||||
|
}
|
||||||
|
fillFileSendEventTiming(&event, s)
|
||||||
|
fillFileEventProgress(&event)
|
||||||
|
return event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) onMetaSent(now time.Time) FileEvent {
|
||||||
|
s.advance(0, now)
|
||||||
|
return s.buildEvent(EnvelopeFileMeta, s.filePacket(), nil, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) onChunkSent(offset int64, chunkSize int64, now time.Time) (FileEvent, error) {
|
||||||
|
if offset != s.sent {
|
||||||
|
return FileEvent{}, fmt.Errorf("file chunk offset mismatch: got %d want %d", offset, s.sent)
|
||||||
|
}
|
||||||
|
packet := s.filePacket()
|
||||||
|
packet.Offset = offset
|
||||||
|
s.advance(chunkSize, now)
|
||||||
|
return s.buildEvent(EnvelopeFileChunk, packet, nil, now), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) onEndSent(now time.Time) FileEvent {
|
||||||
|
s.advance(0, now)
|
||||||
|
return s.buildEvent(EnvelopeFileEnd, s.filePacket(), nil, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileSendSession) onAbort(stage string, offset int64, cause error, now time.Time) FileEvent {
|
||||||
|
packet := s.filePacket()
|
||||||
|
packet.Stage = stage
|
||||||
|
packet.Offset = offset
|
||||||
|
s.advance(0, now)
|
||||||
|
return s.buildEvent(EnvelopeFileAbort, packet, cause, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendFileMetaWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
|
||||||
|
if err := hooks.sendReliable(ctx, session.metaEnvelope()); err != nil {
|
||||||
|
return handleFileSendFailure(session, hooks, "meta", 0, err)
|
||||||
|
}
|
||||||
|
publishFileSendEvent(hooks, session.onMetaSent(time.Now()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendFileChunksWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
|
||||||
|
fd, err := os.Open(session.path)
|
||||||
|
if err != nil {
|
||||||
|
return handleFileSendFailure(session, hooks, "chunk", session.sent, err)
|
||||||
|
}
|
||||||
|
defer fd.Close()
|
||||||
|
streamErr := streamFileChunks(ctx, fd, hooks.config.ChunkSize, func(offset int64, chunk []byte) error {
|
||||||
|
err := hooks.sendReliable(ctx, session.chunkEnvelope(offset, chunk))
|
||||||
|
if err != nil {
|
||||||
|
return &fileSendError{stage: "chunk", offset: offset, err: err}
|
||||||
|
}
|
||||||
|
event, stateErr := session.onChunkSent(offset, int64(len(chunk)), time.Now())
|
||||||
|
if stateErr != nil {
|
||||||
|
return &fileSendError{stage: "chunk", offset: offset, err: stateErr}
|
||||||
|
}
|
||||||
|
publishFileSendEvent(hooks, event)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
if streamErr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var sendErr *fileSendError
|
||||||
|
if errors.As(streamErr, &sendErr) {
|
||||||
|
return handleFileSendFailure(session, hooks, sendErr.stage, sendErr.offset, sendErr.err)
|
||||||
|
}
|
||||||
|
return handleFileSendFailure(session, hooks, "chunk", session.sent, streamErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sendFileEndWithHooks(ctx context.Context, session *fileSendSession, hooks fileSendHooks) error {
|
||||||
|
if err := hooks.sendReliable(ctx, session.endEnvelope()); err != nil {
|
||||||
|
return handleFileSendFailure(session, hooks, "end", session.sent, err)
|
||||||
|
}
|
||||||
|
publishFileSendEvent(hooks, session.onEndSent(time.Now()))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleFileSendFailure(session *fileSendSession, hooks fileSendHooks, stage string, offset int64, cause error) error {
|
||||||
|
if session != nil && hooks.sendAbort != nil && session.fileID != "" {
|
||||||
|
_ = hooks.sendAbort(session.fileID, stage, offset, cause)
|
||||||
|
}
|
||||||
|
if session != nil {
|
||||||
|
publishFileSendEvent(hooks, session.onAbort(stage, offset, cause, time.Now()))
|
||||||
|
}
|
||||||
|
return cause
|
||||||
|
}
|
||||||
|
|
||||||
|
func publishFileSendEvent(hooks fileSendHooks, event FileEvent) {
|
||||||
|
if hooks.publishEvent != nil {
|
||||||
|
hooks.publishEvent(event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func streamFileChunks(ctx context.Context, reader io.Reader, chunkSize int, sendChunk func(offset int64, chunk []byte) error) error {
|
||||||
|
if chunkSize <= 0 {
|
||||||
|
chunkSize = defaultFileChunkSize
|
||||||
|
}
|
||||||
|
buf := make([]byte, chunkSize)
|
||||||
|
var offset int64
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return fmt.Errorf("file stream canceled: %w", ctx.Err())
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
n, readErr := reader.Read(buf)
|
||||||
|
if n > 0 {
|
||||||
|
chunk := make([]byte, n)
|
||||||
|
copy(chunk, buf[:n])
|
||||||
|
if err := sendChunk(offset, chunk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
offset += int64(n)
|
||||||
|
}
|
||||||
|
if readErr == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errors.Is(readErr, io.EOF) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return readErr
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,224 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileSendSessionProgress(t *testing.T) {
|
||||||
|
session := &fileSendSession{
|
||||||
|
fileID: "file-1",
|
||||||
|
path: "/tmp/demo.bin",
|
||||||
|
name: "demo.bin",
|
||||||
|
size: 200,
|
||||||
|
checksum: "sum",
|
||||||
|
startedAt: time.Unix(100, 0),
|
||||||
|
updatedAt: time.Unix(100, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
metaEvent := session.onMetaSent(time.Unix(100, 0))
|
||||||
|
if got, want := metaEvent.Kind, EnvelopeFileMeta; got != want {
|
||||||
|
t.Fatalf("meta kind mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := metaEvent.Total, int64(200); got != want {
|
||||||
|
t.Fatalf("meta total mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := metaEvent.Received, int64(0); got != want {
|
||||||
|
t.Fatalf("meta received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
chunkEvent, err := session.onChunkSent(0, 80, time.Unix(104, 0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("onChunkSent failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := chunkEvent.Received, int64(80); got != want {
|
||||||
|
t.Fatalf("chunk received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := chunkEvent.Percent, 40.0; got != want {
|
||||||
|
t.Fatalf("chunk percent mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := chunkEvent.Duration, 4*time.Second; got != want {
|
||||||
|
t.Fatalf("chunk duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := chunkEvent.StepDuration, 4*time.Second; got != want {
|
||||||
|
t.Fatalf("chunk step duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := chunkEvent.InstantRateBPS, 20.0; got != want {
|
||||||
|
t.Fatalf("chunk instant rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
secondChunkEvent, err := session.onChunkSent(80, 120, time.Unix(108, 0))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("second onChunkSent failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := secondChunkEvent.Received, int64(200); got != want {
|
||||||
|
t.Fatalf("second chunk received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := secondChunkEvent.Percent, 100.0; got != want {
|
||||||
|
t.Fatalf("second chunk percent mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := secondChunkEvent.RateBPS, 25.0; got != want {
|
||||||
|
t.Fatalf("second chunk rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := secondChunkEvent.StepDuration, 4*time.Second; got != want {
|
||||||
|
t.Fatalf("second chunk step duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := secondChunkEvent.InstantRateBPS, 30.0; got != want {
|
||||||
|
t.Fatalf("second chunk instant rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
endEvent := session.onEndSent(time.Unix(110, 0))
|
||||||
|
if !endEvent.Done {
|
||||||
|
t.Fatal("end event should be done")
|
||||||
|
}
|
||||||
|
if got, want := endEvent.Received, int64(200); got != want {
|
||||||
|
t.Fatalf("end received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.Percent, 100.0; got != want {
|
||||||
|
t.Fatalf("end percent mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.Duration, 10*time.Second; got != want {
|
||||||
|
t.Fatalf("end duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.StepDuration, 2*time.Second; got != want {
|
||||||
|
t.Fatalf("end step duration mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.RateBPS, 20.0; got != want {
|
||||||
|
t.Fatalf("end rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := endEvent.InstantRateBPS, 0.0; got != want {
|
||||||
|
t.Fatalf("end instant rate mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendFileWithHooksLogsLocalProgress(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
filePath := filepath.Join(dir, "demo.txt")
|
||||||
|
data := []byte("hello notify send progress")
|
||||||
|
if err := os.WriteFile(filePath, data, 0o644); err != nil {
|
||||||
|
t.Fatalf("WriteFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var sentKinds []EnvelopeKind
|
||||||
|
var events []FileEvent
|
||||||
|
err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{
|
||||||
|
sendReliable: func(ctx context.Context, env Envelope) error {
|
||||||
|
sentKinds = append(sentKinds, env.Kind)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
sendAbort: func(fileID string, stage string, offset int64, cause error) error {
|
||||||
|
t.Fatalf("unexpected abort: fileID=%s stage=%s offset=%d err=%v", fileID, stage, offset, cause)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
publishEvent: func(event FileEvent) {
|
||||||
|
events = append(events, event)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sendFileWithHooks failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(sentKinds), 3; got != want {
|
||||||
|
t.Fatalf("sent kinds count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if sentKinds[0] != EnvelopeFileMeta || sentKinds[1] != EnvelopeFileChunk || sentKinds[2] != EnvelopeFileEnd {
|
||||||
|
t.Fatalf("unexpected sent kinds: %v", sentKinds)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := len(events), 3; got != want {
|
||||||
|
t.Fatalf("event count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if events[0].Kind != EnvelopeFileMeta || events[1].Kind != EnvelopeFileChunk || events[2].Kind != EnvelopeFileEnd {
|
||||||
|
t.Fatalf("unexpected event kinds: %+v", []EnvelopeKind{events[0].Kind, events[1].Kind, events[2].Kind})
|
||||||
|
}
|
||||||
|
if got, want := events[1].Received, int64(len(data)); got != want {
|
||||||
|
t.Fatalf("chunk received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if !events[2].Done {
|
||||||
|
t.Fatal("end event should be done")
|
||||||
|
}
|
||||||
|
if got, want := events[2].Received, int64(len(data)); got != want {
|
||||||
|
t.Fatalf("end received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[2].Path, filePath; got != want {
|
||||||
|
t.Fatalf("end path mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if events[0].Packet.FileID == "" {
|
||||||
|
t.Fatal("fileID should not be empty")
|
||||||
|
}
|
||||||
|
if events[0].Packet.FileID != events[1].Packet.FileID || events[1].Packet.FileID != events[2].Packet.FileID {
|
||||||
|
t.Fatalf("fileID should stay stable across events: %+v", events)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendFileWithHooksAbortOnChunkFailure(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
filePath := filepath.Join(dir, "demo.txt")
|
||||||
|
data := []byte("hello notify send failure")
|
||||||
|
if err := os.WriteFile(filePath, data, 0o644); err != nil {
|
||||||
|
t.Fatalf("WriteFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wantErr := errors.New("chunk ack timeout")
|
||||||
|
var abortFileID string
|
||||||
|
var abortStage string
|
||||||
|
var abortOffset int64
|
||||||
|
var abortCause error
|
||||||
|
var events []FileEvent
|
||||||
|
|
||||||
|
err := sendFileWithHooks(context.Background(), filePath, fileSendHooks{
|
||||||
|
sendReliable: func(ctx context.Context, env Envelope) error {
|
||||||
|
if env.Kind == EnvelopeFileChunk {
|
||||||
|
return wantErr
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
sendAbort: func(fileID string, stage string, offset int64, cause error) error {
|
||||||
|
abortFileID = fileID
|
||||||
|
abortStage = stage
|
||||||
|
abortOffset = offset
|
||||||
|
abortCause = cause
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
publishEvent: func(event FileEvent) {
|
||||||
|
events = append(events, event)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if !errors.Is(err, wantErr) {
|
||||||
|
t.Fatalf("sendFileWithHooks error mismatch: got %v want %v", err, wantErr)
|
||||||
|
}
|
||||||
|
if abortFileID == "" {
|
||||||
|
t.Fatal("abort should capture fileID")
|
||||||
|
}
|
||||||
|
if got, want := abortStage, "chunk"; got != want {
|
||||||
|
t.Fatalf("abort stage mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := abortOffset, int64(0); got != want {
|
||||||
|
t.Fatalf("abort offset mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if !errors.Is(abortCause, wantErr) {
|
||||||
|
t.Fatalf("abort cause mismatch: got %v want %v", abortCause, wantErr)
|
||||||
|
}
|
||||||
|
if got, want := len(events), 2; got != want {
|
||||||
|
t.Fatalf("event count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[0].Kind, EnvelopeFileMeta; got != want {
|
||||||
|
t.Fatalf("first event kind mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[1].Kind, EnvelopeFileAbort; got != want {
|
||||||
|
t.Fatalf("abort event kind mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[1].Packet.Stage, "chunk"; got != want {
|
||||||
|
t.Fatalf("abort packet stage mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := events[1].Received, int64(0); got != want {
|
||||||
|
t.Fatalf("abort received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if !errors.Is(events[1].Err, wantErr) {
|
||||||
|
t.Fatalf("abort event error mismatch: got %v want %v", events[1].Err, wantErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,328 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
fileTransferMetadataKindKey = "_notify.file_adapter_kind"
|
||||||
|
fileTransferMetadataKindValue = "file"
|
||||||
|
fileTransferMetadataNameKey = "_notify.file_name"
|
||||||
|
fileTransferMetadataModeKey = "_notify.file_mode"
|
||||||
|
fileTransferMetadataModTimeKey = "_notify.file_mod_time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type transferFileSource struct {
|
||||||
|
file *os.File
|
||||||
|
size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTransferFileSource(path string, size int64) (*transferFileSource, error) {
|
||||||
|
file, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &transferFileSource{
|
||||||
|
file: file,
|
||||||
|
size: size,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *transferFileSource) ReadAt(p []byte, off int64) (int, error) {
|
||||||
|
if s == nil || s.file == nil {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
return s.file.ReadAt(p, off)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *transferFileSource) Size() int64 {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return s.size
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *transferFileSource) Close() error {
|
||||||
|
if s == nil || s.file == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
type transferCloseWithError interface {
|
||||||
|
CloseWithError(error) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type transferReceiveOffsetProvider interface {
|
||||||
|
NextOffset() int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileTransferReceiveSink struct {
|
||||||
|
pool *fileReceivePool
|
||||||
|
scope string
|
||||||
|
packet FilePacket
|
||||||
|
publishEvent func(FileEvent)
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
offset int64
|
||||||
|
committed bool
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferReceiveSink(pool *fileReceivePool, scope string, packet FilePacket, publishEvent func(FileEvent)) (*fileTransferReceiveSink, error) {
|
||||||
|
if pool == nil {
|
||||||
|
return nil, errTransferSinkNil
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
session, err := pool.onMeta(scope, packet, now)
|
||||||
|
if publishEvent != nil {
|
||||||
|
publishEvent(fileReceiveEventFromSession(EnvelopeFileMeta, packet, session, "", err, now))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &fileTransferReceiveSink{
|
||||||
|
pool: pool,
|
||||||
|
scope: normalizeFileScope(scope),
|
||||||
|
packet: packet,
|
||||||
|
publishEvent: publishEvent,
|
||||||
|
offset: session.received,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) NextOffset() int64 {
|
||||||
|
if s == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.offset
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) WriteAt(p []byte, off int64) (int, error) {
|
||||||
|
if len(p) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
closed := s.closed
|
||||||
|
s.mu.Unlock()
|
||||||
|
if closed {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
packet := s.packet
|
||||||
|
packet.Offset = off
|
||||||
|
packet.Chunk = append([]byte(nil), p...)
|
||||||
|
session, err := s.pool.onChunk(s.scope, packet, now)
|
||||||
|
if s.publishEvent != nil {
|
||||||
|
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileChunk, packet, session, "", err, now))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if end := off + int64(len(p)); end > s.offset {
|
||||||
|
s.offset = end
|
||||||
|
}
|
||||||
|
s.mu.Unlock()
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) Sync(context.Context) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) Commit(context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
closed := s.closed
|
||||||
|
s.mu.Unlock()
|
||||||
|
if closed {
|
||||||
|
return os.ErrClosed
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
finalPath, session, err := s.pool.onEnd(s.scope, FilePacket{FileID: s.packet.FileID}, now)
|
||||||
|
if s.publishEvent != nil {
|
||||||
|
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileEnd, s.packet, session, finalPath, err, now))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.committed = true
|
||||||
|
s.offset = s.packet.Size
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) Close() error {
|
||||||
|
return s.closeWithError(nil, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) CloseWithError(err error) error {
|
||||||
|
return s.closeWithError(err, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferReceiveSink) closeWithError(err error, publish bool) error {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
if s.closed {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s.closed = true
|
||||||
|
committed := s.committed
|
||||||
|
offset := s.offset
|
||||||
|
s.mu.Unlock()
|
||||||
|
if committed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
packet := FilePacket{
|
||||||
|
FileID: s.packet.FileID,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
packet.Stage = "abort"
|
||||||
|
packet.Error = err.Error()
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
session, abortErr := s.pool.onAbort(s.scope, packet, now)
|
||||||
|
if publish && err != nil && s.publishEvent != nil {
|
||||||
|
s.publishEvent(fileReceiveEventFromSession(EnvelopeFileAbort, packet, session, "", firstErr(abortErr, err), now))
|
||||||
|
}
|
||||||
|
return abortErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstErr(primary error, fallback error) error {
|
||||||
|
if primary != nil {
|
||||||
|
return primary
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileReceiveEventFromSession(kind EnvelopeKind, packet FilePacket, session *fileReceiveSession, path string, err error, now time.Time) FileEvent {
|
||||||
|
event := FileEvent{
|
||||||
|
Kind: kind,
|
||||||
|
Packet: packet,
|
||||||
|
Time: now,
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
switch kind {
|
||||||
|
case EnvelopeFileAbort:
|
||||||
|
event.Received = packet.Offset
|
||||||
|
case EnvelopeFileEnd:
|
||||||
|
event.Path = path
|
||||||
|
}
|
||||||
|
if session != nil {
|
||||||
|
if event.Path == "" {
|
||||||
|
if kind == EnvelopeFileEnd && session.finalPath != "" {
|
||||||
|
event.Path = session.finalPath
|
||||||
|
} else {
|
||||||
|
event.Path = session.tmpPath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if kind != EnvelopeFileAbort {
|
||||||
|
event.Received = session.received
|
||||||
|
}
|
||||||
|
fillFileEventTiming(&event, session)
|
||||||
|
}
|
||||||
|
fillFileEventProgress(&event)
|
||||||
|
return event
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildFileTransferDescriptor(session *fileSendSession) TransferDescriptor {
|
||||||
|
return TransferDescriptor{
|
||||||
|
ID: session.fileID,
|
||||||
|
Channel: TransferChannelData,
|
||||||
|
Size: session.size,
|
||||||
|
Checksum: session.checksum,
|
||||||
|
Metadata: map[string]string{
|
||||||
|
fileTransferMetadataKindKey: fileTransferMetadataKindValue,
|
||||||
|
fileTransferMetadataNameKey: session.name,
|
||||||
|
fileTransferMetadataModeKey: strconv.FormatUint(uint64(session.mode.Perm()), 10),
|
||||||
|
fileTransferMetadataModTimeKey: strconv.FormatInt(session.modTime.UnixNano(), 10),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildStableFileTransferID(session *fileSendSession) string {
|
||||||
|
if session == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
sum := sha256.Sum256([]byte(session.name + "|" + strconv.FormatInt(session.size, 10) + "|" + normalizeChecksum(session.checksum)))
|
||||||
|
return fmt.Sprintf("%s-%s", fileIDBaseName(session.name), hex.EncodeToString(sum[:8]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseFileTransferPacket(desc TransferDescriptor) (FilePacket, bool) {
|
||||||
|
if desc.Metadata[fileTransferMetadataKindKey] != fileTransferMetadataKindValue {
|
||||||
|
return FilePacket{}, false
|
||||||
|
}
|
||||||
|
packet := FilePacket{
|
||||||
|
FileID: desc.ID,
|
||||||
|
Name: desc.Metadata[fileTransferMetadataNameKey],
|
||||||
|
Size: desc.Size,
|
||||||
|
Checksum: desc.Checksum,
|
||||||
|
}
|
||||||
|
if modeValue := desc.Metadata[fileTransferMetadataModeKey]; modeValue != "" {
|
||||||
|
if mode, err := strconv.ParseUint(modeValue, 10, 32); err == nil {
|
||||||
|
packet.Mode = uint32(mode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if modTimeValue := desc.Metadata[fileTransferMetadataModTimeKey]; modTimeValue != "" {
|
||||||
|
if modTime, err := strconv.ParseInt(modTimeValue, 10, 64); err == nil {
|
||||||
|
packet.ModTime = modTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return packet, packet.FileID != "" && packet.Name != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
|
||||||
|
packet, ok := parseFileTransferPacket(info.Descriptor)
|
||||||
|
if !ok {
|
||||||
|
return TransferReceiveOptions{}, false, nil
|
||||||
|
}
|
||||||
|
sink, err := newFileTransferReceiveSink(c.getFileReceivePool(), clientFileScope(), packet, func(event FileEvent) {
|
||||||
|
event.NetType = NET_CLIENT
|
||||||
|
event.ServerConn = c
|
||||||
|
c.publishReceivedFileEventMonitorOnly(event)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return TransferReceiveOptions{}, true, err
|
||||||
|
}
|
||||||
|
return TransferReceiveOptions{
|
||||||
|
Descriptor: cloneTransferDescriptor(info.Descriptor),
|
||||||
|
Sink: sink,
|
||||||
|
VerifyChecksum: false,
|
||||||
|
SyncOnCheckpoint: false,
|
||||||
|
}, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) builtinFileTransferHandler(info TransferAcceptInfo) (TransferReceiveOptions, bool, error) {
|
||||||
|
packet, ok := parseFileTransferPacket(info.Descriptor)
|
||||||
|
if !ok {
|
||||||
|
return TransferReceiveOptions{}, false, nil
|
||||||
|
}
|
||||||
|
sink, err := newFileTransferReceiveSink(s.getFileReceivePool(), transferPublicScopeForPeer(info.LogicalConn), packet, func(event FileEvent) {
|
||||||
|
event.NetType = NET_SERVER
|
||||||
|
event.LogicalConn = info.LogicalConn
|
||||||
|
event.TransportConn = info.TransportConn
|
||||||
|
s.publishReceivedFileEventMonitorOnly(event)
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return TransferReceiveOptions{}, true, err
|
||||||
|
}
|
||||||
|
return TransferReceiveOptions{
|
||||||
|
Descriptor: cloneTransferDescriptor(info.Descriptor),
|
||||||
|
Sink: sink,
|
||||||
|
VerifyChecksum: false,
|
||||||
|
SyncOnCheckpoint: false,
|
||||||
|
}, true, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,131 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSendFileUsesTransferKernelAndBuiltinFileReceiver(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
if err := UseModernPSKServer(server, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKServer failed: %v", err)
|
||||||
|
}
|
||||||
|
receiveDir := t.TempDir()
|
||||||
|
if err := server.SetFileReceiveDir(receiveDir); err != nil {
|
||||||
|
t.Fatalf("SetFileReceiveDir failed: %v", err)
|
||||||
|
}
|
||||||
|
var serverMu sync.Mutex
|
||||||
|
var serverEvents []FileEvent
|
||||||
|
server.SetFileHandler(func(event FileEvent) {
|
||||||
|
serverMu.Lock()
|
||||||
|
serverEvents = append(serverEvents, event)
|
||||||
|
serverMu.Unlock()
|
||||||
|
})
|
||||||
|
if err := server.Listen("tcp", "127.0.0.1:0"); err != nil {
|
||||||
|
t.Fatalf("server Listen failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = server.Stop() }()
|
||||||
|
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
if err := UseModernPSKClient(client, integrationSharedSecret, integrationModernPSKOptions()); err != nil {
|
||||||
|
t.Fatalf("UseModernPSKClient failed: %v", err)
|
||||||
|
}
|
||||||
|
var clientMu sync.Mutex
|
||||||
|
var clientEvents []FileEvent
|
||||||
|
client.setFileEventObserver(func(event FileEvent) {
|
||||||
|
clientMu.Lock()
|
||||||
|
clientEvents = append(clientEvents, event)
|
||||||
|
clientMu.Unlock()
|
||||||
|
})
|
||||||
|
if err := client.Connect("tcp", server.listener.Addr().String()); err != nil {
|
||||||
|
t.Fatalf("client Connect failed: %v", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = client.Stop() }()
|
||||||
|
|
||||||
|
payload := bytes.Repeat([]byte("send-file-transfer-kernel-"), 1024)
|
||||||
|
sendPath := filepath.Join(t.TempDir(), "payload.bin")
|
||||||
|
if err := os.WriteFile(sendPath, payload, 0o600); err != nil {
|
||||||
|
t.Fatalf("WriteFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := client.SendFile(context.Background(), sendPath); err != nil {
|
||||||
|
t.Fatalf("SendFile failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
receivedPath := waitForSingleFileInDir(t, receiveDir, 2*time.Second)
|
||||||
|
received, err := os.ReadFile(receivedPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadFile failed: %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(received, payload) {
|
||||||
|
t.Fatalf("received payload mismatch: got %d want %d", len(received), len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
clientSnapshots, err := GetClientTransferSnapshots(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientTransferSnapshots failed: %v", err)
|
||||||
|
}
|
||||||
|
serverSnapshots, err := GetServerTransferSnapshots(server)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerTransferSnapshots failed: %v", err)
|
||||||
|
}
|
||||||
|
if !containsFileTransferSnapshot(clientSnapshots) {
|
||||||
|
t.Fatalf("client snapshots do not contain file transfer metadata: %+v", clientSnapshots)
|
||||||
|
}
|
||||||
|
if !containsFileTransferSnapshot(serverSnapshots) {
|
||||||
|
t.Fatalf("server snapshots do not contain file transfer metadata: %+v", serverSnapshots)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientMu.Lock()
|
||||||
|
serverMu.Lock()
|
||||||
|
defer clientMu.Unlock()
|
||||||
|
defer serverMu.Unlock()
|
||||||
|
if !containsFileEventKind(clientEvents, EnvelopeFileMeta) || !containsFileEventKind(clientEvents, EnvelopeFileEnd) {
|
||||||
|
t.Fatalf("client file events missing meta/end: %+v", clientEvents)
|
||||||
|
}
|
||||||
|
if !containsFileEventKind(serverEvents, EnvelopeFileMeta) || !containsFileEventKind(serverEvents, EnvelopeFileEnd) {
|
||||||
|
t.Fatalf("server file events missing meta/end: %+v", serverEvents)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForSingleFileInDir(t *testing.T, dir string, timeout time.Duration) string {
|
||||||
|
t.Helper()
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
entries, err := os.ReadDir(dir)
|
||||||
|
if err == nil {
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return filepath.Join(dir, entry.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
}
|
||||||
|
t.Fatalf("timed out waiting for received file in %s", dir)
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsFileTransferSnapshot(list []TransferSnapshot) bool {
|
||||||
|
for _, snapshot := range list {
|
||||||
|
if snapshot.Metadata[fileTransferMetadataKindKey] == fileTransferMetadataKindValue && snapshot.State == TransferStateDone {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsFileEventKind(list []FileEvent, kind EnvelopeKind) bool {
|
||||||
|
for _, event := range list {
|
||||||
|
if event.Kind == kind {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,81 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
const defaultFileSendRetry = 3
|
||||||
|
|
||||||
|
const defaultFileAckTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
type fileTransferConfig struct {
|
||||||
|
ChunkSize int
|
||||||
|
AckTimeout time.Duration
|
||||||
|
SendRetry int
|
||||||
|
ReceiveCompletedLimit int
|
||||||
|
MonitorCompletedLimit int
|
||||||
|
}
|
||||||
|
|
||||||
|
func defaultFileTransferConfig() fileTransferConfig {
|
||||||
|
return fileTransferConfig{
|
||||||
|
ChunkSize: defaultFileChunkSize,
|
||||||
|
AckTimeout: defaultFileAckTimeout,
|
||||||
|
SendRetry: defaultFileSendRetry,
|
||||||
|
ReceiveCompletedLimit: defaultFileReceiveCompletedLimit,
|
||||||
|
MonitorCompletedLimit: defaultFileTransferCompletedLimit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeFileTransferConfig(cfg fileTransferConfig) fileTransferConfig {
|
||||||
|
defaults := defaultFileTransferConfig()
|
||||||
|
if cfg.ChunkSize <= 0 {
|
||||||
|
cfg.ChunkSize = defaults.ChunkSize
|
||||||
|
}
|
||||||
|
if cfg.AckTimeout <= 0 {
|
||||||
|
cfg.AckTimeout = defaults.AckTimeout
|
||||||
|
}
|
||||||
|
if cfg.SendRetry <= 0 {
|
||||||
|
cfg.SendRetry = defaults.SendRetry
|
||||||
|
}
|
||||||
|
if cfg.ReceiveCompletedLimit <= 0 {
|
||||||
|
cfg.ReceiveCompletedLimit = defaults.ReceiveCompletedLimit
|
||||||
|
}
|
||||||
|
if cfg.MonitorCompletedLimit <= 0 {
|
||||||
|
cfg.MonitorCompletedLimit = defaults.MonitorCompletedLimit
|
||||||
|
}
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getFileTransferConfig() fileTransferConfig {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.fileTransferCfg = normalizeFileTransferConfig(c.fileTransferCfg)
|
||||||
|
return c.fileTransferCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getFileTransferConfig() fileTransferConfig {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
s.fileTransferCfg = normalizeFileTransferConfig(s.fileTransferCfg)
|
||||||
|
return s.fileTransferCfg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) setFileTransferConfig(cfg fileTransferConfig) {
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
c.mu.Lock()
|
||||||
|
c.fileTransferCfg = cfg
|
||||||
|
state := c.logicalSession
|
||||||
|
c.mu.Unlock()
|
||||||
|
if state != nil {
|
||||||
|
state.applyFileTransferConfig(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) setFileTransferConfig(cfg fileTransferConfig) {
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
s.mu.Lock()
|
||||||
|
s.fileTransferCfg = cfg
|
||||||
|
state := s.logicalSession
|
||||||
|
s.mu.Unlock()
|
||||||
|
if state != nil {
|
||||||
|
state.applyFileTransferConfig(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,104 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientFileTransferConfigDefaults(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
|
||||||
|
cfg := client.getFileTransferConfig()
|
||||||
|
|
||||||
|
if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want {
|
||||||
|
t.Fatalf("chunk size mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want {
|
||||||
|
t.Fatalf("ack timeout mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.SendRetry, defaultFileSendRetry; got != want {
|
||||||
|
t.Fatalf("send retry mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want {
|
||||||
|
t.Fatalf("receive completed limit mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want {
|
||||||
|
t.Fatalf("monitor completed limit mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerFileTransferConfigNormalization(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
|
||||||
|
server.setFileTransferConfig(fileTransferConfig{})
|
||||||
|
cfg := server.getFileTransferConfig()
|
||||||
|
|
||||||
|
if got, want := cfg.ChunkSize, defaultFileChunkSize; got != want {
|
||||||
|
t.Fatalf("normalized chunk size mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.AckTimeout, defaultFileAckTimeout; got != want {
|
||||||
|
t.Fatalf("normalized ack timeout mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.SendRetry, defaultFileSendRetry; got != want {
|
||||||
|
t.Fatalf("normalized retry mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.ReceiveCompletedLimit, defaultFileReceiveCompletedLimit; got != want {
|
||||||
|
t.Fatalf("normalized receive completed limit mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := cfg.MonitorCompletedLimit, defaultFileTransferCompletedLimit; got != want {
|
||||||
|
t.Fatalf("normalized monitor completed limit mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientFileTransferConfigPropagatesRetentionLimits(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
|
||||||
|
client.setFileTransferConfig(fileTransferConfig{
|
||||||
|
ChunkSize: 64,
|
||||||
|
AckTimeout: time.Second,
|
||||||
|
SendRetry: 2,
|
||||||
|
ReceiveCompletedLimit: 7,
|
||||||
|
MonitorCompletedLimit: 9,
|
||||||
|
})
|
||||||
|
|
||||||
|
if got, want := client.getFileReceivePool().completedLimit, 7; got != want {
|
||||||
|
t.Fatalf("client receive pool completed limit mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := client.getFileTransferState().monitorView().completedLimit, 9; got != want {
|
||||||
|
t.Fatalf("client transfer monitor completed limit mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSendFileWithHooksHonorsConfiguredChunkSize(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "payload.bin")
|
||||||
|
if err := os.WriteFile(path, []byte("abcdefg"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write temp file failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunks []int
|
||||||
|
err := sendFileWithHooks(context.Background(), path, fileSendHooks{
|
||||||
|
config: fileTransferConfig{
|
||||||
|
ChunkSize: 3,
|
||||||
|
AckTimeout: time.Millisecond,
|
||||||
|
SendRetry: 1,
|
||||||
|
},
|
||||||
|
sendReliable: func(ctx context.Context, env Envelope) error {
|
||||||
|
if env.Kind == EnvelopeFileChunk {
|
||||||
|
chunks = append(chunks, len(env.File.Chunk))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("sendFileWithHooks failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := chunks, []int{3, 3, 1}; !reflect.DeepEqual(got, want) {
|
||||||
|
t.Fatalf("chunk sizes mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,174 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "sync"
|
||||||
|
|
||||||
|
const defaultFileTransferCompletedLimit = 128
|
||||||
|
|
||||||
|
type fileTransferMonitor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
active map[string]fileTransferSnapshot
|
||||||
|
completed map[string]fileTransferSnapshot
|
||||||
|
runtimeActive map[string]fileTransferSnapshot
|
||||||
|
runtimeCompleted map[string]fileTransferSnapshot
|
||||||
|
completedLimit int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferMonitor() *fileTransferMonitor {
|
||||||
|
return newFileTransferMonitorWithConfig(defaultFileTransferConfig())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferMonitorWithConfig(cfg fileTransferConfig) *fileTransferMonitor {
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
return newFileTransferMonitorWithCompletedLimit(cfg.MonitorCompletedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferMonitorWithCompletedLimit(limit int) *fileTransferMonitor {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = defaultFileTransferCompletedLimit
|
||||||
|
}
|
||||||
|
return &fileTransferMonitor{
|
||||||
|
active: make(map[string]fileTransferSnapshot),
|
||||||
|
completed: make(map[string]fileTransferSnapshot),
|
||||||
|
runtimeActive: make(map[string]fileTransferSnapshot),
|
||||||
|
runtimeCompleted: make(map[string]fileTransferSnapshot),
|
||||||
|
completedLimit: limit,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) applyConfig(cfg fileTransferConfig) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg = normalizeFileTransferConfig(cfg)
|
||||||
|
m.mu.Lock()
|
||||||
|
m.completedLimit = cfg.MonitorCompletedLimit
|
||||||
|
m.trimCompletedLocked()
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) observe(direction fileTransferDirection, event FileEvent) {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !isFileTransferObservable(event.Kind) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
snapshot := fileTransferSnapshotFromEvent(direction, event)
|
||||||
|
key := fileTransferMonitorKey(direction, snapshot.Scope, snapshot.FileID)
|
||||||
|
runtimeKey := fileTransferRuntimeMonitorKey(direction, snapshot.RuntimeScope, snapshot.FileID)
|
||||||
|
if key == "" || runtimeKey == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if isFileTransferTerminal(snapshot.Kind) {
|
||||||
|
delete(m.active, key)
|
||||||
|
m.completed[key] = snapshot
|
||||||
|
delete(m.runtimeActive, runtimeKey)
|
||||||
|
m.runtimeCompleted[runtimeKey] = snapshot
|
||||||
|
m.trimCompletedLocked()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(m.completed, key)
|
||||||
|
m.active[key] = snapshot
|
||||||
|
delete(m.runtimeCompleted, runtimeKey)
|
||||||
|
m.runtimeActive[runtimeKey] = snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) activeSnapshots() []fileTransferSnapshot {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return sortedFileTransferSnapshots(m.active)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) activeSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return filteredFileTransferSnapshots(m.active, direction)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) completedSnapshots() []fileTransferSnapshot {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return sortedFileTransferSnapshots(m.completed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) completedSnapshotsByDirection(direction fileTransferDirection) []fileTransferSnapshot {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return filteredFileTransferSnapshots(m.completed, direction)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) latestSnapshot(direction fileTransferDirection, scope string, fileID string) (fileTransferSnapshot, bool) {
|
||||||
|
if m == nil {
|
||||||
|
return fileTransferSnapshot{}, false
|
||||||
|
}
|
||||||
|
key := fileTransferMonitorKey(direction, scope, fileID)
|
||||||
|
if key == "" {
|
||||||
|
return fileTransferSnapshot{}, false
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if snapshot, ok := m.active[key]; ok {
|
||||||
|
return snapshot, true
|
||||||
|
}
|
||||||
|
snapshot, ok := m.completed[key]
|
||||||
|
return snapshot, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) snapshotsByFileID(fileID string) []fileTransferSnapshot {
|
||||||
|
if m == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
latest := latestFileTransferSnapshotsLocked(m.active, m.completed)
|
||||||
|
return filterFileTransferSnapshotsByFileID(latest, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) snapshotsByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSnapshot {
|
||||||
|
if m == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
latest := latestFileTransferSnapshotsLocked(m.active, m.completed)
|
||||||
|
return filterFileTransferSnapshotsByDirectionAndFileID(latest, direction, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) trimCompletedLocked() {
|
||||||
|
trimFileTransferSnapshotsLocked(m.completed, m.completedLimit)
|
||||||
|
trimFileTransferSnapshotsLocked(m.runtimeCompleted, m.completedLimit)
|
||||||
|
}
|
||||||
|
|
||||||
|
func trimFileTransferSnapshotsLocked(snapshots map[string]fileTransferSnapshot, limit int) {
|
||||||
|
if limit <= 0 || len(snapshots) <= limit {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for len(snapshots) > limit {
|
||||||
|
oldestKey := ""
|
||||||
|
oldestSnapshot := fileTransferSnapshot{}
|
||||||
|
for key, snapshot := range snapshots {
|
||||||
|
if oldestKey == "" || fileTransferSnapshotOlder(snapshot, oldestSnapshot, key, oldestKey) {
|
||||||
|
oldestKey = key
|
||||||
|
oldestSnapshot = snapshot
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if oldestKey == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(snapshots, oldestKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,329 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientTransferMonitorTracksSendLifecycle(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
monitor := client.getFileTransferState().monitorView()
|
||||||
|
now := time.Unix(100, 0)
|
||||||
|
|
||||||
|
client.publishSendFileEvent(FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "send-1", Size: 100},
|
||||||
|
Path: "/tmp/send-1.bin",
|
||||||
|
Total: 100,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
client.publishSendFileEvent(FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "send-1", Size: 100},
|
||||||
|
Path: "/tmp/send-1.bin",
|
||||||
|
Received: 40,
|
||||||
|
Total: 100,
|
||||||
|
Percent: 40,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(2 * time.Second),
|
||||||
|
Duration: 2 * time.Second,
|
||||||
|
RateBPS: 20,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
StepDuration: 2 * time.Second,
|
||||||
|
})
|
||||||
|
|
||||||
|
active := monitor.activeSnapshots()
|
||||||
|
if got, want := len(active), 1; got != want {
|
||||||
|
t.Fatalf("active count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].Direction, fileTransferDirectionSend; got != want {
|
||||||
|
t.Fatalf("direction mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].Scope, clientFileScope(); got != want {
|
||||||
|
t.Fatalf("scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].Received, int64(40); got != want {
|
||||||
|
t.Fatalf("received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("latest snapshot should exist while active")
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Kind, EnvelopeFileChunk; got != want {
|
||||||
|
t.Fatalf("latest active kind mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Received, int64(40); got != want {
|
||||||
|
t.Fatalf("latest active received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.publishSendFileEvent(FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "send-1", Size: 100},
|
||||||
|
Path: "/tmp/send-1.bin",
|
||||||
|
Received: 100,
|
||||||
|
Total: 100,
|
||||||
|
Percent: 100,
|
||||||
|
Done: true,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(4 * time.Second),
|
||||||
|
Duration: 4 * time.Second,
|
||||||
|
RateBPS: 25,
|
||||||
|
Time: now.Add(4 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
active = monitor.activeSnapshots()
|
||||||
|
if got, want := len(active), 0; got != want {
|
||||||
|
t.Fatalf("active count after end mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
completed := monitor.completedSnapshots()
|
||||||
|
if got, want := len(completed), 1; got != want {
|
||||||
|
t.Fatalf("completed count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed[0].Done, true; got != want {
|
||||||
|
t.Fatalf("done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed[0].Received, int64(100); got != want {
|
||||||
|
t.Fatalf("completed received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
snapshot, ok = monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "send-1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("latest snapshot should exist after completion")
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Kind, EnvelopeFileEnd; got != want {
|
||||||
|
t.Fatalf("latest completed kind mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Done, true; got != want {
|
||||||
|
t.Fatalf("latest completed done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestServerTransferMonitorUsesClientScope(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
monitor := server.getFileTransferState().monitorView()
|
||||||
|
client := &ClientConn{ClientID: "client-1"}
|
||||||
|
now := time.Unix(200, 0)
|
||||||
|
|
||||||
|
server.publishReceivedFileEvent(FileEvent{
|
||||||
|
NetType: NET_SERVER,
|
||||||
|
ClientConn: client,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "recv-1", Size: 50},
|
||||||
|
Path: "/tmp/recv-1.part",
|
||||||
|
Received: 20,
|
||||||
|
Total: 50,
|
||||||
|
Percent: 40,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(time.Second),
|
||||||
|
Duration: time.Second,
|
||||||
|
RateBPS: 20,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
active := monitor.activeSnapshots()
|
||||||
|
if got, want := len(active), 1; got != want {
|
||||||
|
t.Fatalf("active count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].Direction, fileTransferDirectionReceive; got != want {
|
||||||
|
t.Fatalf("direction mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].Scope, serverFileScope(client); got != want {
|
||||||
|
t.Fatalf("scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].FileID, "recv-1"; got != want {
|
||||||
|
t.Fatalf("fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransferMonitorDirectionQueries(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
now := time.Unix(300, 0)
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 10},
|
||||||
|
Received: 4,
|
||||||
|
Total: 10,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 10},
|
||||||
|
Received: 7,
|
||||||
|
Total: 10,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
sendSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionSend)
|
||||||
|
if got, want := len(sendSnapshots), 1; got != want {
|
||||||
|
t.Fatalf("send snapshots count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshots[0].Received, int64(4); got != want {
|
||||||
|
t.Fatalf("send snapshot received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
recvSnapshots := monitor.activeSnapshotsByDirection(fileTransferDirectionReceive)
|
||||||
|
if got, want := len(recvSnapshots), 1; got != want {
|
||||||
|
t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshots[0].Received, int64(7); got != want {
|
||||||
|
t.Fatalf("recv snapshot received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "shared")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("send latest snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshot.Received, int64(4); got != want {
|
||||||
|
t.Fatalf("send latest received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
recvSnapshot, ok := monitor.latestSnapshot(fileTransferDirectionReceive, clientFileScope(), "shared")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("recv latest snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshot.Received, int64(7); got != want {
|
||||||
|
t.Fatalf("recv latest received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransferMonitorSnapshotsByFileID(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
now := time.Unix(400, 0)
|
||||||
|
serverClientA := &ClientConn{ClientID: "client-a"}
|
||||||
|
serverClientB := &ClientConn{ClientID: "client-b"}
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 8,
|
||||||
|
Total: 20,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClientA,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 12,
|
||||||
|
Total: 20,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClientB,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 20,
|
||||||
|
Total: 20,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
allSnapshots := monitor.snapshotsByFileID("shared")
|
||||||
|
if got, want := len(allSnapshots), 3; got != want {
|
||||||
|
t.Fatalf("all snapshots count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := allSnapshots[0].Direction, fileTransferDirectionReceive; got != want {
|
||||||
|
t.Fatalf("first snapshot direction mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := allSnapshots[0].Scope, serverFileScope(serverClientA); got != want {
|
||||||
|
t.Fatalf("first snapshot scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := allSnapshots[1].Scope, serverFileScope(serverClientB); got != want {
|
||||||
|
t.Fatalf("second snapshot scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := allSnapshots[2].Direction, fileTransferDirectionSend; got != want {
|
||||||
|
t.Fatalf("third snapshot direction mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
recvSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionReceive, "shared")
|
||||||
|
if got, want := len(recvSnapshots), 2; got != want {
|
||||||
|
t.Fatalf("recv snapshots count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshots[0].Scope, serverFileScope(serverClientA); got != want {
|
||||||
|
t.Fatalf("recv first scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshots[1].Scope, serverFileScope(serverClientB); got != want {
|
||||||
|
t.Fatalf("recv second scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshots[1].Done, true; got != want {
|
||||||
|
t.Fatalf("recv completed snapshot mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
sendSnapshots := monitor.snapshotsByDirectionAndFileID(fileTransferDirectionSend, "shared")
|
||||||
|
if got, want := len(sendSnapshots), 1; got != want {
|
||||||
|
t.Fatalf("send snapshots count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshots[0].Received, int64(8); got != want {
|
||||||
|
t.Fatalf("send snapshot received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransferMonitorCompletedRetentionEvictsOldest(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitorWithCompletedLimit(2)
|
||||||
|
now := time.Unix(500, 0)
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "active-1", Size: 10},
|
||||||
|
Received: 3,
|
||||||
|
Total: 10,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "done-1", Size: 10},
|
||||||
|
Received: 10,
|
||||||
|
Total: 10,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "done-2", Size: 10},
|
||||||
|
Received: 10,
|
||||||
|
Total: 10,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "done-3", Size: 10},
|
||||||
|
Received: 10,
|
||||||
|
Total: 10,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(3 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
active := monitor.activeSnapshots()
|
||||||
|
if got, want := len(active), 1; got != want {
|
||||||
|
t.Fatalf("active count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].FileID, "active-1"; got != want {
|
||||||
|
t.Fatalf("active fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
completed := monitor.completedSnapshots()
|
||||||
|
if got, want := len(completed), 2; got != want {
|
||||||
|
t.Fatalf("completed count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed[0].FileID, "done-2"; got != want {
|
||||||
|
t.Fatalf("first completed fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed[1].FileID, "done-3"; got != want {
|
||||||
|
t.Fatalf("second completed fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-1"); ok {
|
||||||
|
t.Fatal("oldest completed snapshot should be evicted")
|
||||||
|
}
|
||||||
|
if _, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "done-3"); !ok {
|
||||||
|
t.Fatal("latest completed snapshot should be retained")
|
||||||
|
}
|
||||||
|
if snapshot, ok := monitor.latestSnapshot(fileTransferDirectionSend, clientFileScope(), "active-1"); !ok {
|
||||||
|
t.Fatal("active snapshot should remain available")
|
||||||
|
} else if got, want := snapshot.Kind, EnvelopeFileChunk; got != want {
|
||||||
|
t.Fatalf("active latest kind mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,283 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type FileTransferSummary struct {
|
||||||
|
Direction TransferDirection
|
||||||
|
Scope string
|
||||||
|
RuntimeScope string
|
||||||
|
TransportGeneration uint64
|
||||||
|
NetType NetType
|
||||||
|
Kind EnvelopeKind
|
||||||
|
FileID string
|
||||||
|
Path string
|
||||||
|
Received int64
|
||||||
|
Total int64
|
||||||
|
Percent float64
|
||||||
|
Active bool
|
||||||
|
Terminal bool
|
||||||
|
Done bool
|
||||||
|
Failed bool
|
||||||
|
Err error
|
||||||
|
StartedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
RateBPS float64
|
||||||
|
StepDuration time.Duration
|
||||||
|
InstantRateBPS float64
|
||||||
|
Time time.Time
|
||||||
|
Stage string
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileTransferSummaryGroup struct {
|
||||||
|
Send []FileTransferSummary
|
||||||
|
Receive []FileTransferSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
type FileTransferSummaryQuery struct {
|
||||||
|
Scope string
|
||||||
|
RuntimeScope string
|
||||||
|
TransportGeneration uint64
|
||||||
|
MatchTransportGeneration bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type clientFileTransferSummaryReader interface {
|
||||||
|
clientFileTransferActiveSummaries() FileTransferSummaryGroup
|
||||||
|
clientFileTransferCompletedSummaries() FileTransferSummaryGroup
|
||||||
|
clientFileTransferFailedSummaries() FileTransferSummaryGroup
|
||||||
|
clientFileTransferLatestByFileID(string) FileTransferSummaryGroup
|
||||||
|
clientFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
type serverFileTransferSummaryReader interface {
|
||||||
|
serverFileTransferActiveSummaries() FileTransferSummaryGroup
|
||||||
|
serverFileTransferCompletedSummaries() FileTransferSummaryGroup
|
||||||
|
serverFileTransferFailedSummaries() FileTransferSummaryGroup
|
||||||
|
serverFileTransferLatestByFileID(string) FileTransferSummaryGroup
|
||||||
|
serverFileTransferLatestByFileIDQuery(string, FileTransferSummaryQuery) FileTransferSummaryGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
errClientFileTransferSummaryNil = errors.New("client file transfer summary target is nil")
|
||||||
|
errServerFileTransferSummaryNil = errors.New("server file transfer summary target is nil")
|
||||||
|
errClientFileTransferSummaryUnsupported = errors.New("client file transfer summary target type is unsupported")
|
||||||
|
errServerFileTransferSummaryUnsupported = errors.New("server file transfer summary target type is unsupported")
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetClientFileTransferActiveSummaries(c Client) (FileTransferSummaryGroup, error) {
|
||||||
|
if c == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.clientFileTransferActiveSummaries(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerFileTransferActiveSummaries(s Server) (FileTransferSummaryGroup, error) {
|
||||||
|
if s == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.serverFileTransferActiveSummaries(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientFileTransferCompletedSummaries(c Client) (FileTransferSummaryGroup, error) {
|
||||||
|
if c == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.clientFileTransferCompletedSummaries(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerFileTransferCompletedSummaries(s Server) (FileTransferSummaryGroup, error) {
|
||||||
|
if s == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.serverFileTransferCompletedSummaries(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientFileTransferFailedSummaries(c Client) (FileTransferSummaryGroup, error) {
|
||||||
|
if c == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.clientFileTransferFailedSummaries(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerFileTransferFailedSummaries(s Server) (FileTransferSummaryGroup, error) {
|
||||||
|
if s == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.serverFileTransferFailedSummaries(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientFileTransferLatestByFileID(c Client, fileID string) (FileTransferSummaryGroup, error) {
|
||||||
|
if c == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.clientFileTransferLatestByFileID(fileID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerFileTransferLatestByFileID(s Server, fileID string) (FileTransferSummaryGroup, error) {
|
||||||
|
if s == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.serverFileTransferLatestByFileID(fileID), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetClientFileTransferLatestByFileIDQuery(c Client, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) {
|
||||||
|
if c == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(c).(clientFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errClientFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.clientFileTransferLatestByFileIDQuery(fileID, query), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetServerFileTransferLatestByFileIDQuery(s Server, fileID string, query FileTransferSummaryQuery) (FileTransferSummaryGroup, error) {
|
||||||
|
if s == nil {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryNil
|
||||||
|
}
|
||||||
|
reader, ok := any(s).(serverFileTransferSummaryReader)
|
||||||
|
if !ok {
|
||||||
|
return FileTransferSummaryGroup{}, errServerFileTransferSummaryUnsupported
|
||||||
|
}
|
||||||
|
return reader.serverFileTransferLatestByFileIDQuery(fileID, query), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientFileTransferActiveSummaries() FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(c.getFileTransferState().active())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientFileTransferCompletedSummaries() FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(c.getFileTransferState().completed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientFileTransferFailedSummaries() FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(c.getFileTransferState().failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileID(fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) clientFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(c.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverFileTransferActiveSummaries() FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(s.getFileTransferState().active())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverFileTransferCompletedSummaries() FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(s.getFileTransferState().completed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverFileTransferFailedSummaries() FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(s.getFileTransferState().failed())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverFileTransferLatestByFileID(fileID string) FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileID(fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) serverFileTransferLatestByFileIDQuery(fileID string, query FileTransferSummaryQuery) FileTransferSummaryGroup {
|
||||||
|
return publicFileTransferSummaryGroup(s.getFileTransferState().latestByFileIDQuery(fileID, internalFileTransferSummaryQuery(query)))
|
||||||
|
}
|
||||||
|
|
||||||
|
func publicFileTransferSummaryGroup(src fileTransferSummaryGroup) FileTransferSummaryGroup {
|
||||||
|
return FileTransferSummaryGroup{
|
||||||
|
Send: publicFileTransferSummaries(src.Send),
|
||||||
|
Receive: publicFileTransferSummaries(src.Receive),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func publicFileTransferSummaries(src []fileTransferSummary) []FileTransferSummary {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := make([]FileTransferSummary, 0, len(src))
|
||||||
|
for _, summary := range src {
|
||||||
|
out = append(out, publicFileTransferSummary(summary))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func publicFileTransferSummary(summary fileTransferSummary) FileTransferSummary {
|
||||||
|
return FileTransferSummary{
|
||||||
|
Direction: publicFileTransferDirection(summary.Direction),
|
||||||
|
Scope: summary.Scope,
|
||||||
|
RuntimeScope: summary.RuntimeScope,
|
||||||
|
TransportGeneration: summary.TransportGeneration,
|
||||||
|
NetType: summary.NetType,
|
||||||
|
Kind: summary.Kind,
|
||||||
|
FileID: summary.FileID,
|
||||||
|
Path: summary.Path,
|
||||||
|
Received: summary.Received,
|
||||||
|
Total: summary.Total,
|
||||||
|
Percent: summary.Percent,
|
||||||
|
Active: summary.Active,
|
||||||
|
Terminal: summary.Terminal,
|
||||||
|
Done: summary.Done,
|
||||||
|
Failed: summary.Failed,
|
||||||
|
Err: summary.Err,
|
||||||
|
StartedAt: summary.StartedAt,
|
||||||
|
UpdatedAt: summary.UpdatedAt,
|
||||||
|
Duration: summary.Duration,
|
||||||
|
RateBPS: summary.RateBPS,
|
||||||
|
StepDuration: summary.StepDuration,
|
||||||
|
InstantRateBPS: summary.InstantRateBPS,
|
||||||
|
Time: summary.Time,
|
||||||
|
Stage: summary.Stage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func publicFileTransferDirection(direction fileTransferDirection) TransferDirection {
|
||||||
|
switch direction {
|
||||||
|
case fileTransferDirectionReceive:
|
||||||
|
return TransferDirectionReceive
|
||||||
|
default:
|
||||||
|
return TransferDirectionSend
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func internalFileTransferSummaryQuery(query FileTransferSummaryQuery) fileTransferSummaryQuery {
|
||||||
|
return fileTransferSummaryQuery{
|
||||||
|
Scope: query.Scope,
|
||||||
|
RuntimeScope: query.RuntimeScope,
|
||||||
|
TransportGeneration: query.TransportGeneration,
|
||||||
|
MatchTransportGeneration: query.MatchTransportGeneration,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,202 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetClientFileTransferSummariesRejectNil(t *testing.T) {
|
||||||
|
if _, err := GetClientFileTransferActiveSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetClientFileTransferActiveSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetClientFileTransferCompletedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetClientFileTransferCompletedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetClientFileTransferFailedSummaries(nil); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetClientFileTransferFailedSummaries nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetClientFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetClientFileTransferLatestByFileID nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetClientFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errClientFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetClientFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errClientFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetServerFileTransferActiveSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetServerFileTransferActiveSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetServerFileTransferCompletedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetServerFileTransferCompletedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetServerFileTransferFailedSummaries(nil); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetServerFileTransferFailedSummaries nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetServerFileTransferLatestByFileID(nil, "x"); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetServerFileTransferLatestByFileID nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
if _, err := GetServerFileTransferLatestByFileIDQuery(nil, "x", FileTransferSummaryQuery{}); !errors.Is(err, errServerFileTransferSummaryNil) {
|
||||||
|
t.Fatalf("GetServerFileTransferLatestByFileIDQuery nil error = %v, want %v", err, errServerFileTransferSummaryNil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetClientFileTransferSummariesPublicAPI(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
now := time.Unix(2000, 0)
|
||||||
|
|
||||||
|
client.publishSendFileEvent(FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "client-public", Size: 16},
|
||||||
|
Received: 6,
|
||||||
|
Total: 16,
|
||||||
|
Percent: 37.5,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(time.Second),
|
||||||
|
Duration: time.Second,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
active, err := GetClientFileTransferActiveSummaries(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientFileTransferActiveSummaries failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(active.Send), 1; got != want {
|
||||||
|
t.Fatalf("active send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active.Send[0].RuntimeScope, clientFileScope(); got != want {
|
||||||
|
t.Fatalf("active runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got := active.Send[0].TransportGeneration; got != 0 {
|
||||||
|
t.Fatalf("active transport generation mismatch: got %d want 0", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
latest, err := GetClientFileTransferLatestByFileID(client, "client-public")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientFileTransferLatestByFileID failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(latest.Send), 1; got != want {
|
||||||
|
t.Fatalf("latest send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := latest.Send[0].Direction, TransferDirectionSend; got != want {
|
||||||
|
t.Fatalf("latest direction mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
query, err := GetClientFileTransferLatestByFileIDQuery(client, "client-public", FileTransferSummaryQuery{
|
||||||
|
RuntimeScope: clientFileScope(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientFileTransferLatestByFileIDQuery failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(query.Send), 1; got != want {
|
||||||
|
t.Fatalf("query send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.publishSendFileEvent(FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "client-public", Size: 16},
|
||||||
|
Received: 16,
|
||||||
|
Total: 16,
|
||||||
|
Percent: 100,
|
||||||
|
Done: true,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(2 * time.Second),
|
||||||
|
Duration: 2 * time.Second,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
completed, err := GetClientFileTransferCompletedSummaries(client)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetClientFileTransferCompletedSummaries failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(completed.Send), 1; got != want {
|
||||||
|
t.Fatalf("completed send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed.Send[0].Done, true; got != want {
|
||||||
|
t.Fatalf("completed done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetServerFileTransferLatestByFileIDQueryResolvesTransportGenerationPublicAPI(t *testing.T) {
|
||||||
|
server := NewServer().(*ServerCommon)
|
||||||
|
now := time.Unix(2100, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "public-gen"}
|
||||||
|
serverClient.markClientConnIdentityBound()
|
||||||
|
serverClient.markClientConnStreamTransport()
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared-public", Size: 20},
|
||||||
|
Received: 5,
|
||||||
|
Total: 20,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
firstRuntimeScope := serverTransportScope(serverClient)
|
||||||
|
logicalScope := serverFileScope(serverClient)
|
||||||
|
|
||||||
|
serverClient.markClientConnTransportDetached("read error", nil)
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
server.getFileTransferState().observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared-public", Size: 20},
|
||||||
|
Received: 9,
|
||||||
|
Total: 20,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
secondRuntimeScope := serverTransportScope(serverClient)
|
||||||
|
|
||||||
|
legacy, err := GetServerFileTransferLatestByFileID(server, "shared-public")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerFileTransferLatestByFileID failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(legacy.Receive), 1; got != want {
|
||||||
|
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := legacy.Receive[0].TransportGeneration, uint64(2); got != want {
|
||||||
|
t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
allRuntime, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerFileTransferLatestByFileIDQuery scope failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(allRuntime.Receive), 2; got != want {
|
||||||
|
t.Fatalf("runtime receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gen1, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
TransportGeneration: 1,
|
||||||
|
MatchTransportGeneration: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-1 failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(gen1.Receive), 1; got != want {
|
||||||
|
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gen1.Receive[0].RuntimeScope, firstRuntimeScope; got != want {
|
||||||
|
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gen2, err := GetServerFileTransferLatestByFileIDQuery(server, "shared-public", FileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
TransportGeneration: 2,
|
||||||
|
MatchTransportGeneration: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetServerFileTransferLatestByFileIDQuery generation-2 failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := len(gen2.Receive), 1; got != want {
|
||||||
|
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gen2.Receive[0].RuntimeScope, secondRuntimeScope; got != want {
|
||||||
|
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,146 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "sort"
|
||||||
|
|
||||||
|
type fileTransferSummaryGroup struct {
|
||||||
|
Send []fileTransferSummary
|
||||||
|
Receive []fileTransferSummary
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileTransferSummaryQuery struct {
|
||||||
|
Scope string
|
||||||
|
RuntimeScope string
|
||||||
|
TransportGeneration uint64
|
||||||
|
MatchTransportGeneration bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileTransferQuery struct {
|
||||||
|
monitor *fileTransferMonitor
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferQuery(m *fileTransferMonitor) fileTransferQuery {
|
||||||
|
return fileTransferQuery{monitor: m}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) active() fileTransferSummaryGroup {
|
||||||
|
if q.monitor == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return groupFileTransferSummaries(q.monitor.activeSummaries())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) completed() fileTransferSummaryGroup {
|
||||||
|
if q.monitor == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.completedSummaries(), func(summary fileTransferSummary) bool {
|
||||||
|
return summary.Done && !summary.Failed
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) failed() fileTransferSummaryGroup {
|
||||||
|
return groupFileTransferSummaries(filterFileTransferSummaries(latestFileTransferSummaries(q.monitor), func(summary fileTransferSummary) bool {
|
||||||
|
return summary.Failed
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) latestByFileID(fileID string) fileTransferSummaryGroup {
|
||||||
|
if q.monitor == nil || fileID == "" {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return groupFileTransferSummaries(q.monitor.summariesByFileID(fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) latestSendByFileID(fileID string) []fileTransferSummary {
|
||||||
|
if q.monitor == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionSend, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) latestReceiveByFileID(fileID string) []fileTransferSummary {
|
||||||
|
if q.monitor == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return q.monitor.summariesByDirectionAndFileID(fileTransferDirectionReceive, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup {
|
||||||
|
if q.monitor == nil || fileID == "" {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return groupFileTransferSummaries(filterFileTransferSummaries(q.monitor.runtimeSummariesByFileID(fileID), func(summary fileTransferSummary) bool {
|
||||||
|
return fileTransferSummaryQueryMatch(summary, query)
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||||
|
if q.monitor == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionSend, fileID), func(summary fileTransferSummary) bool {
|
||||||
|
return fileTransferSummaryQueryMatch(summary, query)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q fileTransferQuery) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||||
|
if q.monitor == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return filterFileTransferSummaries(q.monitor.runtimeSummariesByDirectionAndFileID(fileTransferDirectionReceive, fileID), func(summary fileTransferSummary) bool {
|
||||||
|
return fileTransferSummaryQueryMatch(summary, query)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func latestFileTransferSummaries(m *fileTransferMonitor) []fileTransferSummary {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
summaries := append([]fileTransferSummary{}, m.activeSummaries()...)
|
||||||
|
summaries = append(summaries, m.completedSummaries()...)
|
||||||
|
sort.Slice(summaries, func(i int, j int) bool {
|
||||||
|
return fileTransferSummarySortKey(summaries[i]) < fileTransferSummarySortKey(summaries[j])
|
||||||
|
})
|
||||||
|
return summaries
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferSummarySortKey(summary fileTransferSummary) string {
|
||||||
|
return fileTransferMonitorKey(summary.Direction, summary.Scope, summary.FileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func groupFileTransferSummaries(src []fileTransferSummary) fileTransferSummaryGroup {
|
||||||
|
var group fileTransferSummaryGroup
|
||||||
|
for _, summary := range src {
|
||||||
|
switch summary.Direction {
|
||||||
|
case fileTransferDirectionReceive:
|
||||||
|
group.Receive = append(group.Receive, summary)
|
||||||
|
case fileTransferDirectionSend:
|
||||||
|
group.Send = append(group.Send, summary)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return group
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterFileTransferSummaries(src []fileTransferSummary, keep func(fileTransferSummary) bool) []fileTransferSummary {
|
||||||
|
out := make([]fileTransferSummary, 0, len(src))
|
||||||
|
for _, summary := range src {
|
||||||
|
if !keep(summary) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, summary)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferSummaryQueryMatch(summary fileTransferSummary, query fileTransferSummaryQuery) bool {
|
||||||
|
if query.Scope != "" && normalizeFileScope(summary.Scope) != normalizeFileScope(query.Scope) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if query.RuntimeScope != "" && normalizeFileScope(summary.RuntimeScope) != normalizeFileScope(query.RuntimeScope) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if query.MatchTransportGeneration && summary.TransportGeneration != query.TransportGeneration {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileTransferQueryActiveCompletedAndFailed(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
query := newFileTransferQuery(monitor)
|
||||||
|
now := time.Unix(800, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "client-a"}
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "active-send", Size: 10},
|
||||||
|
Received: 4,
|
||||||
|
Total: 10,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "done-recv", Size: 12},
|
||||||
|
Received: 12,
|
||||||
|
Total: 12,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileAbort,
|
||||||
|
Packet: FilePacket{FileID: "failed-send", Size: 8, Stage: "chunk"},
|
||||||
|
Received: 3,
|
||||||
|
Total: 8,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
Err: errString("send failed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
active := query.active()
|
||||||
|
if got, want := len(active.Send), 1; got != want {
|
||||||
|
t.Fatalf("active send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active.Send[0].FileID, "active-send"; got != want {
|
||||||
|
t.Fatalf("active send fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(active.Receive), 0; got != want {
|
||||||
|
t.Fatalf("active receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
completed := query.completed()
|
||||||
|
if got, want := len(completed.Send), 0; got != want {
|
||||||
|
t.Fatalf("completed send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(completed.Receive), 1; got != want {
|
||||||
|
t.Fatalf("completed receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed.Receive[0].FileID, "done-recv"; got != want {
|
||||||
|
t.Fatalf("completed receive fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed.Receive[0].Done, true; got != want {
|
||||||
|
t.Fatalf("completed receive done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
failed := query.failed()
|
||||||
|
if got, want := len(failed.Send), 1; got != want {
|
||||||
|
t.Fatalf("failed send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := failed.Send[0].FileID, "failed-send"; got != want {
|
||||||
|
t.Fatalf("failed send fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := failed.Send[0].Failed, true; got != want {
|
||||||
|
t.Fatalf("failed send flag mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(failed.Receive), 0; got != want {
|
||||||
|
t.Fatalf("failed receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferQueryLatestByFileID(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
query := newFileTransferQuery(monitor)
|
||||||
|
now := time.Unix(900, 0)
|
||||||
|
serverClientA := &ClientConn{ClientID: "client-a"}
|
||||||
|
serverClientB := &ClientConn{ClientID: "client-b"}
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 6,
|
||||||
|
Total: 20,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClientA,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 9,
|
||||||
|
Total: 20,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClientB,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 20,
|
||||||
|
Total: 20,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
group := query.latestByFileID("shared")
|
||||||
|
if got, want := len(group.Send), 1; got != want {
|
||||||
|
t.Fatalf("group send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := group.Send[0].FileID, "shared"; got != want {
|
||||||
|
t.Fatalf("group send fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(group.Receive), 2; got != want {
|
||||||
|
t.Fatalf("group receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := group.Receive[0].Scope, serverFileScope(serverClientA); got != want {
|
||||||
|
t.Fatalf("first receive scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := group.Receive[1].Scope, serverFileScope(serverClientB); got != want {
|
||||||
|
t.Fatalf("second receive scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
send := query.latestSendByFileID("shared")
|
||||||
|
if got, want := len(send), 1; got != want {
|
||||||
|
t.Fatalf("send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := send[0].Received, int64(6); got != want {
|
||||||
|
t.Fatalf("send received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
receive := query.latestReceiveByFileID("shared")
|
||||||
|
if got, want := len(receive), 2; got != want {
|
||||||
|
t.Fatalf("receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := receive[0].Scope, serverFileScope(serverClientA); got != want {
|
||||||
|
t.Fatalf("receive first scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := receive[1].Done, true; got != want {
|
||||||
|
t.Fatalf("receive second done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestClientTransferQueryFollowsPublishedEvents(t *testing.T) {
|
||||||
|
client := NewClient().(*ClientCommon)
|
||||||
|
now := time.Unix(1000, 0)
|
||||||
|
|
||||||
|
client.publishSendFileEvent(FileEvent{
|
||||||
|
NetType: NET_CLIENT,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "client-done", Size: 16},
|
||||||
|
Received: 16,
|
||||||
|
Total: 16,
|
||||||
|
Done: true,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(time.Second),
|
||||||
|
Duration: time.Second,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
completed := client.getFileTransferState().completed()
|
||||||
|
if got, want := len(completed.Send), 1; got != want {
|
||||||
|
t.Fatalf("client completed send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed.Send[0].FileID, "client-done"; got != want {
|
||||||
|
t.Fatalf("client completed send fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferQueryLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
query := newFileTransferQuery(monitor)
|
||||||
|
now := time.Unix(960, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "client-gen"}
|
||||||
|
serverClient.markClientConnIdentityBound()
|
||||||
|
serverClient.markClientConnStreamTransport()
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 5,
|
||||||
|
Total: 20,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
firstRuntimeScope := serverTransportScope(serverClient)
|
||||||
|
logicalScope := serverFileScope(serverClient)
|
||||||
|
|
||||||
|
serverClient.markClientConnTransportDetached("read error", nil)
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 20},
|
||||||
|
Received: 9,
|
||||||
|
Total: 20,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
secondRuntimeScope := serverTransportScope(serverClient)
|
||||||
|
if secondRuntimeScope == firstRuntimeScope {
|
||||||
|
t.Fatalf("runtime scope should change across transport generations: got %q", secondRuntimeScope)
|
||||||
|
}
|
||||||
|
|
||||||
|
legacy := query.latestReceiveByFileID("shared")
|
||||||
|
if got, want := len(legacy), 1; got != want {
|
||||||
|
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := legacy[0].TransportGeneration, uint64(2); got != want {
|
||||||
|
t.Fatalf("legacy receive generation mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
runtimeAll := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
})
|
||||||
|
if got, want := len(runtimeAll), 2; got != want {
|
||||||
|
t.Fatalf("runtime receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gen1 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
TransportGeneration: 1,
|
||||||
|
MatchTransportGeneration: true,
|
||||||
|
})
|
||||||
|
if got, want := len(gen1), 1; got != want {
|
||||||
|
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want {
|
||||||
|
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gen2 := query.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
TransportGeneration: 2,
|
||||||
|
MatchTransportGeneration: true,
|
||||||
|
})
|
||||||
|
if got, want := len(gen2), 1; got != want {
|
||||||
|
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want {
|
||||||
|
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,190 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileTransferDirection uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
fileTransferDirectionReceive fileTransferDirection = iota
|
||||||
|
fileTransferDirectionSend
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileTransferSnapshot struct {
|
||||||
|
Direction fileTransferDirection
|
||||||
|
Scope string
|
||||||
|
RuntimeScope string
|
||||||
|
TransportGeneration uint64
|
||||||
|
NetType NetType
|
||||||
|
Kind EnvelopeKind
|
||||||
|
FileID string
|
||||||
|
Path string
|
||||||
|
Received int64
|
||||||
|
Total int64
|
||||||
|
Percent float64
|
||||||
|
Done bool
|
||||||
|
Err error
|
||||||
|
StartedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
RateBPS float64
|
||||||
|
StepDuration time.Duration
|
||||||
|
InstantRateBPS float64
|
||||||
|
Time time.Time
|
||||||
|
Stage string
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferMonitorScope(event FileEvent) string {
|
||||||
|
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
|
||||||
|
return serverFileScope(logical)
|
||||||
|
}
|
||||||
|
return clientFileScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferRuntimeScope(event FileEvent) string {
|
||||||
|
if event.TransportConn != nil {
|
||||||
|
return serverTransportScopeForTransport(event.TransportConn)
|
||||||
|
}
|
||||||
|
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
|
||||||
|
return serverTransportScope(logical)
|
||||||
|
}
|
||||||
|
return clientFileScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferTransportGeneration(event FileEvent) uint64 {
|
||||||
|
if event.TransportConn != nil {
|
||||||
|
return event.TransportConn.TransportGeneration()
|
||||||
|
}
|
||||||
|
logical := fileEventLogicalConnSnapshot(event)
|
||||||
|
if logical == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return logical.transportGenerationSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferMonitorKey(direction fileTransferDirection, scope string, fileID string) string {
|
||||||
|
if fileID == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return strconv.Itoa(int(direction)) + "|" + scope + "|" + fileID
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferRuntimeMonitorKey(direction fileTransferDirection, runtimeScope string, fileID string) string {
|
||||||
|
return fileTransferMonitorKey(direction, normalizeFileScope(runtimeScope), fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferSnapshotFromEvent(direction fileTransferDirection, event FileEvent) fileTransferSnapshot {
|
||||||
|
return fileTransferSnapshot{
|
||||||
|
Direction: direction,
|
||||||
|
Scope: fileTransferMonitorScope(event),
|
||||||
|
RuntimeScope: fileTransferRuntimeScope(event),
|
||||||
|
TransportGeneration: fileTransferTransportGeneration(event),
|
||||||
|
NetType: event.NetType,
|
||||||
|
Kind: event.Kind,
|
||||||
|
FileID: event.Packet.FileID,
|
||||||
|
Path: event.Path,
|
||||||
|
Received: event.Received,
|
||||||
|
Total: event.Total,
|
||||||
|
Percent: event.Percent,
|
||||||
|
Done: event.Done,
|
||||||
|
Err: event.Err,
|
||||||
|
StartedAt: event.StartedAt,
|
||||||
|
UpdatedAt: event.UpdatedAt,
|
||||||
|
Duration: event.Duration,
|
||||||
|
RateBPS: event.RateBPS,
|
||||||
|
StepDuration: event.StepDuration,
|
||||||
|
InstantRateBPS: event.InstantRateBPS,
|
||||||
|
Time: event.Time,
|
||||||
|
Stage: event.Packet.Stage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isFileTransferTerminal(kind EnvelopeKind) bool {
|
||||||
|
return kind == EnvelopeFileEnd || kind == EnvelopeFileAbort
|
||||||
|
}
|
||||||
|
|
||||||
|
func isFileTransferObservable(kind EnvelopeKind) bool {
|
||||||
|
return kind == EnvelopeFileMeta || kind == EnvelopeFileChunk || kind == EnvelopeFileEnd || kind == EnvelopeFileAbort
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortedFileTransferSnapshots(src map[string]fileTransferSnapshot) []fileTransferSnapshot {
|
||||||
|
keys := make([]string, 0, len(src))
|
||||||
|
for key := range src {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
out := make([]fileTransferSnapshot, 0, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
out = append(out, src[key])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func latestFileTransferSnapshotsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSnapshot {
|
||||||
|
merged := make(map[string]fileTransferSnapshot, len(active)+len(completed))
|
||||||
|
for key, snapshot := range completed {
|
||||||
|
merged[key] = snapshot
|
||||||
|
}
|
||||||
|
for key, snapshot := range active {
|
||||||
|
merged[key] = snapshot
|
||||||
|
}
|
||||||
|
return sortedFileTransferSnapshots(merged)
|
||||||
|
}
|
||||||
|
|
||||||
|
func filteredFileTransferSnapshots(src map[string]fileTransferSnapshot, direction fileTransferDirection) []fileTransferSnapshot {
|
||||||
|
out := make([]fileTransferSnapshot, 0, len(src))
|
||||||
|
for _, snapshot := range sortedFileTransferSnapshots(src) {
|
||||||
|
if snapshot.Direction != direction {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, snapshot)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterFileTransferSnapshotsByFileID(src []fileTransferSnapshot, fileID string) []fileTransferSnapshot {
|
||||||
|
out := make([]fileTransferSnapshot, 0, len(src))
|
||||||
|
for _, snapshot := range src {
|
||||||
|
if snapshot.FileID != fileID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, snapshot)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterFileTransferSnapshotsByDirectionAndFileID(src []fileTransferSnapshot, direction fileTransferDirection, fileID string) []fileTransferSnapshot {
|
||||||
|
out := make([]fileTransferSnapshot, 0, len(src))
|
||||||
|
for _, snapshot := range src {
|
||||||
|
if snapshot.Direction != direction || snapshot.FileID != fileID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, snapshot)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferSnapshotOlder(candidate fileTransferSnapshot, current fileTransferSnapshot, candidateKey string, currentKey string) bool {
|
||||||
|
candidateTime := fileTransferSnapshotCompletedTime(candidate)
|
||||||
|
currentTime := fileTransferSnapshotCompletedTime(current)
|
||||||
|
if candidateTime.Before(currentTime) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if currentTime.Before(candidateTime) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return candidateKey < currentKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferSnapshotCompletedTime(snapshot fileTransferSnapshot) time.Time {
|
||||||
|
if !snapshot.Time.IsZero() {
|
||||||
|
return snapshot.Time
|
||||||
|
}
|
||||||
|
if !snapshot.UpdatedAt.IsZero() {
|
||||||
|
return snapshot.UpdatedAt
|
||||||
|
}
|
||||||
|
return snapshot.StartedAt
|
||||||
|
}
|
||||||
@@ -0,0 +1,302 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import itransfer "b612.me/notify/internal/transfer"
|
||||||
|
|
||||||
|
type fileTransferState struct {
|
||||||
|
monitor *fileTransferMonitor
|
||||||
|
query fileTransferQuery
|
||||||
|
runtime *transferRuntime
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferState() *fileTransferState {
|
||||||
|
return newFileTransferStateWithConfig(defaultFileTransferConfig())
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFileTransferStateWithConfig(cfg fileTransferConfig) *fileTransferState {
|
||||||
|
monitor := newFileTransferMonitorWithConfig(cfg)
|
||||||
|
return &fileTransferState{
|
||||||
|
monitor: monitor,
|
||||||
|
query: newFileTransferQuery(monitor),
|
||||||
|
runtime: newTransferRuntime(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) observe(direction fileTransferDirection, event FileEvent) {
|
||||||
|
if s == nil || s.monitor == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.monitor.observe(direction, event)
|
||||||
|
s.observeRuntime(direction, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) observeMonitorOnly(direction fileTransferDirection, event FileEvent) {
|
||||||
|
if s == nil || s.monitor == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.monitor.observe(direction, event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) applyConfig(cfg fileTransferConfig) {
|
||||||
|
if s == nil || s.monitor == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.monitor.applyConfig(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) monitorView() *fileTransferMonitor {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.monitor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) active() fileTransferSummaryGroup {
|
||||||
|
if s == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return s.query.active()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) completed() fileTransferSummaryGroup {
|
||||||
|
if s == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return s.query.completed()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) failed() fileTransferSummaryGroup {
|
||||||
|
if s == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return s.query.failed()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latest(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) {
|
||||||
|
if s == nil || s.monitor == nil {
|
||||||
|
return fileTransferSummary{}, false
|
||||||
|
}
|
||||||
|
return s.monitor.latestSummary(direction, scope, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latestByFileID(fileID string) fileTransferSummaryGroup {
|
||||||
|
if s == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return s.query.latestByFileID(fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latestSendByFileID(fileID string) []fileTransferSummary {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.query.latestSendByFileID(fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latestReceiveByFileID(fileID string) []fileTransferSummary {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.query.latestReceiveByFileID(fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latestByFileIDQuery(fileID string, query fileTransferSummaryQuery) fileTransferSummaryGroup {
|
||||||
|
if s == nil {
|
||||||
|
return fileTransferSummaryGroup{}
|
||||||
|
}
|
||||||
|
return s.query.latestByFileIDQuery(fileID, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latestSendByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.query.latestSendByFileIDQuery(fileID, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) latestReceiveByFileIDQuery(fileID string, query fileTransferSummaryQuery) []fileTransferSummary {
|
||||||
|
if s == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s.query.latestReceiveByFileIDQuery(fileID, query)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) observeRuntime(direction fileTransferDirection, event FileEvent) {
|
||||||
|
if s == nil || s.runtime == nil || event.Packet.FileID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtimeScope := transferRuntimeScopeForEvent(event)
|
||||||
|
publicScope := transferRuntimePublicScopeForEvent(event)
|
||||||
|
transportGeneration := transferRuntimeTransportGenerationForEvent(event)
|
||||||
|
s.ensureRuntimeTransfer(direction, runtimeScope, publicScope, transportGeneration, event)
|
||||||
|
s.recordRuntimeStage(direction, runtimeScope, event.Packet.FileID, runtimeTransferStageForEvent(event))
|
||||||
|
switch event.Kind {
|
||||||
|
case EnvelopeFileChunk:
|
||||||
|
s.runtime.activate(direction, runtimeScope, event.Packet.FileID)
|
||||||
|
s.syncRuntimeProgress(direction, runtimeScope, event)
|
||||||
|
case EnvelopeFileEnd:
|
||||||
|
s.runtime.activate(direction, runtimeScope, event.Packet.FileID)
|
||||||
|
s.syncRuntimeProgress(direction, runtimeScope, event)
|
||||||
|
switch direction {
|
||||||
|
case fileTransferDirectionSend:
|
||||||
|
s.runtime.beginCommit(direction, runtimeScope, event.Packet.FileID)
|
||||||
|
case fileTransferDirectionReceive:
|
||||||
|
s.runtime.beginVerify(direction, runtimeScope, event.Packet.FileID)
|
||||||
|
}
|
||||||
|
s.runtime.complete(direction, runtimeScope, event.Packet.FileID)
|
||||||
|
case EnvelopeFileAbort:
|
||||||
|
s.syncRuntimeProgress(direction, runtimeScope, event)
|
||||||
|
s.recordRuntimeFailureStage(direction, runtimeScope, event.Packet.FileID, event.Packet.Stage)
|
||||||
|
s.runtime.abort(direction, runtimeScope, event.Packet.FileID, event.Err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) ensureRuntimeTransfer(direction fileTransferDirection, runtimeScope string, publicScope string, transportGeneration uint64, event FileEvent) {
|
||||||
|
if s == nil || s.runtime == nil || event.Packet.FileID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.ensureTransferDescriptor(direction, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{
|
||||||
|
ID: event.Packet.FileID,
|
||||||
|
Channel: itransfer.DataChannel,
|
||||||
|
Size: event.Packet.Size,
|
||||||
|
Checksum: event.Packet.Checksum,
|
||||||
|
Metadata: buildKernelTransferMetadata(event),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) startRuntimeSendSession(runtimeScope string, publicScope string, transportGeneration uint64, session *fileSendSession) {
|
||||||
|
if s == nil || s.runtime == nil || session == nil || session.fileID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.ensureTransferDescriptor(fileTransferDirectionSend, runtimeScope, publicScope, transportGeneration, itransfer.Descriptor{
|
||||||
|
ID: session.fileID,
|
||||||
|
Channel: itransfer.DataChannel,
|
||||||
|
Size: session.size,
|
||||||
|
Checksum: session.checksum,
|
||||||
|
Metadata: itransfer.Metadata{
|
||||||
|
"name": session.name,
|
||||||
|
"path": session.path,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKernelTransferMetadata(event FileEvent) itransfer.Metadata {
|
||||||
|
metadata := make(itransfer.Metadata)
|
||||||
|
if event.Packet.Name != "" {
|
||||||
|
metadata["name"] = event.Packet.Name
|
||||||
|
}
|
||||||
|
if event.Path != "" {
|
||||||
|
metadata["path"] = event.Path
|
||||||
|
}
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) syncRuntimeProgress(direction fileTransferDirection, scope string, event FileEvent) {
|
||||||
|
if s == nil || s.runtime == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
snapshot, ok := s.runtimeSnapshot(direction, scope, event.Packet.FileID)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
progress := event.Received
|
||||||
|
if progress < 0 {
|
||||||
|
progress = 0
|
||||||
|
}
|
||||||
|
switch direction {
|
||||||
|
case fileTransferDirectionReceive:
|
||||||
|
if delta := progress - snapshot.ReceivedBytes; delta > 0 {
|
||||||
|
s.runtime.recordReceive(direction, scope, event.Packet.FileID, delta)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if delta := progress - snapshot.SentBytes; delta > 0 {
|
||||||
|
s.runtime.recordSend(direction, scope, event.Packet.FileID, delta)
|
||||||
|
}
|
||||||
|
s.runtime.setAckedBytes(direction, scope, event.Packet.FileID, progress)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) recordRuntimeRetry(direction fileTransferDirection, scope string, fileID string) {
|
||||||
|
if s == nil || s.runtime == nil || fileID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.recordRetry(direction, scope, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) recordRuntimeTimeout(direction fileTransferDirection, scope string, fileID string) {
|
||||||
|
if s == nil || s.runtime == nil || fileID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.recordTimeout(direction, scope, fileID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) recordRuntimeStage(direction fileTransferDirection, scope string, fileID string, stage string) {
|
||||||
|
if s == nil || s.runtime == nil || fileID == "" || stage == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.recordStage(direction, scope, fileID, stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) recordRuntimeFailureStage(direction fileTransferDirection, scope string, fileID string, stage string) {
|
||||||
|
if s == nil || s.runtime == nil || fileID == "" || stage == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.runtime.recordFailureStage(direction, scope, fileID, stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *fileTransferState) runtimeSnapshot(direction fileTransferDirection, scope string, transferID string) (itransfer.Snapshot, bool) {
|
||||||
|
if s == nil || s.runtime == nil || transferID == "" {
|
||||||
|
return itransfer.Snapshot{}, false
|
||||||
|
}
|
||||||
|
return s.runtime.snapshot(direction, scope, transferID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func transferRuntimeScopeForEvent(event FileEvent) string {
|
||||||
|
if event.TransportConn != nil {
|
||||||
|
return serverTransportScopeForTransport(event.TransportConn)
|
||||||
|
}
|
||||||
|
if logical := fileEventLogicalConnSnapshot(event); logical != nil {
|
||||||
|
return serverTransportScope(logical)
|
||||||
|
}
|
||||||
|
return clientFileScope()
|
||||||
|
}
|
||||||
|
|
||||||
|
func transferRuntimePublicScopeForEvent(event FileEvent) string {
|
||||||
|
return fileTransferMonitorScope(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
func transferRuntimeTransportGenerationForEvent(event FileEvent) uint64 {
|
||||||
|
if event.TransportConn != nil {
|
||||||
|
return event.TransportConn.TransportGeneration()
|
||||||
|
}
|
||||||
|
logical := fileEventLogicalConnSnapshot(event)
|
||||||
|
if logical == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return logical.transportGenerationSnapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func runtimeTransferStageForEvent(event FileEvent) string {
|
||||||
|
if event.Packet.Stage != "" {
|
||||||
|
return event.Packet.Stage
|
||||||
|
}
|
||||||
|
return fileStageByKind(event.Kind)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getTransferRuntime() *transferRuntime {
|
||||||
|
return c.getFileTransferState().runtime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getTransferRuntime() *transferRuntime {
|
||||||
|
return s.getFileTransferState().runtime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ClientCommon) getFileTransferState() *fileTransferState {
|
||||||
|
return c.getLogicalSessionState().fileTransfers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *ServerCommon) getFileTransferState() *fileTransferState {
|
||||||
|
return s.getLogicalSessionState().fileTransfers
|
||||||
|
}
|
||||||
@@ -0,0 +1,371 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
itransfer "b612.me/notify/internal/transfer"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFileTransferStateObserveFeedsQuery(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
now := time.Unix(1100, 0)
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "state-active", Size: 32},
|
||||||
|
Received: 10,
|
||||||
|
Total: 32,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
Kind: EnvelopeFileAbort,
|
||||||
|
Packet: FilePacket{FileID: "state-failed", Size: 16, Stage: "chunk"},
|
||||||
|
Received: 6,
|
||||||
|
Total: 16,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
Err: errString("receive failed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
active := state.active()
|
||||||
|
if got, want := len(active.Send), 1; got != want {
|
||||||
|
t.Fatalf("active send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active.Send[0].FileID, "state-active"; got != want {
|
||||||
|
t.Fatalf("active send fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
failed := state.failed()
|
||||||
|
if got, want := len(failed.Receive), 1; got != want {
|
||||||
|
t.Fatalf("failed receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := failed.Receive[0].FileID, "state-failed"; got != want {
|
||||||
|
t.Fatalf("failed receive fileID mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := failed.Receive[0].Failed, true; got != want {
|
||||||
|
t.Fatalf("failed receive flag mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferStateLatestHelpers(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
now := time.Unix(1200, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "client-a"}
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "state-shared", Size: 40},
|
||||||
|
Received: 15,
|
||||||
|
Total: 40,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "state-shared", Size: 40},
|
||||||
|
Received: 40,
|
||||||
|
Total: 40,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
summary, ok := state.latest(fileTransferDirectionSend, clientFileScope(), "state-shared")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("latest send summary should exist")
|
||||||
|
}
|
||||||
|
if got, want := summary.Received, int64(15); got != want {
|
||||||
|
t.Fatalf("latest send received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
group := state.latestByFileID("state-shared")
|
||||||
|
if got, want := len(group.Send), 1; got != want {
|
||||||
|
t.Fatalf("latest group send count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := len(group.Receive), 1; got != want {
|
||||||
|
t.Fatalf("latest group receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := group.Receive[0].Scope, serverFileScope(serverClient); got != want {
|
||||||
|
t.Fatalf("latest group receive scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
send := state.latestSendByFileID("state-shared")
|
||||||
|
if got, want := len(send), 1; got != want {
|
||||||
|
t.Fatalf("latest send list count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
receive := state.latestReceiveByFileID("state-shared")
|
||||||
|
if got, want := len(receive), 1; got != want {
|
||||||
|
t.Fatalf("latest receive list count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := receive[0].Done, true; got != want {
|
||||||
|
t.Fatalf("latest receive done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferStateObserveFeedsTransferRuntime(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
now := time.Unix(1300, 0)
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
|
||||||
|
Path: "/tmp/demo.bin",
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
|
||||||
|
Received: 8,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "kernel-send", Name: "demo.bin", Size: 8, Checksum: "sum-send"},
|
||||||
|
Received: 8,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "kernel-send")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("send snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshot.State, itransfer.StateDone; got != want {
|
||||||
|
t.Fatalf("send state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want {
|
||||||
|
t.Fatalf("send direction = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshot.SentBytes, int64(8); got != want {
|
||||||
|
t.Fatalf("send bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshot.AckedBytes, int64(8); got != want {
|
||||||
|
t.Fatalf("send acked bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got := sendSnapshot.Metadata["name"]; got != "demo.bin" {
|
||||||
|
t.Fatalf("send metadata name = %q, want demo.bin", got)
|
||||||
|
}
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||||
|
Received: 6,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "kernel-recv", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||||
|
Received: 6,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, clientFileScope(), "kernel-recv")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("receive snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshot.State, itransfer.StateDone; got != want {
|
||||||
|
t.Fatalf("receive state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want {
|
||||||
|
t.Fatalf("receive direction = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshot.ReceivedBytes, int64(6); got != want {
|
||||||
|
t.Fatalf("receive bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferStateRuntimeResilienceStats(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
session := &fileSendSession{
|
||||||
|
fileID: "kernel-retry",
|
||||||
|
path: "/tmp/retry.bin",
|
||||||
|
name: "retry.bin",
|
||||||
|
size: 5,
|
||||||
|
checksum: "sum-retry",
|
||||||
|
}
|
||||||
|
|
||||||
|
state.startRuntimeSendSession(clientFileScope(), clientFileScope(), 0, session)
|
||||||
|
state.recordRuntimeTimeout(fileTransferDirectionSend, clientFileScope(), session.fileID)
|
||||||
|
state.recordRuntimeRetry(fileTransferDirectionSend, clientFileScope(), session.fileID)
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileAbort,
|
||||||
|
Packet: FilePacket{FileID: session.fileID, Name: session.name, Size: session.size, Checksum: session.checksum, Stage: "meta"},
|
||||||
|
Received: 0,
|
||||||
|
Err: errString("ack timeout"),
|
||||||
|
Time: time.Unix(1400, 0),
|
||||||
|
})
|
||||||
|
|
||||||
|
snapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), session.fileID)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("runtime snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := snapshot.TimeoutCount, 1; got != want {
|
||||||
|
t.Fatalf("timeout count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.RetryCount, 1; got != want {
|
||||||
|
t.Fatalf("retry count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.State, itransfer.StateAborted; got != want {
|
||||||
|
t.Fatalf("state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.LastError, "ack timeout"; got != want {
|
||||||
|
t.Fatalf("last error = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Stage, "meta"; got != want {
|
||||||
|
t.Fatalf("stage = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.LastFailureStage, "meta"; got != want {
|
||||||
|
t.Fatalf("last failure stage = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferStateRuntimeSeparatesScopeAndDirection(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
now := time.Unix(1450, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "client-b"}
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "shared-id", Name: "send.bin", Size: 4, Checksum: "sum-send"},
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "shared-id", Name: "recv.bin", Size: 6, Checksum: "sum-recv"},
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
sendSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionSend, clientFileScope(), "shared-id")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("send snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := sendSnapshot.Direction, itransfer.DirectionSend; got != want {
|
||||||
|
t.Fatalf("send direction = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
recvSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, serverTransportScope(serverClient), "shared-id")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("receive snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := recvSnapshot.Direction, itransfer.DirectionReceive; got != want {
|
||||||
|
t.Fatalf("receive direction = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferStateRuntimeSeparatesServerTransportGenerations(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
now := time.Unix(1500, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "client-gen"}
|
||||||
|
serverClient.markClientConnIdentityBound()
|
||||||
|
serverClient.markClientConnStreamTransport()
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "shared-transfer", Name: "recv-a.bin", Size: 4, Checksum: "sum-a"},
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
|
||||||
|
firstScope := serverTransportScope(serverClient)
|
||||||
|
firstSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, firstScope, "shared-transfer")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("first generation snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := firstSnapshot.Metadata[transferMetadataScopeKey], serverFileScope(serverClient); got != want {
|
||||||
|
t.Fatalf("first generation public scope metadata = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverClient.markClientConnTransportDetached("read error", nil)
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileMeta,
|
||||||
|
Packet: FilePacket{FileID: "shared-transfer", Name: "recv-b.bin", Size: 6, Checksum: "sum-b"},
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
secondScope := serverTransportScope(serverClient)
|
||||||
|
if secondScope == firstScope {
|
||||||
|
t.Fatalf("runtime scope should change across transport generations: got %q", secondScope)
|
||||||
|
}
|
||||||
|
secondSnapshot, ok := state.runtimeSnapshot(fileTransferDirectionReceive, secondScope, "shared-transfer")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("second generation snapshot should exist")
|
||||||
|
}
|
||||||
|
if got, want := transferSnapshotRuntimeScope(secondSnapshot.Metadata), secondScope; got != want {
|
||||||
|
t.Fatalf("second generation runtime scope metadata = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := transferSnapshotTransportGeneration(secondSnapshot.Metadata), uint64(2); got != want {
|
||||||
|
t.Fatalf("second generation transport generation metadata = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFileTransferStateLatestByFileIDQueryResolvesTransportGeneration(t *testing.T) {
|
||||||
|
state := newFileTransferState()
|
||||||
|
now := time.Unix(1510, 0)
|
||||||
|
serverClient := &ClientConn{ClientID: "client-query-gen"}
|
||||||
|
serverClient.markClientConnIdentityBound()
|
||||||
|
serverClient.markClientConnStreamTransport()
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 30},
|
||||||
|
Received: 6,
|
||||||
|
Total: 30,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
firstRuntimeScope := serverTransportScope(serverClient)
|
||||||
|
logicalScope := serverFileScope(serverClient)
|
||||||
|
|
||||||
|
serverClient.markClientConnTransportDetached("read error", nil)
|
||||||
|
serverClient.markClientConnTransportAttached()
|
||||||
|
|
||||||
|
state.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClient,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "shared", Size: 30},
|
||||||
|
Received: 10,
|
||||||
|
Total: 30,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
secondRuntimeScope := serverTransportScope(serverClient)
|
||||||
|
|
||||||
|
legacy := state.latestReceiveByFileID("shared")
|
||||||
|
if got, want := len(legacy), 1; got != want {
|
||||||
|
t.Fatalf("legacy receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gen1 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
TransportGeneration: 1,
|
||||||
|
MatchTransportGeneration: true,
|
||||||
|
})
|
||||||
|
if got, want := len(gen1), 1; got != want {
|
||||||
|
t.Fatalf("generation-1 receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gen1[0].RuntimeScope, firstRuntimeScope; got != want {
|
||||||
|
t.Fatalf("generation-1 runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
gen2 := state.latestReceiveByFileIDQuery("shared", fileTransferSummaryQuery{
|
||||||
|
Scope: logicalScope,
|
||||||
|
TransportGeneration: 2,
|
||||||
|
MatchTransportGeneration: true,
|
||||||
|
})
|
||||||
|
if got, want := len(gen2), 1; got != want {
|
||||||
|
t.Fatalf("generation-2 receive count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := gen2[0].RuntimeScope, secondRuntimeScope; got != want {
|
||||||
|
t.Fatalf("generation-2 runtime scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,210 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fileTransferSummary struct {
|
||||||
|
Direction fileTransferDirection
|
||||||
|
Scope string
|
||||||
|
RuntimeScope string
|
||||||
|
TransportGeneration uint64
|
||||||
|
NetType NetType
|
||||||
|
Kind EnvelopeKind
|
||||||
|
FileID string
|
||||||
|
Path string
|
||||||
|
Received int64
|
||||||
|
Total int64
|
||||||
|
Percent float64
|
||||||
|
Active bool
|
||||||
|
Terminal bool
|
||||||
|
Done bool
|
||||||
|
Failed bool
|
||||||
|
Err error
|
||||||
|
StartedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
Duration time.Duration
|
||||||
|
RateBPS float64
|
||||||
|
StepDuration time.Duration
|
||||||
|
InstantRateBPS float64
|
||||||
|
Time time.Time
|
||||||
|
Stage string
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileTransferSummaryRecord struct {
|
||||||
|
snapshot fileTransferSnapshot
|
||||||
|
active bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func fileTransferSummaryFromSnapshot(snapshot fileTransferSnapshot, active bool) fileTransferSummary {
|
||||||
|
return fileTransferSummary{
|
||||||
|
Direction: snapshot.Direction,
|
||||||
|
Scope: snapshot.Scope,
|
||||||
|
RuntimeScope: snapshot.RuntimeScope,
|
||||||
|
TransportGeneration: snapshot.TransportGeneration,
|
||||||
|
NetType: snapshot.NetType,
|
||||||
|
Kind: snapshot.Kind,
|
||||||
|
FileID: snapshot.FileID,
|
||||||
|
Path: snapshot.Path,
|
||||||
|
Received: snapshot.Received,
|
||||||
|
Total: snapshot.Total,
|
||||||
|
Percent: snapshot.Percent,
|
||||||
|
Active: active,
|
||||||
|
Terminal: !active && isFileTransferTerminal(snapshot.Kind),
|
||||||
|
Done: snapshot.Done,
|
||||||
|
Failed: snapshot.Kind == EnvelopeFileAbort || snapshot.Err != nil,
|
||||||
|
Err: snapshot.Err,
|
||||||
|
StartedAt: snapshot.StartedAt,
|
||||||
|
UpdatedAt: snapshot.UpdatedAt,
|
||||||
|
Duration: snapshot.Duration,
|
||||||
|
RateBPS: snapshot.RateBPS,
|
||||||
|
StepDuration: snapshot.StepDuration,
|
||||||
|
InstantRateBPS: snapshot.InstantRateBPS,
|
||||||
|
Time: snapshot.Time,
|
||||||
|
Stage: snapshot.Stage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) activeSummaries() []fileTransferSummary {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return summariesFromSnapshots(sortedFileTransferSnapshots(m.active), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) completedSummaries() []fileTransferSummary {
|
||||||
|
if m == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return summariesFromSnapshots(sortedFileTransferSnapshots(m.completed), false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) latestSummary(direction fileTransferDirection, scope string, fileID string) (fileTransferSummary, bool) {
|
||||||
|
if m == nil {
|
||||||
|
return fileTransferSummary{}, false
|
||||||
|
}
|
||||||
|
key := fileTransferMonitorKey(direction, scope, fileID)
|
||||||
|
if key == "" {
|
||||||
|
return fileTransferSummary{}, false
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if snapshot, ok := m.active[key]; ok {
|
||||||
|
return fileTransferSummaryFromSnapshot(snapshot, true), true
|
||||||
|
}
|
||||||
|
snapshot, ok := m.completed[key]
|
||||||
|
if !ok {
|
||||||
|
return fileTransferSummary{}, false
|
||||||
|
}
|
||||||
|
return fileTransferSummaryFromSnapshot(snapshot, false), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) summariesByFileID(fileID string) []fileTransferSummary {
|
||||||
|
if m == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) summariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary {
|
||||||
|
if m == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.active, m.completed), direction, fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) runtimeSummariesByFileID(fileID string) []fileTransferSummary {
|
||||||
|
if m == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return summariesFromRecords(filterFileTransferSummaryRecordsByFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *fileTransferMonitor) runtimeSummariesByDirectionAndFileID(direction fileTransferDirection, fileID string) []fileTransferSummary {
|
||||||
|
if m == nil || fileID == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
return summariesFromRecords(filterFileTransferSummaryRecordsByDirectionAndFileID(latestFileTransferSummaryRecordsLocked(m.runtimeActive, m.runtimeCompleted), direction, fileID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func latestFileTransferSummaryRecordsLocked(active map[string]fileTransferSnapshot, completed map[string]fileTransferSnapshot) []fileTransferSummaryRecord {
|
||||||
|
keys := make([]string, 0, len(active)+len(completed))
|
||||||
|
seen := make(map[string]struct{}, len(active)+len(completed))
|
||||||
|
for key := range completed {
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
for key := range active {
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
out := make([]fileTransferSummaryRecord, 0, len(keys))
|
||||||
|
for _, key := range keys {
|
||||||
|
if snapshot, ok := active[key]; ok {
|
||||||
|
out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: true})
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if snapshot, ok := completed[key]; ok {
|
||||||
|
out = append(out, fileTransferSummaryRecord{snapshot: snapshot, active: false})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func summariesFromSnapshots(src []fileTransferSnapshot, active bool) []fileTransferSummary {
|
||||||
|
out := make([]fileTransferSummary, 0, len(src))
|
||||||
|
for _, snapshot := range src {
|
||||||
|
out = append(out, fileTransferSummaryFromSnapshot(snapshot, active))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func summariesFromRecords(src []fileTransferSummaryRecord) []fileTransferSummary {
|
||||||
|
out := make([]fileTransferSummary, 0, len(src))
|
||||||
|
for _, record := range src {
|
||||||
|
out = append(out, fileTransferSummaryFromSnapshot(record.snapshot, record.active))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterFileTransferSummaryRecordsByFileID(src []fileTransferSummaryRecord, fileID string) []fileTransferSummaryRecord {
|
||||||
|
out := make([]fileTransferSummaryRecord, 0, len(src))
|
||||||
|
for _, record := range src {
|
||||||
|
if record.snapshot.FileID != fileID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, record)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterFileTransferSummaryRecordsByDirectionAndFileID(src []fileTransferSummaryRecord, direction fileTransferDirection, fileID string) []fileTransferSummaryRecord {
|
||||||
|
out := make([]fileTransferSummaryRecord, 0, len(src))
|
||||||
|
for _, record := range src {
|
||||||
|
if record.snapshot.Direction != direction || record.snapshot.FileID != fileID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = append(out, record)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTransferMonitorLatestSummaryPrefersActive(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
now := time.Unix(500, 0)
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "summary-1", Size: 30},
|
||||||
|
Received: 12,
|
||||||
|
Total: 30,
|
||||||
|
Percent: 40,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(time.Second),
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
summary, ok := monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("latest summary should exist while active")
|
||||||
|
}
|
||||||
|
if got, want := summary.Active, true; got != want {
|
||||||
|
t.Fatalf("active summary mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.Terminal, false; got != want {
|
||||||
|
t.Fatalf("terminal summary mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.Received, int64(12); got != want {
|
||||||
|
t.Fatalf("active summary received mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "summary-1", Size: 30},
|
||||||
|
Received: 30,
|
||||||
|
Total: 30,
|
||||||
|
Percent: 100,
|
||||||
|
Done: true,
|
||||||
|
StartedAt: now,
|
||||||
|
UpdatedAt: now.Add(2 * time.Second),
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
summary, ok = monitor.latestSummary(fileTransferDirectionSend, clientFileScope(), "summary-1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("latest summary should exist after completion")
|
||||||
|
}
|
||||||
|
if got, want := summary.Active, false; got != want {
|
||||||
|
t.Fatalf("completed summary active mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.Terminal, true; got != want {
|
||||||
|
t.Fatalf("completed summary terminal mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summary.Done, true; got != want {
|
||||||
|
t.Fatalf("completed summary done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransferMonitorSummariesByFileID(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
now := time.Unix(600, 0)
|
||||||
|
serverClientA := &ClientConn{ClientID: "client-a"}
|
||||||
|
serverClientB := &ClientConn{ClientID: "client-b"}
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "summary-shared", Size: 20},
|
||||||
|
Received: 8,
|
||||||
|
Total: 20,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClientA,
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "summary-shared", Size: 20},
|
||||||
|
Received: 12,
|
||||||
|
Total: 20,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
ClientConn: serverClientB,
|
||||||
|
Kind: EnvelopeFileAbort,
|
||||||
|
Packet: FilePacket{FileID: "summary-shared", Size: 20, Stage: "chunk"},
|
||||||
|
Received: 14,
|
||||||
|
Total: 20,
|
||||||
|
Time: now.Add(2 * time.Second),
|
||||||
|
Err: errString("recv failed"),
|
||||||
|
})
|
||||||
|
|
||||||
|
summaries := monitor.summariesByFileID("summary-shared")
|
||||||
|
if got, want := len(summaries), 3; got != want {
|
||||||
|
t.Fatalf("summaries count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summaries[0].Scope, serverFileScope(serverClientA); got != want {
|
||||||
|
t.Fatalf("first summary scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summaries[0].Active, true; got != want {
|
||||||
|
t.Fatalf("first summary active mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summaries[1].Scope, serverFileScope(serverClientB); got != want {
|
||||||
|
t.Fatalf("second summary scope mismatch: got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summaries[1].Failed, true; got != want {
|
||||||
|
t.Fatalf("second summary failed mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summaries[1].Terminal, true; got != want {
|
||||||
|
t.Fatalf("second summary terminal mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := summaries[2].Direction, fileTransferDirectionSend; got != want {
|
||||||
|
t.Fatalf("third summary direction mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTransferMonitorActiveAndCompletedSummaries(t *testing.T) {
|
||||||
|
monitor := newFileTransferMonitor()
|
||||||
|
now := time.Unix(700, 0)
|
||||||
|
|
||||||
|
monitor.observe(fileTransferDirectionSend, FileEvent{
|
||||||
|
Kind: EnvelopeFileChunk,
|
||||||
|
Packet: FilePacket{FileID: "active-1", Size: 10},
|
||||||
|
Received: 3,
|
||||||
|
Total: 10,
|
||||||
|
Time: now,
|
||||||
|
})
|
||||||
|
monitor.observe(fileTransferDirectionReceive, FileEvent{
|
||||||
|
Kind: EnvelopeFileEnd,
|
||||||
|
Packet: FilePacket{FileID: "done-1", Size: 10},
|
||||||
|
Received: 10,
|
||||||
|
Total: 10,
|
||||||
|
Done: true,
|
||||||
|
Time: now.Add(time.Second),
|
||||||
|
})
|
||||||
|
|
||||||
|
active := monitor.activeSummaries()
|
||||||
|
if got, want := len(active), 1; got != want {
|
||||||
|
t.Fatalf("active summaries count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := active[0].Active, true; got != want {
|
||||||
|
t.Fatalf("active summary state mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
completed := monitor.completedSummaries()
|
||||||
|
if got, want := len(completed), 1; got != want {
|
||||||
|
t.Fatalf("completed summaries count mismatch: got %d want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed[0].Active, false; got != want {
|
||||||
|
t.Fatalf("completed summary state mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := completed[0].Done, true; got != want {
|
||||||
|
t.Fatalf("completed summary done mismatch: got %v want %v", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type errString string
|
||||||
|
|
||||||
|
func (e errString) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
module b612.me/notify
|
||||||
|
|
||||||
|
go 1.24.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
b612.me/starcrypto v1.0.2
|
||||||
|
b612.me/stario v0.1.0
|
||||||
|
github.com/Microsoft/go-winio v0.6.2
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/emmansun/gmsm v0.41.1 // indirect
|
||||||
|
golang.org/x/crypto v0.48.0 // indirect
|
||||||
|
golang.org/x/sys v0.41.0 // indirect
|
||||||
|
golang.org/x/term v0.40.0 // indirect
|
||||||
|
)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
b612.me/starcrypto v1.0.2 h1:6f8YHNMHZPwxDSRxY2OJeMP4ExKa/cakLIO04f0gLhE=
|
||||||
|
b612.me/starcrypto v1.0.2/go.mod h1:I7oYTmQgnVPj5S5yKwoTyqkItq1HgF9XdJT/v3qs5QE=
|
||||||
|
b612.me/stario v0.1.0 h1:V1uA7fLYzgTadOXpnyPaFC3z0MAKFIM/RKXzZUDXvL4=
|
||||||
|
b612.me/stario v0.1.0/go.mod h1:7kjE69oFqNrca0P72L5+ZbTV09QGJ2N3bBY3qeFXOGc=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||||
|
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||||
|
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||||
|
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/emmansun/gmsm v0.41.1 h1:mD1MqmaXTEqt+9UVmDpRYvcEMIa5vuslFEnw7IWp6/w=
|
||||||
|
github.com/emmansun/gmsm v0.41.1/go.mod h1:FD1EQk4XcSMkahZFzNwFoI/uXzAlODB9JVsJ9G5N7Do=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||||
|
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||||
|
golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts=
|
||||||
|
golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos=
|
||||||
|
golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k=
|
||||||
|
golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||||
|
golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg=
|
||||||
|
golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||||
|
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||||
@@ -0,0 +1,127 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultInboundDispatchSource = "_notify.default_inbound_source"
|
||||||
|
|
||||||
|
type inboundDispatcher struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
closed bool
|
||||||
|
workers map[string]*inboundDispatchWorker
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
type inboundDispatchWorker struct {
|
||||||
|
queue []func()
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInboundDispatcher() *inboundDispatcher {
|
||||||
|
return &inboundDispatcher{
|
||||||
|
workers: make(map[string]*inboundDispatchWorker),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *inboundDispatcher) Dispatch(source string, fn func()) bool {
|
||||||
|
if d == nil || fn == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if source == "" {
|
||||||
|
source = defaultInboundDispatchSource
|
||||||
|
}
|
||||||
|
d.mu.Lock()
|
||||||
|
if d.closed {
|
||||||
|
d.mu.Unlock()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
worker := d.workers[source]
|
||||||
|
if worker == nil {
|
||||||
|
worker = &inboundDispatchWorker{}
|
||||||
|
d.workers[source] = worker
|
||||||
|
}
|
||||||
|
worker.queue = append(worker.queue, fn)
|
||||||
|
if worker.running {
|
||||||
|
d.mu.Unlock()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
worker.running = true
|
||||||
|
d.wg.Add(1)
|
||||||
|
d.mu.Unlock()
|
||||||
|
go d.run(source, worker)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *inboundDispatcher) run(source string, worker *inboundDispatchWorker) {
|
||||||
|
defer d.wg.Done()
|
||||||
|
for {
|
||||||
|
d.mu.Lock()
|
||||||
|
if len(worker.queue) == 0 {
|
||||||
|
worker.running = false
|
||||||
|
if current := d.workers[source]; current == worker {
|
||||||
|
delete(d.workers, source)
|
||||||
|
}
|
||||||
|
d.mu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fn := worker.queue[0]
|
||||||
|
worker.queue[0] = nil
|
||||||
|
worker.queue = worker.queue[1:]
|
||||||
|
d.mu.Unlock()
|
||||||
|
fn()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *inboundDispatcher) CloseAndWait() {
|
||||||
|
if d == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.mu.Lock()
|
||||||
|
d.closed = true
|
||||||
|
d.mu.Unlock()
|
||||||
|
d.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientInboundDispatchSource() string {
|
||||||
|
return "client"
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverInboundDispatchSource(source interface{}) string {
|
||||||
|
switch data := source.(type) {
|
||||||
|
case serverInboundSource:
|
||||||
|
return serverInboundDispatchSourceKey(data)
|
||||||
|
case *serverInboundSource:
|
||||||
|
if data == nil {
|
||||||
|
return defaultInboundDispatchSource
|
||||||
|
}
|
||||||
|
return serverInboundDispatchSourceKey(*data)
|
||||||
|
case net.Conn:
|
||||||
|
return fmt.Sprintf("conn:%p", data)
|
||||||
|
case string:
|
||||||
|
if data == "" {
|
||||||
|
return defaultInboundDispatchSource
|
||||||
|
}
|
||||||
|
return "peer:" + data
|
||||||
|
default:
|
||||||
|
return defaultInboundDispatchSource
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func serverInboundDispatchSourceKey(source serverInboundSource) string {
|
||||||
|
if source.Conn != nil {
|
||||||
|
return fmt.Sprintf("conn:%p:%d", source.Conn, source.TransportGeneration)
|
||||||
|
}
|
||||||
|
if source.Logical != nil {
|
||||||
|
return fmt.Sprintf("logical:%s:%d", source.Logical.ID(), source.TransportGeneration)
|
||||||
|
}
|
||||||
|
if source.Source != "" {
|
||||||
|
return fmt.Sprintf("peer:%s:%d", source.Source, source.TransportGeneration)
|
||||||
|
}
|
||||||
|
if source.RemoteAddr != nil {
|
||||||
|
return fmt.Sprintf("addr:%s:%d", source.RemoteAddr.String(), source.TransportGeneration)
|
||||||
|
}
|
||||||
|
return defaultInboundDispatchSource
|
||||||
|
}
|
||||||
@@ -0,0 +1,103 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInboundDispatcherSerializesPerSource(t *testing.T) {
|
||||||
|
dispatcher := newInboundDispatcher()
|
||||||
|
defer dispatcher.CloseAndWait()
|
||||||
|
|
||||||
|
firstStarted := make(chan struct{}, 1)
|
||||||
|
secondStarted := make(chan struct{}, 1)
|
||||||
|
otherStarted := make(chan struct{}, 1)
|
||||||
|
releaseFirst := make(chan struct{})
|
||||||
|
|
||||||
|
var mu sync.Mutex
|
||||||
|
var order []string
|
||||||
|
|
||||||
|
record := func(step string) {
|
||||||
|
mu.Lock()
|
||||||
|
order = append(order, step)
|
||||||
|
mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dispatcher.Dispatch("alpha", func() {
|
||||||
|
record("alpha-1-start")
|
||||||
|
firstStarted <- struct{}{}
|
||||||
|
<-releaseFirst
|
||||||
|
record("alpha-1-end")
|
||||||
|
}) {
|
||||||
|
t.Fatal("dispatch alpha-1 failed")
|
||||||
|
}
|
||||||
|
if !dispatcher.Dispatch("alpha", func() {
|
||||||
|
record("alpha-2-start")
|
||||||
|
secondStarted <- struct{}{}
|
||||||
|
record("alpha-2-end")
|
||||||
|
}) {
|
||||||
|
t.Fatal("dispatch alpha-2 failed")
|
||||||
|
}
|
||||||
|
if !dispatcher.Dispatch("beta", func() {
|
||||||
|
record("beta-1-start")
|
||||||
|
otherStarted <- struct{}{}
|
||||||
|
record("beta-1-end")
|
||||||
|
}) {
|
||||||
|
t.Fatal("dispatch beta-1 failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-firstStarted:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for alpha-1")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-otherStarted:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for beta-1")
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-secondStarted:
|
||||||
|
t.Fatal("alpha-2 started before alpha-1 finished")
|
||||||
|
case <-time.After(100 * time.Millisecond):
|
||||||
|
}
|
||||||
|
|
||||||
|
close(releaseFirst)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-secondStarted:
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("timed out waiting for alpha-2")
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatcher.CloseAndWait()
|
||||||
|
|
||||||
|
mu.Lock()
|
||||||
|
defer mu.Unlock()
|
||||||
|
if len(order) == 0 {
|
||||||
|
t.Fatal("dispatch order is empty")
|
||||||
|
}
|
||||||
|
alpha1Start := indexOfString(order, "alpha-1-start")
|
||||||
|
alpha1End := indexOfString(order, "alpha-1-end")
|
||||||
|
alpha2Start := indexOfString(order, "alpha-2-start")
|
||||||
|
beta1Start := indexOfString(order, "beta-1-start")
|
||||||
|
if alpha1Start < 0 || alpha1End < 0 || alpha2Start < 0 || beta1Start < 0 {
|
||||||
|
t.Fatalf("unexpected order trace: %v", order)
|
||||||
|
}
|
||||||
|
if alpha2Start < alpha1End {
|
||||||
|
t.Fatalf("alpha source was not serialized: %v", order)
|
||||||
|
}
|
||||||
|
if beta1Start > alpha1End {
|
||||||
|
t.Fatalf("beta source did not run in parallel window: %v", order)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func indexOfString(list []string, target string) int {
|
||||||
|
for idx, item := range list {
|
||||||
|
if item == target {
|
||||||
|
return idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return -1
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package notify
|
||||||
|
|
||||||
|
import "b612.me/starcrypto"
|
||||||
|
|
||||||
|
var integrationSharedSecret = []byte("notify-integration-modern-psk")
|
||||||
|
|
||||||
|
func integrationModernPSKOptions() *ModernPSKOptions {
|
||||||
|
return &ModernPSKOptions{
|
||||||
|
Salt: []byte("notify-integration-modern-psk-salt"),
|
||||||
|
AAD: []byte("notify-integration-modern-psk-aad"),
|
||||||
|
Argon2Params: starcrypto.Argon2Params{
|
||||||
|
Time: 1,
|
||||||
|
Memory: 8,
|
||||||
|
Threads: 1,
|
||||||
|
KeyLen: 32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,40 @@
|
|||||||
|
package codec
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/gob"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Register(data interface{}) {
|
||||||
|
gob.Register(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterName(name string, data interface{}) {
|
||||||
|
gob.RegisterName(name, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterAll(data []interface{}) {
|
||||||
|
for _, v := range data {
|
||||||
|
gob.Register(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func RegisterNames(data map[string]interface{}) {
|
||||||
|
for k, v := range data {
|
||||||
|
gob.RegisterName(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Encode(src interface{}) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
enc := gob.NewEncoder(&buf)
|
||||||
|
err := enc.Encode(&src)
|
||||||
|
return buf.Bytes(), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func Decode(src []byte) (interface{}, error) {
|
||||||
|
dec := gob.NewDecoder(bytes.NewReader(src))
|
||||||
|
var dst interface{}
|
||||||
|
err := dec.Decode(&dst)
|
||||||
|
return dst, err
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
package timeutil
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
func NowUnixNano() int64 {
|
||||||
|
return time.Now().UnixNano()
|
||||||
|
}
|
||||||
@@ -0,0 +1,366 @@
|
|||||||
|
package transfer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTransferIDEmpty = errors.New("transfer id is empty")
|
||||||
|
ErrTransferExists = errors.New("transfer already exists")
|
||||||
|
ErrTransferNotFound = errors.New("transfer not found")
|
||||||
|
ErrTransferBytesInvalid = errors.New("transfer bytes must be non-negative")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
now func() time.Time
|
||||||
|
transfers map[string]*managedTransfer
|
||||||
|
}
|
||||||
|
|
||||||
|
type managedTransfer struct {
|
||||||
|
snapshot Snapshot
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager() *Manager {
|
||||||
|
return NewManagerWithClock(time.Now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManagerWithClock(now func() time.Time) *Manager {
|
||||||
|
if now == nil {
|
||||||
|
now = time.Now
|
||||||
|
}
|
||||||
|
return &Manager{
|
||||||
|
now: now,
|
||||||
|
transfers: make(map[string]*managedTransfer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) StartOutgoing(desc Descriptor) (Snapshot, error) {
|
||||||
|
return m.start(desc, DirectionSend, StateNegotiating)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) StartIncoming(desc Descriptor) (Snapshot, error) {
|
||||||
|
return m.start(desc, DirectionReceive, StatePrepared)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Activate(id string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.State = StateActive
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Pause(id string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
if snapshot.State.Terminal() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
snapshot.State = StatePaused
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Resume(id string, confirmedBytes int64) (Snapshot, error) {
|
||||||
|
if confirmedBytes < 0 {
|
||||||
|
return Snapshot{}, ErrTransferBytesInvalid
|
||||||
|
}
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
switch snapshot.Direction {
|
||||||
|
case DirectionSend:
|
||||||
|
if confirmedBytes > snapshot.SentBytes {
|
||||||
|
snapshot.SentBytes = confirmedBytes
|
||||||
|
}
|
||||||
|
snapshot.AckedBytes = confirmedBytes
|
||||||
|
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
|
||||||
|
snapshot.AckedBytes = snapshot.Size
|
||||||
|
}
|
||||||
|
case DirectionReceive:
|
||||||
|
if confirmedBytes > snapshot.ReceivedBytes {
|
||||||
|
snapshot.ReceivedBytes = confirmedBytes
|
||||||
|
}
|
||||||
|
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
|
||||||
|
snapshot.ReceivedBytes = snapshot.Size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
snapshot.State = StateActive
|
||||||
|
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RecordSend(id string, sentBytes int64) (Snapshot, error) {
|
||||||
|
if sentBytes < 0 {
|
||||||
|
return Snapshot{}, ErrTransferBytesInvalid
|
||||||
|
}
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.SentBytes += sentBytes
|
||||||
|
if snapshot.Size > 0 && snapshot.SentBytes > snapshot.Size {
|
||||||
|
snapshot.SentBytes = snapshot.Size
|
||||||
|
}
|
||||||
|
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||||
|
if !snapshot.State.Terminal() {
|
||||||
|
snapshot.State = StateActive
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RecordReceive(id string, recvBytes int64) (Snapshot, error) {
|
||||||
|
if recvBytes < 0 {
|
||||||
|
return Snapshot{}, ErrTransferBytesInvalid
|
||||||
|
}
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.ReceivedBytes += recvBytes
|
||||||
|
if snapshot.Size > 0 && snapshot.ReceivedBytes > snapshot.Size {
|
||||||
|
snapshot.ReceivedBytes = snapshot.Size
|
||||||
|
}
|
||||||
|
if !snapshot.State.Terminal() {
|
||||||
|
snapshot.State = StateActive
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetAckedBytes(id string, ackedBytes int64) (Snapshot, error) {
|
||||||
|
if ackedBytes < 0 {
|
||||||
|
return Snapshot{}, ErrTransferBytesInvalid
|
||||||
|
}
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.AckedBytes = ackedBytes
|
||||||
|
if snapshot.Size > 0 && snapshot.AckedBytes > snapshot.Size {
|
||||||
|
snapshot.AckedBytes = snapshot.Size
|
||||||
|
}
|
||||||
|
if snapshot.AckedBytes > snapshot.SentBytes {
|
||||||
|
snapshot.SentBytes = snapshot.AckedBytes
|
||||||
|
}
|
||||||
|
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) BeginCommit(id string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.State = StateCommitting
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) BeginVerify(id string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.State = StateVerifying
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Complete(id string) (Snapshot, error) {
|
||||||
|
now := m.currentTime()
|
||||||
|
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
|
||||||
|
snapshot.State = StateDone
|
||||||
|
snapshot.CompletedAt = now.UnixNano()
|
||||||
|
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Abort(id string, err error) (Snapshot, error) {
|
||||||
|
return m.finishWithError(id, StateAborted, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Fail(id string, err error) (Snapshot, error) {
|
||||||
|
return m.finishWithError(id, StateFailed, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RecordRetry(id string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.RetryCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RecordTimeout(id string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.TimeoutCount++
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetStage(id string, stage string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.Stage = stage
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetFailureStage(id string, stage string) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
snapshot.LastFailureStage = stage
|
||||||
|
if stage != "" {
|
||||||
|
snapshot.Stage = stage
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) MergeMetadata(id string, metadata Metadata) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if snapshot.Metadata == nil {
|
||||||
|
snapshot.Metadata = make(Metadata, len(metadata))
|
||||||
|
}
|
||||||
|
for key, value := range metadata {
|
||||||
|
if value == "" {
|
||||||
|
delete(snapshot.Metadata, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
snapshot.Metadata[key] = value
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) RecordTelemetry(id string, delta TelemetryDelta) (Snapshot, error) {
|
||||||
|
return m.update(id, func(snapshot *Snapshot) error {
|
||||||
|
if delta.SourceReadDuration > 0 {
|
||||||
|
snapshot.SourceReadDuration += delta.SourceReadDuration
|
||||||
|
}
|
||||||
|
if delta.StreamWriteDuration > 0 {
|
||||||
|
snapshot.StreamWriteDuration += delta.StreamWriteDuration
|
||||||
|
}
|
||||||
|
if delta.SinkWriteDuration > 0 {
|
||||||
|
snapshot.SinkWriteDuration += delta.SinkWriteDuration
|
||||||
|
}
|
||||||
|
if delta.SyncDuration > 0 {
|
||||||
|
snapshot.SyncDuration += delta.SyncDuration
|
||||||
|
}
|
||||||
|
if delta.VerifyDuration > 0 {
|
||||||
|
snapshot.VerifyDuration += delta.VerifyDuration
|
||||||
|
}
|
||||||
|
if delta.CommitDuration > 0 {
|
||||||
|
snapshot.CommitDuration += delta.CommitDuration
|
||||||
|
}
|
||||||
|
if delta.CommitWaitDuration > 0 {
|
||||||
|
snapshot.CommitWaitDuration += delta.CommitWaitDuration
|
||||||
|
}
|
||||||
|
if delta.SourceReadCount > 0 {
|
||||||
|
snapshot.SourceReadCount += delta.SourceReadCount
|
||||||
|
}
|
||||||
|
if delta.StreamWriteCount > 0 {
|
||||||
|
snapshot.StreamWriteCount += delta.StreamWriteCount
|
||||||
|
}
|
||||||
|
if delta.SinkWriteCount > 0 {
|
||||||
|
snapshot.SinkWriteCount += delta.SinkWriteCount
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Snapshot(id string) (Snapshot, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
transfer, ok := m.transfers[id]
|
||||||
|
if !ok {
|
||||||
|
return Snapshot{}, false
|
||||||
|
}
|
||||||
|
return cloneSnapshot(transfer.snapshot), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Snapshots() []Snapshot {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
out := make([]Snapshot, 0, len(m.transfers))
|
||||||
|
for _, transfer := range m.transfers {
|
||||||
|
out = append(out, cloneSnapshot(transfer.snapshot))
|
||||||
|
}
|
||||||
|
sort.Slice(out, func(i int, j int) bool {
|
||||||
|
return out[i].ID < out[j].ID
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Restore(snapshot Snapshot) (Snapshot, error) {
|
||||||
|
if snapshot.ID == "" {
|
||||||
|
return Snapshot{}, ErrTransferIDEmpty
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.transfers[snapshot.ID] = &managedTransfer{snapshot: cloneSnapshot(snapshot)}
|
||||||
|
return cloneSnapshot(snapshot), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) start(desc Descriptor, direction Direction, state State) (Snapshot, error) {
|
||||||
|
if desc.ID == "" {
|
||||||
|
return Snapshot{}, ErrTransferIDEmpty
|
||||||
|
}
|
||||||
|
now := m.currentTime()
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
if _, exists := m.transfers[desc.ID]; exists {
|
||||||
|
return Snapshot{}, ErrTransferExists
|
||||||
|
}
|
||||||
|
snapshot := Snapshot{
|
||||||
|
ID: desc.ID,
|
||||||
|
Direction: direction,
|
||||||
|
Channel: normalizeChannel(desc.Channel),
|
||||||
|
State: state,
|
||||||
|
Size: desc.Size,
|
||||||
|
Checksum: desc.Checksum,
|
||||||
|
Metadata: cloneMetadata(desc.Metadata),
|
||||||
|
StartedAt: now.UnixNano(),
|
||||||
|
UpdatedAt: now.UnixNano(),
|
||||||
|
}
|
||||||
|
m.transfers[desc.ID] = &managedTransfer{snapshot: snapshot}
|
||||||
|
return cloneSnapshot(snapshot), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) finishWithError(id string, state State, err error) (Snapshot, error) {
|
||||||
|
now := m.currentTime()
|
||||||
|
return m.updateWithTime(id, now, func(snapshot *Snapshot, now time.Time) error {
|
||||||
|
snapshot.State = state
|
||||||
|
snapshot.CompletedAt = now.UnixNano()
|
||||||
|
if err != nil {
|
||||||
|
snapshot.LastError = err.Error()
|
||||||
|
}
|
||||||
|
snapshot.InflightBytes = inflightBytes(*snapshot)
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) update(id string, fn func(*Snapshot) error) (Snapshot, error) {
|
||||||
|
return m.updateWithTime(id, m.currentTime(), func(snapshot *Snapshot, _ time.Time) error {
|
||||||
|
return fn(snapshot)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateWithTime(id string, now time.Time, fn func(*Snapshot, time.Time) error) (Snapshot, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
transfer, ok := m.transfers[id]
|
||||||
|
if !ok {
|
||||||
|
return Snapshot{}, ErrTransferNotFound
|
||||||
|
}
|
||||||
|
snapshot := &transfer.snapshot
|
||||||
|
if err := fn(snapshot, now); err != nil {
|
||||||
|
return Snapshot{}, err
|
||||||
|
}
|
||||||
|
snapshot.UpdatedAt = now.UnixNano()
|
||||||
|
return cloneSnapshot(*snapshot), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) currentTime() time.Time {
|
||||||
|
return m.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
func inflightBytes(snapshot Snapshot) int64 {
|
||||||
|
if snapshot.Direction != DirectionSend {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if snapshot.SentBytes <= snapshot.AckedBytes {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return snapshot.SentBytes - snapshot.AckedBytes
|
||||||
|
}
|
||||||
@@ -0,0 +1,193 @@
|
|||||||
|
package transfer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type fakeClock struct {
|
||||||
|
now time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeClock) Now() time.Time {
|
||||||
|
return f.now
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeClock) Advance(d time.Duration) {
|
||||||
|
f.now = f.now.Add(d)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerOutgoingLifecycle(t *testing.T) {
|
||||||
|
clock := &fakeClock{now: time.Unix(100, 0)}
|
||||||
|
manager := NewManagerWithClock(clock.Now)
|
||||||
|
|
||||||
|
snapshot, err := manager.StartOutgoing(Descriptor{
|
||||||
|
ID: "tx-1",
|
||||||
|
Size: 100,
|
||||||
|
Checksum: "sum-1",
|
||||||
|
Metadata: Metadata{"kind": "file"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartOutgoing failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.State, StateNegotiating; got != want {
|
||||||
|
t.Fatalf("start state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Channel, DataChannel; got != want {
|
||||||
|
t.Fatalf("channel = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
if _, err := manager.Activate("tx-1"); err != nil {
|
||||||
|
t.Fatalf("Activate failed: %v", err)
|
||||||
|
}
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
if _, err := manager.RecordSend("tx-1", 60); err != nil {
|
||||||
|
t.Fatalf("RecordSend failed: %v", err)
|
||||||
|
}
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
if _, err := manager.SetAckedBytes("tx-1", 40); err != nil {
|
||||||
|
t.Fatalf("SetAckedBytes failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.RecordRetry("tx-1"); err != nil {
|
||||||
|
t.Fatalf("RecordRetry failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.RecordTimeout("tx-1"); err != nil {
|
||||||
|
t.Fatalf("RecordTimeout failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.Pause("tx-1"); err != nil {
|
||||||
|
t.Fatalf("Pause failed: %v", err)
|
||||||
|
}
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
if _, err := manager.Resume("tx-1", 40); err != nil {
|
||||||
|
t.Fatalf("Resume failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.BeginCommit("tx-1"); err != nil {
|
||||||
|
t.Fatalf("BeginCommit failed: %v", err)
|
||||||
|
}
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
snapshot, err = manager.Complete("tx-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got, want := snapshot.State, StateDone; got != want {
|
||||||
|
t.Fatalf("complete state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.SentBytes, int64(60); got != want {
|
||||||
|
t.Fatalf("sent bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.AckedBytes, int64(40); got != want {
|
||||||
|
t.Fatalf("acked bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.InflightBytes, int64(20); got != want {
|
||||||
|
t.Fatalf("inflight bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.RetryCount, 1; got != want {
|
||||||
|
t.Fatalf("retry count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.TimeoutCount, 1; got != want {
|
||||||
|
t.Fatalf("timeout count = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if _, err := manager.SetStage("tx-1", "chunk"); err != nil {
|
||||||
|
t.Fatalf("SetStage failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.SetFailureStage("tx-1", "chunk"); err != nil {
|
||||||
|
t.Fatalf("SetFailureStage failed: %v", err)
|
||||||
|
}
|
||||||
|
if snapshot.CompletedAt == 0 {
|
||||||
|
t.Fatal("completed timestamp should be set")
|
||||||
|
}
|
||||||
|
if got := snapshot.Metadata["kind"]; got != "file" {
|
||||||
|
t.Fatalf("metadata kind = %q, want file", got)
|
||||||
|
}
|
||||||
|
snapshot, ok := manager.Snapshot("tx-1")
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("snapshot should still exist")
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Stage, "chunk"; got != want {
|
||||||
|
t.Fatalf("stage = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.LastFailureStage, "chunk"; got != want {
|
||||||
|
t.Fatalf("last failure stage = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerIncomingResumeAndVerify(t *testing.T) {
|
||||||
|
clock := &fakeClock{now: time.Unix(200, 0)}
|
||||||
|
manager := NewManagerWithClock(clock.Now)
|
||||||
|
|
||||||
|
snapshot, err := manager.StartIncoming(Descriptor{
|
||||||
|
ID: "rx-1",
|
||||||
|
Channel: ControlChannel,
|
||||||
|
Size: 64,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("StartIncoming failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.State, StatePrepared; got != want {
|
||||||
|
t.Fatalf("prepared state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
snapshot, err = manager.Resume("rx-1", 16)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Resume failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.ReceivedBytes, int64(16); got != want {
|
||||||
|
t.Fatalf("received bytes after resume = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := manager.RecordReceive("rx-1", 20); err != nil {
|
||||||
|
t.Fatalf("RecordReceive failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.BeginVerify("rx-1"); err != nil {
|
||||||
|
t.Fatalf("BeginVerify failed: %v", err)
|
||||||
|
}
|
||||||
|
clock.Advance(time.Second)
|
||||||
|
snapshot, err = manager.Complete("rx-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Complete failed: %v", err)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.State, StateDone; got != want {
|
||||||
|
t.Fatalf("complete state = %v, want %v", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.ReceivedBytes, int64(36); got != want {
|
||||||
|
t.Fatalf("received bytes = %d, want %d", got, want)
|
||||||
|
}
|
||||||
|
if got, want := snapshot.Channel, ControlChannel; got != want {
|
||||||
|
t.Fatalf("channel = %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerValidatesIDsAndSortedSnapshots(t *testing.T) {
|
||||||
|
manager := NewManager()
|
||||||
|
|
||||||
|
if _, err := manager.StartOutgoing(Descriptor{}); !errors.Is(err, ErrTransferIDEmpty) {
|
||||||
|
t.Fatalf("empty id error = %v, want %v", err, ErrTransferIDEmpty)
|
||||||
|
}
|
||||||
|
if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); err != nil {
|
||||||
|
t.Fatalf("StartOutgoing b failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.StartIncoming(Descriptor{ID: "a"}); err != nil {
|
||||||
|
t.Fatalf("StartIncoming a failed: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := manager.StartOutgoing(Descriptor{ID: "b"}); !errors.Is(err, ErrTransferExists) {
|
||||||
|
t.Fatalf("duplicate id error = %v, want %v", err, ErrTransferExists)
|
||||||
|
}
|
||||||
|
if _, err := manager.RecordSend("missing", 1); !errors.Is(err, ErrTransferNotFound) {
|
||||||
|
t.Fatalf("missing transfer error = %v, want %v", err, ErrTransferNotFound)
|
||||||
|
}
|
||||||
|
if _, err := manager.RecordReceive("a", -1); !errors.Is(err, ErrTransferBytesInvalid) {
|
||||||
|
t.Fatalf("negative bytes error = %v, want %v", err, ErrTransferBytesInvalid)
|
||||||
|
}
|
||||||
|
|
||||||
|
snapshots := manager.Snapshots()
|
||||||
|
if len(snapshots) != 2 {
|
||||||
|
t.Fatalf("snapshot count = %d, want 2", len(snapshots))
|
||||||
|
}
|
||||||
|
if snapshots[0].ID != "a" || snapshots[1].ID != "b" {
|
||||||
|
t.Fatalf("snapshot order = [%s %s], want [a b]", snapshots[0].ID, snapshots[1].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,188 @@
|
|||||||
|
package transfer
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
type Channel string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ControlChannel Channel = "control"
|
||||||
|
DataChannel Channel = "data"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Direction uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
DirectionSend Direction = iota
|
||||||
|
DirectionReceive
|
||||||
|
)
|
||||||
|
|
||||||
|
type State uint8
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateInit State = iota
|
||||||
|
StateNegotiating
|
||||||
|
StatePrepared
|
||||||
|
StateActive
|
||||||
|
StatePaused
|
||||||
|
StateCommitting
|
||||||
|
StateVerifying
|
||||||
|
StateDone
|
||||||
|
StateAborted
|
||||||
|
StateFailed
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s State) Terminal() bool {
|
||||||
|
switch s {
|
||||||
|
case StateDone, StateAborted, StateFailed:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Range struct {
|
||||||
|
Offset int64
|
||||||
|
Length int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Metadata map[string]string
|
||||||
|
|
||||||
|
type TelemetryDelta struct {
|
||||||
|
SourceReadDuration time.Duration
|
||||||
|
StreamWriteDuration time.Duration
|
||||||
|
SinkWriteDuration time.Duration
|
||||||
|
SyncDuration time.Duration
|
||||||
|
VerifyDuration time.Duration
|
||||||
|
CommitDuration time.Duration
|
||||||
|
CommitWaitDuration time.Duration
|
||||||
|
SourceReadCount int
|
||||||
|
StreamWriteCount int
|
||||||
|
SinkWriteCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Descriptor struct {
|
||||||
|
ID string
|
||||||
|
Direction Direction
|
||||||
|
Channel Channel
|
||||||
|
Size int64
|
||||||
|
Checksum string
|
||||||
|
Metadata Metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
type Snapshot struct {
|
||||||
|
ID string
|
||||||
|
Direction Direction
|
||||||
|
Channel Channel
|
||||||
|
State State
|
||||||
|
Stage string
|
||||||
|
LastFailureStage string
|
||||||
|
Size int64
|
||||||
|
Checksum string
|
||||||
|
Metadata Metadata
|
||||||
|
SentBytes int64
|
||||||
|
AckedBytes int64
|
||||||
|
ReceivedBytes int64
|
||||||
|
InflightBytes int64
|
||||||
|
RetryCount int
|
||||||
|
TimeoutCount int
|
||||||
|
LastError string
|
||||||
|
SourceReadDuration time.Duration
|
||||||
|
StreamWriteDuration time.Duration
|
||||||
|
SinkWriteDuration time.Duration
|
||||||
|
SyncDuration time.Duration
|
||||||
|
VerifyDuration time.Duration
|
||||||
|
CommitDuration time.Duration
|
||||||
|
CommitWaitDuration time.Duration
|
||||||
|
SourceReadCount int
|
||||||
|
StreamWriteCount int
|
||||||
|
SinkWriteCount int
|
||||||
|
StartedAt int64
|
||||||
|
UpdatedAt int64
|
||||||
|
CompletedAt int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Begin struct {
|
||||||
|
TransferID string
|
||||||
|
Channel Channel
|
||||||
|
Size int64
|
||||||
|
Checksum string
|
||||||
|
Metadata Metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
type BeginAck struct {
|
||||||
|
TransferID string
|
||||||
|
Accepted bool
|
||||||
|
NextOffset int64
|
||||||
|
Missing []Range
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Resume struct {
|
||||||
|
TransferID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type ResumeAck struct {
|
||||||
|
TransferID string
|
||||||
|
Accepted bool
|
||||||
|
NextOffset int64
|
||||||
|
Missing []Range
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Commit struct {
|
||||||
|
TransferID string
|
||||||
|
Size int64
|
||||||
|
Checksum string
|
||||||
|
}
|
||||||
|
|
||||||
|
type CommitAck struct {
|
||||||
|
TransferID string
|
||||||
|
Accepted bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Abort struct {
|
||||||
|
TransferID string
|
||||||
|
Stage string
|
||||||
|
Offset int64
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
type Segment struct {
|
||||||
|
TransferID string
|
||||||
|
Channel Channel
|
||||||
|
Offset int64
|
||||||
|
Payload []byte
|
||||||
|
Flags uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Ack struct {
|
||||||
|
TransferID string
|
||||||
|
NextOffset int64
|
||||||
|
Missing []Range
|
||||||
|
Final bool
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeChannel(channel Channel) Channel {
|
||||||
|
if channel == "" {
|
||||||
|
return DataChannel
|
||||||
|
}
|
||||||
|
return channel
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneMetadata(src Metadata) Metadata {
|
||||||
|
if len(src) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
dst := make(Metadata, len(src))
|
||||||
|
for key, value := range src {
|
||||||
|
dst[key] = value
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneSnapshot(src Snapshot) Snapshot {
|
||||||
|
src.Metadata = cloneMetadata(src.Metadata)
|
||||||
|
return src
|
||||||
|
}
|
||||||
@@ -0,0 +1,16 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialNamedPipe(_ string, _ *time.Duration) (net.Conn, error) {
|
||||||
|
return nil, ErrNamedPipeUnsupported
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNamedPipe(_ string) (net.Listener, error) {
|
||||||
|
return nil, ErrNamedPipeUnsupported
|
||||||
|
}
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Microsoft/go-winio"
|
||||||
|
)
|
||||||
|
|
||||||
|
func dialNamedPipe(addr string, timeout *time.Duration) (net.Conn, error) {
|
||||||
|
return winio.DialPipe(NormalizeNamedPipeAddr(addr), timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNamedPipe(addr string) (net.Listener, error) {
|
||||||
|
return winio.ListenPipe(NormalizeNamedPipeAddr(addr), &winio.PipeConfig{
|
||||||
|
MessageMode: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,79 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrNamedPipeUnsupported = errors.New("named pipe transport is only supported on windows")
|
||||||
|
|
||||||
|
func IsUDPNetwork(network string) bool {
|
||||||
|
return strings.Contains(strings.ToLower(strings.TrimSpace(network)), "udp")
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsNamedPipeNetwork(network string) bool {
|
||||||
|
switch strings.ToLower(strings.TrimSpace(network)) {
|
||||||
|
case "npipe", "pipe", "namedpipe", "named-pipe":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dial(network string, addr string) (net.Conn, error) {
|
||||||
|
if IsNamedPipeNetwork(network) {
|
||||||
|
return dialNamedPipe(addr, nil)
|
||||||
|
}
|
||||||
|
return net.Dial(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialTimeout(network string, addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
if IsNamedPipeNetwork(network) {
|
||||||
|
return dialNamedPipe(addr, &timeout)
|
||||||
|
}
|
||||||
|
return net.DialTimeout(network, addr, timeout)
|
||||||
|
}
|
||||||
|
|
||||||
|
func Listen(network string, addr string) (net.Listener, error) {
|
||||||
|
if IsNamedPipeNetwork(network) {
|
||||||
|
return listenNamedPipe(addr)
|
||||||
|
}
|
||||||
|
return net.Listen(network, addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NormalizeNamedPipeAddr(addr string) string {
|
||||||
|
trimmed := strings.TrimSpace(addr)
|
||||||
|
if trimmed == "" {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, `\\.\pipe\`) {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(trimmed, `//./pipe/`) {
|
||||||
|
return `\\.\pipe\` + strings.TrimPrefix(trimmed, `//./pipe/`)
|
||||||
|
}
|
||||||
|
trimmed = strings.TrimPrefix(trimmed, `\\`)
|
||||||
|
trimmed = strings.TrimPrefix(trimmed, `//`)
|
||||||
|
trimmed = strings.TrimPrefix(trimmed, `.\pipe\`)
|
||||||
|
trimmed = strings.TrimPrefix(trimmed, `./pipe/`)
|
||||||
|
trimmed = strings.TrimPrefix(trimmed, `pipe\`)
|
||||||
|
trimmed = strings.TrimPrefix(trimmed, `pipe/`)
|
||||||
|
trimmed = strings.TrimLeft(strings.ReplaceAll(trimmed, "/", `\`), `\`)
|
||||||
|
return `\\.\pipe\` + trimmed
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnRemoteAddrString(conn net.Conn) string {
|
||||||
|
if conn == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
addr := conn.RemoteAddr()
|
||||||
|
if addr == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
if value := addr.String(); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDialNamedPipeUnsupportedOnNonWindows(t *testing.T) {
|
||||||
|
_, err := DialTimeout("npipe", "notify-demo", time.Millisecond)
|
||||||
|
if !errors.Is(err, ErrNamedPipeUnsupported) {
|
||||||
|
t.Fatalf("DialTimeout error = %v, want %v", err, ErrNamedPipeUnsupported)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListenNamedPipeUnsupportedOnNonWindows(t *testing.T) {
|
||||||
|
_, err := Listen("npipe", "notify-demo")
|
||||||
|
if !errors.Is(err, ErrNamedPipeUnsupported) {
|
||||||
|
t.Fatalf("Listen error = %v, want %v", err, ErrNamedPipeUnsupported)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestNamedPipeNetworkAliases(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
network string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{network: "npipe", want: true},
|
||||||
|
{network: "pipe", want: true},
|
||||||
|
{network: "namedpipe", want: true},
|
||||||
|
{network: "named-pipe", want: true},
|
||||||
|
{network: "tcp", want: false},
|
||||||
|
{network: "unix", want: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := IsNamedPipeNetwork(tt.network); got != tt.want {
|
||||||
|
t.Fatalf("IsNamedPipeNetwork(%q) = %v, want %v", tt.network, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeNamedPipeAddr(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addr string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{name: "short-name", addr: "notify-demo", want: `\\.\pipe\notify-demo`},
|
||||||
|
{name: "pipe-prefix", addr: `pipe\notify-demo`, want: `\\.\pipe\notify-demo`},
|
||||||
|
{name: "slash-prefix", addr: "//./pipe/notify-demo", want: `\\.\pipe\notify-demo`},
|
||||||
|
{name: "normalized", addr: `\\.\pipe\notify-demo`, want: `\\.\pipe\notify-demo`},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := NormalizeNamedPipeAddr(tt.addr); got != tt.want {
|
||||||
|
t.Fatalf("NormalizeNamedPipeAddr(%q) = %q, want %q", tt.addr, got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user